Skip to content

Commit cbc8343

Browse files
ikawrakowIwan Kawrakow
andauthored
Make IQ1_M work for QK_K = 64 (#6327)
* iq1_m: make it work for QK_K = 64 (WIP) * iq1_m: make it work for QK_K = 64 (scalar and AVX2) * iq1_m: QK_K = 64 seems to work on Metal and ARM_NEON --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent e562b97 commit cbc8343

File tree

3 files changed

+97
-4
lines changed

3 files changed

+97
-4
lines changed

ggml-common.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,13 +377,20 @@ typedef struct {
377377
} block_iq1_s;
378378
static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
379379

380-
// 1.8125 bpw
380+
// 1.75 bpw
381381
typedef struct {
382382
uint8_t qs[QK_K/8]; // grid index, low 8 bits
383383
uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8)
384-
uint8_t scales[QK_K/32]; // 4-bit block scales
384+
#if QK_K == 64
385+
ggml_half d;
386+
#endif
387+
uint8_t scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64)
385388
} block_iq1_m;
389+
#if QK_K == 64
390+
static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32 + sizeof(ggml_half), "wrong iq1_m block size/padding");
391+
#else
386392
static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
393+
#endif
387394

388395
// Used by IQ1_M quants
389396
typedef union {

ggml-metal.metal

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4497,7 +4497,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
44974497

44984498
device const float * y4 = y + 32 * ix;
44994499

4500+
#if QK_K != 64
45004501
iq1m_scale_t scale;
4502+
#endif
45014503

45024504
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
45034505

@@ -4519,7 +4521,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
45194521

45204522
for (int row = 0; row < N_DST; row++) {
45214523

4524+
#if QK_K != 64
45224525
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
4526+
#endif
45234527

45244528
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
45254529
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
@@ -4535,8 +4539,14 @@ void kernel_mul_mv_iq1_m_f32_impl(
45354539
}
45364540
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
45374541
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
4542+
#if QK_K == 64
4543+
const float d = (float) *((device const half *)(sc - 1));
4544+
sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
4545+
(sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
4546+
#else
45384547
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
45394548
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
4549+
#endif
45404550

45414551
sc += nb*sizeof(block_iq1_m)/2;
45424552
qs += nb*sizeof(block_iq1_m);
@@ -5277,13 +5287,21 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
52775287
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
52785288
const int ib32 = il/2;
52795289
il = il%2;
5280-
iq1m_scale_t scale;
52815290
device const uint16_t * sc = (device const uint16_t *)xb->scales;
5291+
#if QK_K == 64
5292+
const float d = xb->d;
5293+
#else
5294+
iq1m_scale_t scale;
52825295
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
52835296
const float d = scale.f16;
5297+
#endif
52845298
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
52855299
device const uint8_t * qh = xb->qh + 2*ib32 + il;
5300+
#if QK_K == 64
5301+
const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
5302+
#else
52865303
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
5304+
#endif
52875305
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
52885306
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
52895307
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));

ggml-quants.c

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3481,19 +3481,30 @@ void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, in
34813481
float delta[4];
34823482
uint16_t idx[4];
34833483

3484+
#if QK_K != 64
34843485
iq1m_scale_t scale;
3486+
#endif
34853487

34863488
for (int i = 0; i < nb; i++) {
34873489

34883490
const uint16_t * sc = (const uint16_t *)x[i].scales;
3491+
#if QK_K == 64
3492+
const float d = GGML_FP16_TO_FP32(x[i].d);
3493+
#else
34893494
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
34903495
const float d = GGML_FP16_TO_FP32(scale.f16);
3496+
#endif
34913497
const uint8_t * qs = x[i].qs;
34923498
const uint8_t * qh = x[i].qh;
34933499

34943500
for (int ib = 0; ib < QK_K/32; ++ib) {
3501+
#if QK_K == 64
3502+
const float dl1 = d * (2*((sc[ib/2] >> (8*(ib%2)+0)) & 0xf) + 1);
3503+
const float dl2 = d * (2*((sc[ib/2] >> (8*(ib%2)+4)) & 0xf) + 1);
3504+
#else
34953505
const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1);
34963506
const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1);
3507+
#endif
34973508
idx[0] = qs[0] | ((qh[0] << 8) & 0x700);
34983509
idx[1] = qs[1] | ((qh[0] << 4) & 0x700);
34993510
idx[2] = qs[2] | ((qh[1] << 8) & 0x700);
@@ -9756,11 +9767,17 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
97569767

97579768
const int nb = n / QK_K;
97589769

9770+
#if QK_K != 64
97599771
iq1m_scale_t scale;
9772+
#endif
97609773

97619774
#if defined __ARM_NEON
97629775

9776+
#if QK_K == 64
9777+
const int32x4_t mask = vdupq_n_s32(0xf);
9778+
#else
97639779
const int32x4_t mask = vdupq_n_s32(0x7);
9780+
#endif
97649781
const int32x4_t mone = vdupq_n_s32(1);
97659782
const int32x4_t mzero = vdupq_n_s32(0);
97669783

@@ -9784,7 +9801,9 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
97849801
const uint8_t * qh = x[i].qh;
97859802
const uint16_t * sc = (const uint16_t *)x[i].scales;
97869803

9804+
#if QK_K != 64
97879805
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
9806+
#endif
97889807

97899808
int32x4_t sumi1 = mzero;
97909809
int32x4_t sumi2 = mzero;
@@ -9813,7 +9832,11 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
98139832
const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
98149833
const int32x4_t p34 = vpaddq_s32(p3, p4);
98159834

9835+
#if QK_K == 64
9836+
int32x4_t scales_4 = ggml_vld1q_u32(sc[0] >> 0, sc[0] >> 4, sc[0] >> 8, sc[0] >> 12);
9837+
#else
98169838
int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
9839+
#endif
98179840
scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
98189841

98199842
sumi1 = vmlaq_s32(sumi1, scales_4, p12);
@@ -9823,14 +9846,22 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
98239846

98249847
}
98259848

9849+
#if QK_K == 64
9850+
sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
9851+
#else
98269852
sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
9853+
#endif
98279854
}
98289855

98299856
*s = sumf;
98309857

98319858
#elif defined __AVX2__
98329859

9860+
#if QK_K == 64
9861+
const __m256i mask = _mm256_set1_epi16(0xf);
9862+
#else
98339863
const __m256i mask = _mm256_set1_epi16(0x7);
9864+
#endif
98349865
const __m256i mone = _mm256_set1_epi16(1);
98359866

98369867
__m256 accum1 = _mm256_setzero_ps();
@@ -9842,7 +9873,9 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
98429873
const uint8_t * qh = x[i].qh;
98439874
const uint16_t * sc = (const uint16_t *)x[i].scales;
98449875

9876+
#if QK_K != 64
98459877
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
9878+
#endif
98469879

98479880
__m256i sumi1 = _mm256_setzero_si256();
98489881
__m256i sumi2 = _mm256_setzero_si256();
@@ -9872,8 +9905,13 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
98729905

98739906
const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
98749907
const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
9908+
#if QK_K == 64
9909+
__m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 4), _mm_set1_epi16(sc[0] >> 0));
9910+
__m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 12), _mm_set1_epi16(sc[0] >> 8));
9911+
#else
98759912
__m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
98769913
__m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
9914+
#endif
98779915
scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
98789916
scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
98799917
const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
@@ -9887,7 +9925,11 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
98879925
qs += 8; qh += 4;
98889926
}
98899927

9928+
#if QK_K == 64
9929+
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
9930+
#else
98909931
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
9932+
#endif
98919933
accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
98929934
accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
98939935

@@ -9907,7 +9949,9 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
99079949
const uint8_t * qh = x[i].qh;
99089950
const uint16_t * sc = (const uint16_t *)x[i].scales;
99099951

9952+
#if QK_K != 64
99109953
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
9954+
#endif
99119955

99129956
int sumi1 = 0, sumi2 = 0;
99139957
for (int ib = 0; ib < QK_K/32; ++ib) {
@@ -9927,15 +9971,24 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
99279971
sum1[l/2] += lsum1;
99289972
sum2[l/2] += lsum2*delta[l];
99299973
}
9974+
#if QK_K == 64
9975+
const int ls1 = 2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1;
9976+
const int ls2 = 2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1;
9977+
#else
99309978
const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
99319979
const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
9980+
#endif
99329981
sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
99339982
sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
99349983
qs += 4;
99359984
qh += 2;
99369985
}
99379986

9987+
#if QK_K == 64
9988+
sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
9989+
#else
99389990
sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
9991+
#endif
99399992
}
99409993

99419994
*s = sumf;
@@ -11986,7 +12039,9 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
1198612039

1198712040
for (int ibl = 0; ibl < nbl; ++ibl) {
1198812041

11989-
//y[ibl].d = GGML_FP32_TO_FP16(0.f);
12042+
#if QK_K == 64
12043+
y[ibl].d = GGML_FP32_TO_FP16(0.f);
12044+
#endif
1199012045
memset(y[ibl].qs, 0, QK_K/8);
1199112046
memset(y[ibl].qh, 0, QK_K/16);
1199212047
memset(y[ibl].scales, 0, QK_K/32);
@@ -12161,13 +12216,22 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
1216112216
}
1216212217

1216312218
uint16_t * sc = (uint16_t *)y[ibl].scales;
12219+
#if QK_K == 64
12220+
float d = max_scale/31;
12221+
#else
1216412222
float d = max_scale/15;
12223+
#endif
1216512224
float id = 1/d;
1216612225
float sumqx_f = 0, sumq2_f = 0;
1216712226
for (int ib = 0; ib < QK_K/block_size; ++ib) {
1216812227
int l = nearest_int(0.5f*(id*scales[ib+0]-1));
12228+
#if QK_K == 64
12229+
l = MAX(0, MIN(15, l));
12230+
sc[ib/4] |= (l << 4*(ib%4));
12231+
#else
1216912232
l = MAX(0, MIN(7, l));
1217012233
sc[ib/4] |= (l << 3*(ib%4));
12234+
#endif
1217112235
y[ibl].qh[ib] |= masks[shifts[ib]];
1217212236
const float * xb = xbl + block_size*ib;
1217312237
if (quant_weights) {
@@ -12190,10 +12254,14 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
1219012254
}
1219112255
if (sumq2_f > 0) d = sumqx_f/sumq2_f;
1219212256
s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
12257+
#if QK_K == 64
12258+
y[ibl].d = s.f16;
12259+
#else
1219312260
sc[0] |= ((s.u16 & 0x000f) << 12);
1219412261
sc[1] |= ((s.u16 & 0x00f0) << 8);
1219512262
sc[2] |= ((s.u16 & 0x0f00) << 4);
1219612263
sc[3] |= ((s.u16 & 0xf000) << 0);
12264+
#endif
1219712265
}
1219812266
}
1219912267

0 commit comments

Comments
 (0)