@@ -3481,19 +3481,30 @@ void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, in
34813481 float delta[4];
34823482 uint16_t idx[4];
34833483
3484+ #if QK_K != 64
34843485 iq1m_scale_t scale;
3486+ #endif
34853487
34863488 for (int i = 0; i < nb; i++) {
34873489
34883490 const uint16_t * sc = (const uint16_t *)x[i].scales;
3491+ #if QK_K == 64
3492+ const float d = GGML_FP16_TO_FP32(x[i].d);
3493+ #else
34893494 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
34903495 const float d = GGML_FP16_TO_FP32(scale.f16);
3496+ #endif
34913497 const uint8_t * qs = x[i].qs;
34923498 const uint8_t * qh = x[i].qh;
34933499
34943500 for (int ib = 0; ib < QK_K/32; ++ib) {
3501+ #if QK_K == 64
3502+ const float dl1 = d * (2*((sc[ib/2] >> (8*(ib%2)+0)) & 0xf) + 1);
3503+ const float dl2 = d * (2*((sc[ib/2] >> (8*(ib%2)+4)) & 0xf) + 1);
3504+ #else
34953505 const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1);
34963506 const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1);
3507+ #endif
34973508 idx[0] = qs[0] | ((qh[0] << 8) & 0x700);
34983509 idx[1] = qs[1] | ((qh[0] << 4) & 0x700);
34993510 idx[2] = qs[2] | ((qh[1] << 8) & 0x700);
@@ -9756,11 +9767,17 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
97569767
97579768 const int nb = n / QK_K;
97589769
9770+ #if QK_K != 64
97599771 iq1m_scale_t scale;
9772+ #endif
97609773
97619774#if defined __ARM_NEON
97629775
9776+ #if QK_K == 64
9777+ const int32x4_t mask = vdupq_n_s32(0xf);
9778+ #else
97639779 const int32x4_t mask = vdupq_n_s32(0x7);
9780+ #endif
97649781 const int32x4_t mone = vdupq_n_s32(1);
97659782 const int32x4_t mzero = vdupq_n_s32(0);
97669783
@@ -9784,7 +9801,9 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
97849801 const uint8_t * qh = x[i].qh;
97859802 const uint16_t * sc = (const uint16_t *)x[i].scales;
97869803
9804+ #if QK_K != 64
97879805 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
9806+ #endif
97889807
97899808 int32x4_t sumi1 = mzero;
97909809 int32x4_t sumi2 = mzero;
@@ -9813,7 +9832,11 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
98139832 const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
98149833 const int32x4_t p34 = vpaddq_s32(p3, p4);
98159834
9835+ #if QK_K == 64
9836+ int32x4_t scales_4 = ggml_vld1q_u32(sc[0] >> 0, sc[0] >> 4, sc[0] >> 8, sc[0] >> 12);
9837+ #else
98169838 int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
9839+ #endif
98179840 scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
98189841
98199842 sumi1 = vmlaq_s32(sumi1, scales_4, p12);
@@ -9823,14 +9846,22 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
98239846
98249847 }
98259848
9849+ #if QK_K == 64
9850+ sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
9851+ #else
98269852 sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
9853+ #endif
98279854 }
98289855
98299856 *s = sumf;
98309857
98319858#elif defined __AVX2__
98329859
9860+ #if QK_K == 64
9861+ const __m256i mask = _mm256_set1_epi16(0xf);
9862+ #else
98339863 const __m256i mask = _mm256_set1_epi16(0x7);
9864+ #endif
98349865 const __m256i mone = _mm256_set1_epi16(1);
98359866
98369867 __m256 accum1 = _mm256_setzero_ps();
@@ -9842,7 +9873,9 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
98429873 const uint8_t * qh = x[i].qh;
98439874 const uint16_t * sc = (const uint16_t *)x[i].scales;
98449875
9876+ #if QK_K != 64
98459877 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
9878+ #endif
98469879
98479880 __m256i sumi1 = _mm256_setzero_si256();
98489881 __m256i sumi2 = _mm256_setzero_si256();
@@ -9872,8 +9905,13 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
98729905
98739906 const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
98749907 const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
9908+ #if QK_K == 64
9909+ __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 4), _mm_set1_epi16(sc[0] >> 0));
9910+ __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 12), _mm_set1_epi16(sc[0] >> 8));
9911+ #else
98759912 __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
98769913 __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
9914+ #endif
98779915 scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
98789916 scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
98799917 const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
@@ -9887,7 +9925,11 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
98879925 qs += 8; qh += 4;
98889926 }
98899927
9928+ #if QK_K == 64
9929+ const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
9930+ #else
98909931 const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
9932+ #endif
98919933 accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
98929934 accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
98939935
@@ -9907,7 +9949,9 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
99079949 const uint8_t * qh = x[i].qh;
99089950 const uint16_t * sc = (const uint16_t *)x[i].scales;
99099951
9952+ #if QK_K != 64
99109953 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
9954+ #endif
99119955
99129956 int sumi1 = 0, sumi2 = 0;
99139957 for (int ib = 0; ib < QK_K/32; ++ib) {
@@ -9927,15 +9971,24 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
99279971 sum1[l/2] += lsum1;
99289972 sum2[l/2] += lsum2*delta[l];
99299973 }
9974+ #if QK_K == 64
9975+ const int ls1 = 2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1;
9976+ const int ls2 = 2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1;
9977+ #else
99309978 const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
99319979 const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
9980+ #endif
99329981 sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
99339982 sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
99349983 qs += 4;
99359984 qh += 2;
99369985 }
99379986
9987+ #if QK_K == 64
9988+ sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
9989+ #else
99389990 sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
9991+ #endif
99399992 }
99409993
99419994 *s = sumf;
@@ -11986,7 +12039,9 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
1198612039
1198712040 for (int ibl = 0; ibl < nbl; ++ibl) {
1198812041
11989- //y[ibl].d = GGML_FP32_TO_FP16(0.f);
12042+ #if QK_K == 64
12043+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
12044+ #endif
1199012045 memset(y[ibl].qs, 0, QK_K/8);
1199112046 memset(y[ibl].qh, 0, QK_K/16);
1199212047 memset(y[ibl].scales, 0, QK_K/32);
@@ -12161,13 +12216,22 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
1216112216 }
1216212217
1216312218 uint16_t * sc = (uint16_t *)y[ibl].scales;
12219+ #if QK_K == 64
12220+ float d = max_scale/31;
12221+ #else
1216412222 float d = max_scale/15;
12223+ #endif
1216512224 float id = 1/d;
1216612225 float sumqx_f = 0, sumq2_f = 0;
1216712226 for (int ib = 0; ib < QK_K/block_size; ++ib) {
1216812227 int l = nearest_int(0.5f*(id*scales[ib+0]-1));
12228+ #if QK_K == 64
12229+ l = MAX(0, MIN(15, l));
12230+ sc[ib/4] |= (l << 4*(ib%4));
12231+ #else
1216912232 l = MAX(0, MIN(7, l));
1217012233 sc[ib/4] |= (l << 3*(ib%4));
12234+ #endif
1217112235 y[ibl].qh[ib] |= masks[shifts[ib]];
1217212236 const float * xb = xbl + block_size*ib;
1217312237 if (quant_weights) {
@@ -12190,10 +12254,14 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
1219012254 }
1219112255 if (sumq2_f > 0) d = sumqx_f/sumq2_f;
1219212256 s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
12257+ #if QK_K == 64
12258+ y[ibl].d = s.f16;
12259+ #else
1219312260 sc[0] |= ((s.u16 & 0x000f) << 12);
1219412261 sc[1] |= ((s.u16 & 0x00f0) << 8);
1219512262 sc[2] |= ((s.u16 & 0x0f00) << 4);
1219612263 sc[3] |= ((s.u16 & 0xf000) << 0);
12264+ #endif
1219712265 }
1219812266}
1219912267
0 commit comments