@@ -1552,7 +1552,7 @@ class vk_perf_logger {
15521552 }
15531553 if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
15541554 const uint64_t m = node->src[0]->ne[1];
1555- const uint64_t n = node->ne[1];
1555+ const uint64_t n = ( node->op == GGML_OP_MUL_MAT) ? node-> ne[1] : node->ne[2 ];
15561556 const uint64_t k = node->src[1]->ne[0];
15571557 const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
15581558 std::string name = ggml_op_name(node->op);
@@ -6744,23 +6744,36 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
67446744 return true;
67456745 }
67466746
6747+ // Quantization overhead is not worth it for small k
67476748 switch (device->vendor_id) {
67486749 case VK_VENDOR_ID_NVIDIA:
6750+ if (k <= 4096) {
6751+ return false;
6752+ }
6753+
67496754 switch (src0_type) {
6750- case GGML_TYPE_Q8_0:
67516755 case GGML_TYPE_MXFP4:
6756+ case GGML_TYPE_Q8_0:
67526757 return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
67536758 default:
67546759 return true;
67556760 }
67566761 case VK_VENDOR_ID_AMD:
6762+ if (k < 2048) {
6763+ return false;
6764+ }
6765+
67576766 switch (src0_type) {
67586767 case GGML_TYPE_Q8_0:
67596768 return device->architecture == vk_device_architecture::AMD_GCN;
67606769 default:
67616770 return true;
67626771 }
67636772 case VK_VENDOR_ID_INTEL:
6773+ if (k < 2048) {
6774+ return false;
6775+ }
6776+
67646777 switch (src0_type) {
67656778 // From tests on A770 Linux, may need more tuning
67666779 case GGML_TYPE_Q4_0:
@@ -6774,7 +6787,6 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
67746787 }
67756788
67766789 GGML_UNUSED(m);
6777- GGML_UNUSED(k);
67786790}
67796791
67806792static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@@ -7274,7 +7286,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
72747286
72757287 const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
72767288
7277- bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ne21 >= 8 ;
7289+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
72787290
72797291 // Check for mmq first
72807292 vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
@@ -7509,7 +7521,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
75097521 const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
75107522
75117523 const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
7512- bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11 , ne10, src0->type);
7524+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12 , ne10, src0->type);
75137525
75147526 vk_pipeline to_fp16_vk_0 = nullptr;
75157527 vk_pipeline to_fp16_vk_1 = nullptr;
0 commit comments