@@ -1105,6 +1105,61 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
11051105#endif // GGML_CUDA_F16
11061106}
11071107
1108+ template <typename dst_t >
1109+ static __global__ void dequantize_block_q4_0 (const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
1110+
1111+ const int i = blockIdx .x ;
1112+
1113+ // assume 32 threads
1114+ const int tid = threadIdx .x ;
1115+ const int il = tid/8 ;
1116+ const int ir = tid%8 ;
1117+ const int ib = 8 *i + ir;
1118+ if (ib >= nb32) {
1119+ return ;
1120+ }
1121+
1122+ dst_t * y = yy + 256 *i + 32 *ir + 4 *il;
1123+
1124+ const block_q4_0 * x = (const block_q4_0 *)vx + ib;
1125+ const float d = __half2float (x->d );
1126+ const float dm = -8 *d;
1127+
1128+ const uint8_t * q = x->qs + 4 *il;
1129+
1130+ for (int l = 0 ; l < 4 ; ++l) {
1131+ y[l+ 0 ] = d * (q[l] & 0xF ) + dm;
1132+ y[l+16 ] = d * (q[l] >> 4 ) + dm;
1133+ }
1134+ }
1135+
1136+ template <typename dst_t >
1137+ static __global__ void dequantize_block_q4_1 (const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
1138+
1139+ const int i = blockIdx .x ;
1140+
1141+ // assume 32 threads
1142+ const int tid = threadIdx .x ;
1143+ const int il = tid/8 ;
1144+ const int ir = tid%8 ;
1145+ const int ib = 8 *i + ir;
1146+ if (ib >= nb32) {
1147+ return ;
1148+ }
1149+
1150+ dst_t * y = yy + 256 *i + 32 *ir + 4 *il;
1151+
1152+ const block_q4_1 * x = (const block_q4_1 *)vx + ib;
1153+ const float2 d = __half22float2 (x->dm );
1154+
1155+ const uint8_t * q = x->qs + 4 *il;
1156+
1157+ for (int l = 0 ; l < 4 ; ++l) {
1158+ y[l+ 0 ] = d.x * (q[l] & 0xF ) + d.y ;
1159+ y[l+16 ] = d.x * (q[l] >> 4 ) + d.y ;
1160+ }
1161+ }
1162+
11081163// ================================== k-quants
11091164
11101165template <typename dst_t >
@@ -6253,6 +6308,20 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
62536308#endif
62546309}
62556310
6311+ template <typename dst_t >
6312+ static void dequantize_q4_0_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6313+ const int nb32 = k / 32 ;
6314+ const int nb = (k + 255 ) / 256 ;
6315+ dequantize_block_q4_0<<<nb, 32 , 0 , stream>>> (vx, y, nb32);
6316+ }
6317+
6318+ template <typename dst_t >
6319+ static void dequantize_q4_1_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6320+ const int nb32 = k / 32 ;
6321+ const int nb = (k + 255 ) / 256 ;
6322+ dequantize_block_q4_1<<<nb, 32 , 0 , stream>>> (vx, y, nb32);
6323+ }
6324+
62566325template <typename dst_t >
62576326static void dequantize_row_q4_K_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
62586327 const int nb = k / QK_K;
@@ -6301,9 +6370,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
63016370 int id;
63026371 switch (type) {
63036372 case GGML_TYPE_Q4_0:
6304- return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0> ;
6373+ return dequantize_q4_0_cuda ;
63056374 case GGML_TYPE_Q4_1:
6306- return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1> ;
6375+ return dequantize_q4_1_cuda ;
63076376 case GGML_TYPE_Q5_0:
63086377 return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
63096378 case GGML_TYPE_Q5_1:
@@ -6338,9 +6407,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
63386407static to_fp32_cuda_t ggml_get_to_fp32_cuda (ggml_type type) {
63396408 switch (type) {
63406409 case GGML_TYPE_Q4_0:
6341- return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0> ;
6410+ return dequantize_q4_0_cuda ;
63426411 case GGML_TYPE_Q4_1:
6343- return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1> ;
6412+ return dequantize_q4_1_cuda ;
63446413 case GGML_TYPE_Q5_0:
63456414 return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
63466415 case GGML_TYPE_Q5_1:
0 commit comments