@@ -357,38 +357,31 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0> {
357357 using q8_0_block = ggml_sycl_reordered::block_q_t <GGML_TYPE_Q8_0>;
358358 using q8_0_traits = typename q8_0_block::traits;
359359
360- __dpct_inline__ float vec_dot_q8_0_q8_1_impl (const int * v, const int * u, const float & d8_0, const sycl::half2 & ds8) {
361- int sumi = 0 ;
362-
363- #pragma unroll
364- for (size_t i = 0 ; i < q8_0_traits::vdr_mmvq; ++i) {
365- // Q8_0 values are signed int8, no nibble extraction needed
366- // Direct dp4a: each int packs 4 int8 values
367- sumi = dpct::dp4a (v[i], u[i], sumi);
368- }
369-
370- const sycl::float2 ds8f = ds8.convert <float , sycl::rounding_mode::automatic>();
371-
372- // Q8_0 has no bias term (values are signed), so just scale
373- return d8_0 * sumi * ds8f.x ();
374- }
375-
376360 __dpct_inline__ float operator ()(const void * __restrict__ vbq, const std::pair<int , int > ibx_offset,
377361 const std::pair<int , int > d_offset, const int8_t * q8_1_quant_ptr,
378362 const sycl::half2 * q8_1_ds, const int & iqs) {
379- const int8_t * bq8_0 = static_cast <const int8_t *>(vbq) + ibx_offset.first ;
380- const ggml_half d = *(reinterpret_cast <const ggml_half *>(static_cast <const uint8_t *>(vbq) + d_offset.first ));
381- int v[q8_0_traits::vdr_mmvq];
382- int u[q8_0_traits::vdr_mmvq];
363+ const uint8_t * base = static_cast <const uint8_t *>(vbq);
364+ const int8_t * qs = reinterpret_cast <const int8_t *>(base + ibx_offset.first );
365+ const ggml_half d = *reinterpret_cast <const ggml_half *>(base + d_offset.first );
366+
367+ int v[q8_0_traits::vdr_mmvq];
368+ int u[q8_0_traits::vdr_mmvq];
383369
384370#pragma unroll
385371 for (size_t i = 0 ; i < q8_0_traits::vdr_mmvq; ++i) {
386- v[i] = get_int_from_int8 (bq8_0 , iqs + i);
372+ v[i] = get_int_from_int8 (qs , iqs + i);
387373 u[i] = get_int_from_int8_aligned (q8_1_quant_ptr, iqs + i);
388374 }
389375
390- return vec_dot_q8_0_q8_1_impl (v, u, d, *q8_1_ds);
391- };
376+ int sumi = 0 ;
377+ #pragma unroll
378+ for (size_t i = 0 ; i < q8_0_traits::vdr_mmvq; ++i) {
379+ sumi = dpct::dp4a (v[i], u[i], sumi);
380+ }
381+
382+ const sycl::half2 ds_values = *q8_1_ds;
383+ return static_cast <float >(d) * static_cast <float >(ds_values[0 ]) * sumi;
384+ }
392385};
393386
394387static inline float vec_dot_q4_K_q8_1_common (const int * __restrict__ q4, const uint16_t * __restrict__ scales,
@@ -481,6 +474,65 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
481474 }
482475};
483476
477+ template <> struct reorder_vec_dot_q_sycl <GGML_TYPE_Q5_K> {
478+ static constexpr ggml_type gtype = GGML_TYPE_Q5_K;
479+
480+ using q5_k_block = ggml_sycl_reordered::block_q_t <GGML_TYPE_Q5_K>;
481+ using q5_k_traits = typename q5_k_block::traits;
482+
483+ __dpct_inline__ float operator ()(const void * __restrict__ vbq, const std::pair<int , int > ibx_offset,
484+ const std::pair<int , int > d_offset, const int8_t * q8_1_quant_ptr,
485+ const sycl::half2 * q8_1_ds, const int & iqs) {
486+ const uint8_t * base = static_cast <const uint8_t *>(vbq);
487+ const uint8_t * qs = base + ibx_offset.first ; // low 4 bits
488+ const uint8_t * qh_base = base + ibx_offset.second ; // high bit
489+ const uint8_t * scs = base + d_offset.first ;
490+ const ggml_half2 * dms = reinterpret_cast <const ggml_half2 *>(base + d_offset.second );
491+
492+ const int bq8_offset = QR5_K * ((iqs / 2 ) / (QI8_1 / 2 ));
493+ const int * ql_ptr = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2 ) % 4 ));
494+ const int * qh_ptr = (const int *) (qh_base + 4 * ((iqs / 2 ) % 4 ));
495+ const uint16_t * scales = (const uint16_t *) scs;
496+
497+ int vl[2 ];
498+ int vh[2 ];
499+ int u[2 * QR5_K];
500+ float d8[QR5_K];
501+
502+ vl[0 ] = ql_ptr[0 ];
503+ vl[1 ] = ql_ptr[4 ];
504+
505+ vh[0 ] = qh_ptr[0 ] >> bq8_offset;
506+ vh[1 ] = qh_ptr[4 ] >> bq8_offset;
507+
508+ uint16_t aux[2 ];
509+ const int j = (QR5_K * ((iqs / 2 ) / (QI8_1 / 2 ))) / 2 ;
510+ if (j < 2 ) {
511+ aux[0 ] = scales[j + 0 ] & 0x3f3f ;
512+ aux[1 ] = scales[j + 2 ] & 0x3f3f ;
513+ } else {
514+ aux[0 ] = ((scales[j + 2 ] >> 0 ) & 0x0f0f ) | ((scales[j - 2 ] & 0xc0c0 ) >> 2 );
515+ aux[1 ] = ((scales[j + 2 ] >> 4 ) & 0x0f0f ) | ((scales[j - 0 ] & 0xc0c0 ) >> 2 );
516+ }
517+
518+ const uint8_t * sc = (const uint8_t *) aux;
519+ const uint8_t * m = sc + 2 ;
520+
521+ for (int i = 0 ; i < QR5_K; ++i) {
522+ const int8_t * quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1;
523+ sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i);
524+
525+ d8[i] = ds_values[0 ];
526+
527+ const int * q8 = (const int *) quant_base_ptr + ((iqs / 2 ) % 4 );
528+ u[2 * i + 0 ] = q8[0 ];
529+ u[2 * i + 1 ] = q8[4 ];
530+ }
531+
532+ return vec_dot_q5_K_q8_1_impl_vmmq (vl, vh, u, sc, m, *dms, d8);
533+ }
534+ };
535+
484536template <> struct reorder_vec_dot_q_sycl <GGML_TYPE_Q6_K> {
485537 static constexpr ggml_type gtype = GGML_TYPE_Q6_K;
486538
0 commit comments