Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1546,9 +1546,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
!ggml_is_transposed(op->src[1]) &&
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
props_dev->has_simdgroup_mm && ne00 >= 64 &&
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
//GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);

// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
Expand Down
66 changes: 38 additions & 28 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -7487,7 +7487,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}

template<int nr0, typename args_t>
template<int NR0, typename args_t>
void kernel_mul_mv_iq4_nl_f32_impl(
args_t args,
device const char * src0,
Expand All @@ -7500,13 +7500,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
const short NSG = FC_mul_mv_nsg;

threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK4_NL;

const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;

const int first_row = (r0 * NSG + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * NR0;

const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
Expand All @@ -7517,31 +7516,35 @@ void kernel_mul_mv_iq4_nl_f32_impl(
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);

const int nb = args.ne00/QK4_NL;
const int ns01 = args.nb01/args.nb00;

const short ix = tiisg/2; // 0...15
const short it = tiisg%2; // 0 or 1

shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);

float4 yl[4];
float sumf[nr0]={0.f};
float sumf[NR0]={0.f};

device const float * yb = y + ix * QK4_NL + it * 8;
device const float * yb = y + ix*QK4_NL + it*8;

uint32_t aux32[2];
thread const uint8_t * q8 = (thread const uint8_t *)aux32;

float4 qf1, qf2;

for (int ib = ix; ib < nb; ib += 16) {
// [TAG_MUL_MV_WEIRD]
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];

for (short row = 0; row < nr0; row++) {
device const block_iq4_nl & xb = x[row*nb + ib];
for (short row = 0; row < NR0; row++) {
device const block_iq4_nl & xb = x[row*ns01 + ib];
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);

float4 acc1 = {0.f}, acc2 = {0.f};
Expand Down Expand Up @@ -7572,7 +7575,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(

device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;

for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
Expand All @@ -7594,7 +7597,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}

template<int nr0, typename args_t>
template<int NR0, typename args_t>
void kernel_mul_mv_iq4_xs_f32_impl(
args_t args,
device const char * src0,
Expand All @@ -7607,12 +7610,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
const short NSG = FC_mul_mv_nsg;

threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_K;

const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * NSG + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * NR0;

const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
Expand All @@ -7623,6 +7625,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);

const int nb = args.ne00/QK_K;
const int ns01 = args.nb01/args.nb00;

const short ix = tiisg/16; // 0 or 1
const short it = tiisg%16; // 0...15
const short ib = it/2;
Expand All @@ -7632,7 +7637,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);

float4 yl[4];
float sumf[nr0]={0.f};
float sumf[NR0]={0.f};

device const float * yb = y + ix * QK_K + ib * 32 + il * 8;

Expand All @@ -7641,15 +7646,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(

float4 qf1, qf2;

for (int ibl = ix; ibl < nb; ibl += 2) {
// [TAG_MUL_MV_WEIRD]
for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];

for (short row = 0; row < nr0; ++row) {
device const block_iq4_xs & xb = x[row*nb + ibl];
for (short row = 0; row < NR0; ++row) {
device const block_iq4_xs & xb = x[row*ns01 + ibl];
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);

float4 acc1 = {0.f}, acc2 = {0.f};
Expand Down Expand Up @@ -7679,7 +7685,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(

device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;

for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
Expand All @@ -7701,7 +7707,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}

template<int nr0, typename args_t>
template<int NR0, typename args_t>
void kernel_mul_mv_mxfp4_f32_impl(
args_t args,
device const char * src0,
Expand All @@ -7714,13 +7720,12 @@ void kernel_mul_mv_mxfp4_f32_impl(
const short NSG = FC_mul_mv_nsg;

threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_MXFP4;

const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;

const int first_row = (r0 * NSG + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * NR0;

const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
Expand All @@ -7731,27 +7736,32 @@ void kernel_mul_mv_mxfp4_f32_impl(
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);

const int nb = args.ne00/QK_MXFP4;
const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors

const short ix = tiisg/2; // 0...15
const short it = tiisg%2; // 0 or 1

shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);

float4 yl[4];
float sumf[nr0]={0.f};
float sumf[NR0]={0.f};

device const float * yb = y + ix * QK_MXFP4 + it * 8;
device const float * yb = y + ix*QK_MXFP4 + it*8;

// note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
// no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
device const float4 * y4 = (device const float4 *) yb;

for (int ib = ix; ib < nb; ib += 16) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];

#pragma unroll(nr0)
for (short row = 0; row < nr0; row++) {
device const block_mxfp4 & xb = x[row*nb + ib];
FOR_UNROLL (short row = 0; row < NR0; row++) {
device const block_mxfp4 & xb = x[row*ns01 + ib];
device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);

float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
Expand All @@ -7769,7 +7779,7 @@ void kernel_mul_mv_mxfp4_f32_impl(

device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;

for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
Expand Down
8 changes: 4 additions & 4 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16313,10 +16313,10 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
}

ggml_tensor * build_layer_ffn(
ggml_tensor * cur,
ggml_tensor * inpSA,
const llama_model & model,
const int il) {
ggml_tensor * cur,
ggml_tensor * inpSA,
const llama_model & model,
const int il) {

// For Granite architectures - scale residual
if (hparams.f_residual_scale) {
Expand Down
Loading