Skip to content

Conversation

@gabe-l-hart
Copy link
Collaborator

Description

This branch adds performance improvements on metal for both SSM_CONV and SSM_SCAN. The kernels were heavily edited by Claude Code, but I've reviewed all changes.

Changes

  • SSM_CONV: Implement a batched version that uses batches of 256 threads for multi-token prefill.
    • Split what was the y dim of the outer grid into y / BATCH_SIZE and use BATCH_SIZE as x for threadgroup (inner grid)
    • Recompute offsets from tgpig.y and tpitg.x
  • SSM_SCAN: Reduce redundant x_dt / dA computations

Performance

./bin/llama-batched-bench -m ~/models/ibm-granite/granite-4.0-h-1b/ggml-model-Q8_0.gguf -c 131072 -b 2048 -ub 512 -npp 1024,4096,8192 -ntg 128 -npl 1,4,8 -ngl 99

Baseline (c8554b6)

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 128 1 1152 0.719 1423.52 1.407 91.00 2.126 541.87
1024 128 4 4608 2.877 1423.62 2.947 173.72 5.824 791.15
1024 128 8 9216 5.748 1425.13 5.269 194.33 11.018 836.48
4096 128 1 4224 2.887 1418.90 1.415 90.49 4.301 982.03
4096 128 4 16896 11.537 1420.09 2.990 171.24 14.527 1163.05
4096 128 8 33792 23.169 1414.28 6.213 164.80 29.383 1150.06
8192 128 1 8320 6.458 1268.54 1.437 89.11 7.894 1053.92
8192 128 4 33280 23.539 1392.09 3.236 158.24 26.774 1242.98
8192 128 8 66560 47.488 1380.06 5.973 171.43 53.461 1245.02

SSM_CONV improvements

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 128 1 1152 0.457 2240.26 1.407 91.00 1.864 618.12
1024 128 4 4608 1.821 2249.29 2.948 173.68 4.769 966.24
1024 128 8 9216 3.641 2250.21 5.265 194.47 8.906 1034.80
4096 128 1 4224 1.834 2233.20 1.410 90.76 3.244 1301.91
4096 128 4 16896 7.332 2234.56 2.997 170.86 10.329 1635.82
4096 128 8 33792 14.683 2231.72 5.347 191.51 20.030 1687.09
8192 128 1 8320 3.723 2200.12 1.425 89.81 5.149 1615.96
8192 128 4 33280 15.000 2184.60 3.074 166.54 18.074 1841.32
8192 128 8 66560 33.971 1929.15 6.071 168.67 40.043 1662.23

SSM_CONV + SSM_SCAN improvements

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 128 1 1152 0.437 2345.00 1.411 90.75 1.847 623.65
1024 128 4 4608 1.732 2364.53 2.977 171.98 4.709 978.47
1024 128 8 9216 3.487 2349.39 5.374 190.53 8.861 1040.03
4096 128 1 4224 1.753 2336.90 1.425 89.80 3.178 1329.05
4096 128 4 16896 7.007 2338.09 3.020 169.55 10.027 1685.03
4096 128 8 33792 14.042 2333.49 5.412 189.22 19.454 1737.01
8192 128 1 8320 3.572 2293.39 1.434 89.25 5.006 1661.94
8192 128 4 33280 14.758 2220.40 3.208 159.62 17.965 1852.45
8192 128 8 66560 33.036 1983.79 6.053 169.17 39.089 1702.79

@github-actions github-actions bot added ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Dec 9, 2025
@jeffbolznv
Copy link
Collaborator

Mind adding some representative cases to test-backend-ops perf?

Comment on lines +1385 to +1404
} else {
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the old kernel faster for ne1 == 1? If not, we can remove it and always use the batched kernel?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I'll test that today.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I'll test that today.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like non-batched is significantly faster for ne1 == 1, so I think we should keep both paths.

Comment on lines 1368 to 1404
// Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
constexpr int BATCH_SIZE = 256;
const bool use_batched = (ne1 > 1);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
if (use_batched) {
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
// Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
// Each threadgroup has BATCH_SIZE threads, each handling one token
const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
} else {
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting - I am quite surprised that this change makes such a big difference. I have to try this approach for all other kernels that launch threadgroups with just 1 thread: unary, binary, scale, clamp, fill, etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I was very surprised too. All credit to Claude Code with Opus 4.5 for the insight.


// Batched version: each threadgroup processes multiple tokens for better efficiency
// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
template<int BATCH_SIZE>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameter does not have to be a template. The better pattern is to make it a function constant. For example, see how this works:

constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
template<typename block_q_type, short NR0, typename args_t>
void mul_vec_q_n_f32_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;

Then observe how we pass the function constants during the construction of the pipeline:

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
char base[256];
char name[256];
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
return res;
}

This way, we can construct pipelines at runtime with different FC_ssm_conv_bs, instead of defining many template instantiations that would increase the compile time.

When you do that, add logic in ggml_metal_op_ssm_conv to determine the smallest power of 2 that is larger or equal to ne1 and less or equal than 256. Use that power to construct a pipeline with the respective batch size.

For example, if ne1 == 100, we want a pipeline with FC_ssm_conv_bs == 128. And so on.

@gabe-l-hart
Copy link
Collaborator Author

Test results for test-backend-ops perf -o SSM_CONV -b Metal

@jeffbolznv I've honestly never quite known how to interpret the results of test-backend-ops perf. Intuitively, I would think a higher us/run would mean slower and a higher GB/s would mean more efficient throughput, but these results seem to show I have that exactly backwards.

@ggerganov Assuming my intuition is backwards, then it looks like the batched implementation is indeed always faster. Would it be worth also adding a float4 variant of the batch implementation?

Test results for test-backend-ops perf:

Baseline (non-batch)

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32', name = 'kernel_ssm_conv_f32_f32'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32                       0x12010a0a0 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]):             139264 runs -     7.19 us/run -       36 kB/run -    4.77 GB/s
  SSM_CONV(type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]):              65536 runs -    16.12 us/run -       68 kB/run -    4.02 GB/s
  SSM_CONV(type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]):              49152 runs -    20.51 us/run -      108 kB/run -    5.02 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]):             114688 runs -     9.31 us/run -       54 kB/run -    5.53 GB/s
  SSM_CONV(type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]):              49152 runs -    22.91 us/run -      102 kB/run -    4.25 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]):              40960 runs -    29.60 us/run -      162 kB/run -    5.22 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]):              90112 runs -    11.61 us/run -       72 kB/run -    5.92 GB/s
  SSM_CONV(type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]):              40960 runs -    29.68 us/run -      136 kB/run -    4.37 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]):              32768 runs -    38.45 us/run -      216 kB/run -    5.36 GB/s
ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_4', name = 'kernel_ssm_conv_f32_f32_4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_4                     0x11df051c0 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]):             204800 runs -     5.04 us/run -       36 kB/run -    6.82 GB/s
  SSM_CONV(type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]):              73728 runs -    14.26 us/run -       68 kB/run -    4.55 GB/s
  SSM_CONV(type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]):              90112 runs -    11.88 us/run -       96 kB/run -    7.71 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]):             163840 runs -     6.13 us/run -       54 kB/run -    8.40 GB/s
  SSM_CONV(type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]):              57344 runs -    19.96 us/run -      102 kB/run -    4.87 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]):              65536 runs -    16.36 us/run -      144 kB/run -    8.39 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]):             139264 runs -     7.38 us/run -       72 kB/run -    9.30 GB/s
  SSM_CONV(type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]):              40960 runs -    25.68 us/run -      136 kB/run -    5.05 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]):              49152 runs -    20.84 us/run -      192 kB/run -    8.78 GB/s

Batched for n_t > 1

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_b256', name = 'kernel_ssm_conv_f32_f32_b256'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_b256                  0x107505090 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]):              57344 runs -    19.91 us/run -       36 kB/run -    1.72 GB/s
  SSM_CONV(type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]):              57344 runs -    20.04 us/run -       68 kB/run -    3.24 GB/s
  SSM_CONV(type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]):              49152 runs -    21.20 us/run -      108 kB/run -    4.86 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]):              40960 runs -    28.42 us/run -       54 kB/run -    1.81 GB/s
  SSM_CONV(type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]):              40960 runs -    28.68 us/run -      102 kB/run -    3.39 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]):              32768 runs -    30.60 us/run -      162 kB/run -    5.05 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]):              32768 runs -    37.49 us/run -       72 kB/run -    1.83 GB/s
  SSM_CONV(type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]):              32768 runs -    37.75 us/run -      136 kB/run -    3.44 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]):              32768 runs -    40.05 us/run -      216 kB/run -    5.14 GB/s
ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_4', name = 'kernel_ssm_conv_f32_f32_4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_4                     0x301505560 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]):             204800 runs -     5.03 us/run -       36 kB/run -    6.82 GB/s
  SSM_CONV(type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]):              57344 runs -    19.56 us/run -       68 kB/run -    3.31 GB/s
  SSM_CONV(type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]):              90112 runs -    11.93 us/run -       96 kB/run -    7.68 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]):             163840 runs -     6.17 us/run -       54 kB/run -    8.35 GB/s
  SSM_CONV(type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]):              40960 runs -    28.45 us/run -      102 kB/run -    3.42 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]):              65536 runs -    16.32 us/run -      144 kB/run -    8.41 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]):             139264 runs -     7.36 us/run -       72 kB/run -    9.33 GB/s
  SSM_CONV(type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]):              32768 runs -    37.16 us/run -      136 kB/run -    3.49 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]):              49152 runs -    20.85 us/run -      192 kB/run -    8.78 GB/s

Batched always

  SSM_CONV(type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]):              57344 runs -    19.91 us/run -       36 kB/run -    1.72 GB/s
  SSM_CONV(type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]):              57344 runs -    19.81 us/run -       68 kB/run -    3.27 GB/s
  SSM_CONV(type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]):              49152 runs -    21.16 us/run -      108 kB/run -    4.87 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]):              40960 runs -    28.41 us/run -       54 kB/run -    1.81 GB/s
  SSM_CONV(type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]):              40960 runs -    28.69 us/run -      102 kB/run -    3.39 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]):              32768 runs -    30.57 us/run -      162 kB/run -    5.05 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]):              32768 runs -    37.20 us/run -       72 kB/run -    1.85 GB/s
  SSM_CONV(type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]):              32768 runs -    37.75 us/run -      136 kB/run -    3.44 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]):              32768 runs -    39.89 us/run -      216 kB/run -    5.16 GB/s
  # NOTE: n_t == 1
  SSM_CONV(type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]):              57344 runs -    19.49 us/run -       36 kB/run -    1.76 GB/s
  SSM_CONV(type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]):              57344 runs -    19.60 us/run -       68 kB/run -    3.31 GB/s
  SSM_CONV(type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]):              49152 runs -    20.92 us/run -       96 kB/run -    4.38 GB/s
  # NOTE: n_t == 1
  SSM_CONV(type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]):              40960 runs -    28.35 us/run -       54 kB/run -    1.82 GB/s
  SSM_CONV(type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]):              40960 runs -    28.43 us/run -      102 kB/run -    3.42 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]):              40960 runs -    29.87 us/run -      144 kB/run -    4.60 GB/s
  # NOTE: n_t == 1
  SSM_CONV(type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]):              32768 runs -    36.92 us/run -       72 kB/run -    1.86 GB/s
  SSM_CONV(type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]):              32768 runs -    37.20 us/run -      136 kB/run -    3.49 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]):              32768 runs -    38.95 us/run -      192 kB/run -    4.70 GB/s

@jeffbolznv
Copy link
Collaborator

You do want lower us/run, higher GB/s. The two values are the same data, just GB/s is computed by summing tensor sizes and dividing by the runtime.

@ggerganov
Copy link
Member

@gabe-l-hart These perf numbers don't add up to the observed llama-batched-bench results. Ideally, note down the actual shapes that are used during prompt processing and text generation and add tests that correspond to those shapes.

@gabe-l-hart
Copy link
Collaborator Author

gabe-l-hart commented Dec 9, 2025

Ok, I'm glad I'm not crazy! This does seem very fishy given the improved results with pp. I'll try to make these a closer representative test.

This was done using Claude Code. It found a number of optimizations around
how the threads were organized, resulting in a huge performance boost!

Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This used Claude Code and resulted in a modest performance improvement
while maintaining correctness.

Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@gabe-l-hart
Copy link
Collaborator Author

Oh boy, that makes much more sense!

    // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
    // d_inner == 3072
    // d_conv == 4
    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1}));

Batched

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_b256', name = 'kernel_ssm_conv_f32_f32_b256'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_b256                  0x136b083d0 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[515,3328,1,1],ne_b=[4,3328,1,1]):                    15024 runs -    70.64 us/run -    13403 kB/run -  180.96 GB/s

Un-Batched

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_4', name = 'kernel_ssm_conv_f32_f32_4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_4                     0x12320be70 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[515,3328,1,1],ne_b=[4,3328,1,1]):                     2504 runs -  3801.15 us/run -    13403 kB/run -    3.36 GB/s

@gabe-l-hart
Copy link
Collaborator Author

With a single-token (generate) example:

Batch for prefill only

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_b256', name = 'kernel_ssm_conv_f32_f32_b256'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_b256                  0x1292076d0 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[515,3328,1,1],ne_b=[4,3328,1,1]):                    15024 runs -    70.69 us/run -    13403 kB/run -  180.82 GB/s
ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_4', name = 'kernel_ssm_conv_f32_f32_4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_4                     0x12910b550 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[4,3328,1,1],ne_b=[4,3328,1,1]):              98304 runs -    10.22 us/run -      117 kB/run -   10.91 GB/s
  Backend Metal: OK

Batch for both

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_b256', name = 'kernel_ssm_conv_f32_f32_b256'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_b256                  0x157b070e0 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[515,3328,1,1],ne_b=[4,3328,1,1]):                    15024 runs -    70.58 us/run -    13403 kB/run -  181.11 GB/s
  SSM_CONV(type=f32,ne_a=[4,3328,1,1],ne_b=[4,3328,1,1]):              24576 runs -    59.17 us/run -      117 kB/run -    1.89 GB/s
  Backend Metal: OK

Non-Batch for both

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_4', name = 'kernel_ssm_conv_f32_f32_4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_4                     0x135e075e0 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[515,3328,1,1],ne_b=[4,3328,1,1]):                     2504 runs -  3801.51 us/run -    13403 kB/run -    3.36 GB/s
  SSM_CONV(type=f32,ne_a=[4,3328,1,1],ne_b=[4,3328,1,1]):              98304 runs -    10.26 us/run -      117 kB/run -   10.88 GB/s
  Backend Metal: OK

Given this, I think we should keep both versions dispatched on n_t like it is currently.

Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@gabe-l-hart gabe-l-hart force-pushed the SSMKernelImprovements branch from 427ae08 to da044cd Compare December 9, 2025 18:04
Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@gabe-l-hart
Copy link
Collaborator Author

Some similar numbers for representative tests on SSM_CONV:

Without Optimizations

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_scan_f32', name = 'kernel_ssm_scan_f32_nsg=4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_scan_f32_nsg=4                     0x149906a50 | th_max = 1024 | th_width =   32
  SSM_SCAN(type=f32,d_state=128,head_dim=64,n_head=48,n_group=1,n_seq_tokens=512,n_seqs=1):                     2102 runs -  2178.75 us/run -    15968 kB/run -    6.99 GB/s
  SSM_SCAN(type=f32,d_state=128,head_dim=64,n_head=48,n_group=1,n_seq_tokens=1,n_seqs=1):              40960 runs -    24.85 us/run -     3097 kB/run -  118.88 GB/s
  Backend Metal: OK

With Optimizations

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_scan_f32', name = 'kernel_ssm_scan_f32_nsg=4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_scan_f32_nsg=4                     0x146408690 | th_max = 1024 | th_width =   32
  SSM_SCAN(type=f32,d_state=128,head_dim=64,n_head=48,n_group=1,n_seq_tokens=512,n_seqs=1):                     2102 runs -  1909.71 us/run -    15968 kB/run -    7.97 GB/s
  SSM_SCAN(type=f32,d_state=128,head_dim=64,n_head=48,n_group=1,n_seq_tokens=1,n_seqs=1):              40960 runs -    26.50 us/run -     3097 kB/run -  111.48 GB/s
  Backend Metal: OK

x[0] = sumf;
}

// typedef decltype(kernel_ssm_conv_f32_f32_batched<1>) kernel_ssm_conv_batched_t;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this

Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@gabe-l-hart
Copy link
Collaborator Author

Another small speedup with a float4 version of the SSM_CONV batched impl:

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_batched_4', name = 'kernel_ssm_conv_f32_f32_batched_4_ssm_conv_bs=256'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_batched_4_ssm_conv_bs=256      0x126304280 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[515,3328,1,1],ne_b=[4,3328,1,1]):                    15024 runs -    68.79 us/run -    13403 kB/run -  185.80 GB/s
ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_4', name = 'kernel_ssm_conv_f32_f32_4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_4                     0x12610d700 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[4,3328,1,1],ne_b=[4,3328,1,1]):              98304 runs -    10.33 us/run -      117 kB/run -   10.80 GB/s
  Backend Metal: OK

@gabe-l-hart
Copy link
Collaborator Author

Hm, the failing test looks suspiciously related to this PR, but it's failing on CONV_2D which seems to pass just fine on my machine.

2025-12-09T18:44:28.7929820Z Failing tests:
2025-12-09T18:44:28.7930170Z   CONV_2D(ne_input=[1,1,25,2],ne_kernel=[2,1,25,1],type_kernel=f16,stride0=3,stride1=5,padding0=5,padding1=5,dilation0=2,dilation1=4,cwhn=0)
2025-12-09T18:44:28.7930420Z   Backend Metal: �[1;31mFAIL�[0m

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failing tests seems like a fluke - should be safe to ignore

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
@ggerganov ggerganov merged commit 086a63e into ggml-org:master Dec 9, 2025
62 of 69 checks passed
@gabe-l-hart gabe-l-hart deleted the SSMKernelImprovements branch December 9, 2025 19:30
@github-actions github-actions bot added the testing Everything test related label Dec 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants