Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions ggml/src/ggml-cuda/dmmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -416,10 +416,11 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,

static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const half * x = (const half *) vx;

// load 2 halfs into register in a single instruction
const half2 x_reg = *((half2 *) &(x[ib + iqs]));
// automatic half -> float type cast if dfloat == float
v.x = x[ib + iqs + 0];
v.y = x[ib + iqs + 1];
v.x = __low2float(x_reg);
v.y = __high2float(x_reg);
}

static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
Expand Down Expand Up @@ -476,13 +477,28 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
// matrix multiplication
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
#ifdef GGML_CUDA_F16
tmp += __hmul2(v, {
y[iybs + iqs + j/qr + 0],
y[iybs + iqs + j/qr + y_offset]
});
if ( y_offset == 1 ) {
// load 2 dfloats into register in a single instruction
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
tmp += __hmul2(v, y_reg);
}
else {
tmp += __hmul2(v, {
y[iybs + iqs + j/qr + 0],
y[iybs + iqs + j/qr + y_offset]
});
}
#else
tmp += v.x * y[iybs + iqs + j/qr + 0];
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
if ( y_offset == 1 ) {
// load 2 dfloats into register in a single instruction
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
tmp += v.x * y_reg.x;
tmp += v.y * y_reg.y;
}
else {
tmp += v.x * y[iybs + iqs + j/qr + 0];
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
}
#endif // GGML_CUDA_F16
}
}
Expand Down
Loading