diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 1137e210773af..5f9370449bb2d 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1546,9 +1546,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { !ggml_is_transposed(op->src[1]) && // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - props_dev->has_simdgroup_mm && ne00 >= 64 && - (ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) { - //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) { + //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 45d91def88bf2..ddc285042d284 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -7487,7 +7487,7 @@ kernel void kernel_mul_mv_iq1_m_f32( kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, device const char * src0, @@ -7500,13 +7500,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7517,6 +7516,9 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK4_NL; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -7524,24 +7526,25 @@ void kernel_mul_mv_iq4_nl_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK4_NL + it * 8; + device const float * yb = y + ix*QK4_NL + it*8; uint32_t aux32[2]; thread const uint8_t * q8 = (thread const uint8_t *)aux32; float4 qf1, qf2; - for (int ib = ix; ib < nb; ib += 16) { + // [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; row++) { - device const block_iq4_nl & xb = x[row*nb + ib]; + for (short row = 0; row < NR0; row++) { + device const block_iq4_nl & xb = x[row*ns01 + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -7572,7 +7575,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -7594,7 +7597,7 @@ kernel void kernel_mul_mv_iq4_nl_f32( kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, device const char * src0, @@ -7607,12 +7610,11 @@ void kernel_mul_mv_iq4_xs_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7623,6 +7625,9 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_K; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/16; // 0 or 1 const short it = tiisg%16; // 0...15 const short ib = it/2; @@ -7632,7 +7637,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -7641,15 +7646,16 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; - for (int ibl = ix; ibl < nb; ibl += 2) { + // [TAG_MUL_MV_WEIRD] + for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; + for (short row = 0; row < NR0; ++row) { + device const block_iq4_xs & xb = x[row*ns01 + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -7679,7 +7685,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -7701,7 +7707,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_mxfp4_f32_impl( args_t args, device const char * src0, @@ -7714,13 +7720,12 @@ void kernel_mul_mv_mxfp4_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_MXFP4; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7731,6 +7736,9 @@ void kernel_mul_mv_mxfp4_f32_impl( device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_MXFP4; + const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -7738,20 +7746,22 @@ void kernel_mul_mv_mxfp4_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK_MXFP4 + it * 8; + device const float * yb = y + ix*QK_MXFP4 + it*8; + + // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster + // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { + device const float4 * y4 = (device const float4 *) yb; - for (int ib = ix; ib < nb; ib += 16) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; -#pragma unroll(nr0) - for (short row = 0; row < nr0; row++) { - device const block_mxfp4 & xb = x[row*nb + ib]; + FOR_UNROLL (short row = 0; row < NR0; row++) { + device const block_mxfp4 & xb = x[row*ns01 + ib]; device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it); float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]); @@ -7769,7 +7779,7 @@ void kernel_mul_mv_mxfp4_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a5fe5b749c355..36d495d6cfeab 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -16313,10 +16313,10 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { } ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - ggml_tensor * inpSA, - const llama_model & model, - const int il) { + ggml_tensor * cur, + ggml_tensor * inpSA, + const llama_model & model, + const int il) { // For Granite architectures - scale residual if (hparams.f_residual_scale) {