From 95c8b9c1b77f5ae5953d7e3ce1ad009f9be6bd3b Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Tue, 8 Oct 2024 05:00:26 -0700 Subject: [PATCH 1/3] Vectorize load instructions in dmmv f16 CUDA kernel Replaces scalar with vector load instructions, which substantially improves performance on NVIDIA HBM GPUs, e.g. gives a 1.27X overall speedup for Meta-Llama-3-8B-Instruct-F16 BS1 inference evaluation on H100 SXM 80GB HBM3. On GDDR GPUs, there is a slight (1.01X) speedup. --- ggml/src/ggml-cuda/dmmv.cu | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu index 96a5adef5b2b5..2a9543fd1f12e 100644 --- a/ggml/src/ggml-cuda/dmmv.cu +++ b/ggml/src/ggml-cuda/dmmv.cu @@ -416,10 +416,11 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const half * x = (const half *) vx; - + // load 2 halfs into register in a single instruction + const half2 x_reg = *((half2 *) &(x[ib + iqs])); // automatic half -> float type cast if dfloat == float - v.x = x[ib + iqs + 0]; - v.y = x[ib + iqs + 1]; + v.x = x_reg.x; + v.y = x_reg.y; } static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) { @@ -476,13 +477,31 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons // matrix multiplication // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 #ifdef GGML_CUDA_F16 - tmp += __hmul2(v, { - y[iybs + iqs + j/qr + 0], - y[iybs + iqs + j/qr + y_offset] - }); + if ( y_offset == 1 ) { + // load 2 dfloats into register in a single instruction + const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr])); + tmp += __hmul2(v, { + y_reg.x; + y_reg.y; + }); + } + else { + tmp += __hmul2(v, { + y[iybs + iqs + j/qr + 0], + y[iybs + iqs + j/qr + y_offset] + }); + } #else - tmp += v.x * y[iybs + iqs + j/qr + 0]; - tmp += v.y * y[iybs + iqs + j/qr + y_offset]; + if ( y_offset == 1 ) { + // load 2 dfloats into register in a single instruction + const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr])); + tmp += v.x * y_reg.x; + tmp += v.y * y_reg.y; + } + else { + tmp += v.x * y[iybs + iqs + j/qr + 0]; + tmp += v.y * y[iybs + iqs + j/qr + y_offset]; + } #endif // GGML_CUDA_F16 } } From d07dc44c6320c8ee0d7f57416d7808907377119a Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Thu, 10 Oct 2024 06:05:12 -0700 Subject: [PATCH 2/3] addressed comment --- ggml/src/ggml-cuda/dmmv.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu index 2a9543fd1f12e..b727d4ff01bac 100644 --- a/ggml/src/ggml-cuda/dmmv.cu +++ b/ggml/src/ggml-cuda/dmmv.cu @@ -419,8 +419,8 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int // load 2 halfs into register in a single instruction const half2 x_reg = *((half2 *) &(x[ib + iqs])); // automatic half -> float type cast if dfloat == float - v.x = x_reg.x; - v.y = x_reg.y; + v.x = __low2float(x_reg); + v.y = __high2float(x_reg); } static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) { From d150c7e309af8d11d896c2bb8622030bfb1635ca Mon Sep 17 00:00:00 2001 From: agray3 Date: Thu, 10 Oct 2024 15:47:11 +0100 Subject: [PATCH 3/3] Update ggml/src/ggml-cuda/dmmv.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/dmmv.cu | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu index b727d4ff01bac..00e21b5d77e3c 100644 --- a/ggml/src/ggml-cuda/dmmv.cu +++ b/ggml/src/ggml-cuda/dmmv.cu @@ -480,10 +480,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons if ( y_offset == 1 ) { // load 2 dfloats into register in a single instruction const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr])); - tmp += __hmul2(v, { - y_reg.x; - y_reg.y; - }); + tmp += __hmul2(v, y_reg); } else { tmp += __hmul2(v, {