Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,38 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_me
return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);

GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(ggml_is_contiguous(op->src[1]));

char base[256];
char name[256];

const char * suffix = "";
if (op->src[1]->ne[0] % 4 == 0) {
suffix = "_4";
}

snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

ggml_metal_cv_free(cv);
}

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);

Expand All @@ -427,7 +459,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_me
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}

res.smem = 32*sizeof(float)*nsg;
// Shared memory layout:
// - sgptg * NW floats for partial sums (nsg * 32)
// - sgptg floats for shared_x_dt (nsg)
// - sgptg floats for shared_dA (nsg)
// Total: nsg * (32 + 2) floats
res.smem = (32 + 2)*sizeof(float)*nsg;

return res;
}
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_ad
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
#define FC_MUL_MV 600
#define FC_MUL_MM 700
#define FC_ROPE 800
#define FC_SSM_CONV 900

// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8
Expand Down
42 changes: 35 additions & 7 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1365,15 +1365,43 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
/*.nb2 =*/ nb2,
};

auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
// Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
const bool use_batched = (ne1 > 1);

if (use_batched) {
// Determine the smallest power of 2 that's >= ne1, but <= 256
int BATCH_SIZE;
if (ne1 > 128) BATCH_SIZE = 256;
else if (ne1 > 64 ) BATCH_SIZE = 128;
else if (ne1 > 32 ) BATCH_SIZE = 64;
else if (ne1 > 16 ) BATCH_SIZE = 32;
else if (ne1 > 8 ) BATCH_SIZE = 16;
else if (ne1 > 4 ) BATCH_SIZE = 8;
else BATCH_SIZE = 2;

auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
// Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
// Each threadgroup has BATCH_SIZE threads, each handling one token
const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
} else {
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
}
Comment on lines +1394 to +1404
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the old kernel faster for ne1 == 1? If not, we can remove it and always use the batched kernel?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I'll test that today.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I'll test that today.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like non-batched is significantly faster for ne1 == 1, so I think we should keep both paths.


return 1;
}
Expand Down
136 changes: 127 additions & 9 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2343,7 +2343,102 @@ kernel void kernel_ssm_conv_f32_f32_4(
x[0] = sumf;
}

constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];

// Batched version: each threadgroup processes multiple tokens for better efficiency
// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
kernel void kernel_ssm_conv_f32_f32_batched(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// tgpig.x = row index (ir)
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
// tgpig.z = sequence index (i3)
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
const short BATCH_SIZE = FC_ssm_conv_bs;

const int64_t ir = tgpig.x;
const int64_t i2_base = tgpig.y * BATCH_SIZE;
const int64_t i3 = tgpig.z;
const int64_t i2_off = tpitg.x;
const int64_t i2 = i2_base + i2_off;

const int64_t nc = args.ne10; // conv kernel size (typically 4)
const int64_t n_t = args.ne1; // number of tokens

// Bounds check for partial batches at the end
if (i2 >= n_t) {
return;
}

// Load conv weights (shared across all tokens for this row)
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);

// Load source for this specific token
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);

// Output location for this token
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);

float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
sumf += s[i0] * c[i0];
}

x[0] = sumf;
}

kernel void kernel_ssm_conv_f32_f32_batched_4(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// tgpig.x = row index (ir)
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
// tgpig.z = sequence index (i3)
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
const short BATCH_SIZE = FC_ssm_conv_bs;

const int64_t ir = tgpig.x;
const int64_t i2_base = tgpig.y * BATCH_SIZE;
const int64_t i3 = tgpig.z;
const int64_t i2_off = tpitg.x;
const int64_t i2 = i2_base + i2_off;

const int64_t nc = args.ne10; // conv kernel size (typically 4)
const int64_t n_t = args.ne1; // number of tokens

// Bounds check for partial batches at the end
if (i2 >= n_t) {
return;
}

// Load conv weights (shared across all tokens for this row)
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);

// Load source for this specific token
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);

// Output location for this token
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);

float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}

x[0] = sumf;
}

// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// Optimized version: reduces redundant memory loads by having one thread load shared values
kernel void kernel_ssm_scan_f32(
constant ggml_metal_kargs_ssm_scan & args,
device const void * src0,
Expand All @@ -2363,7 +2458,15 @@ kernel void kernel_ssm_scan_f32(
uint3 tgpg[[threadgroups_per_grid]]) {
constexpr short NW = N_SIMDWIDTH;

shared[tpitg.x] = 0.0f;
// Shared memory layout:
// [0..sgptg*NW-1]: partial sums for reduction (existing)
// [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
// [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
threadgroup float * shared_sums = shared;
threadgroup float * shared_x_dt = shared + sgptg * NW;
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;

shared_sums[tpitg.x] = 0.0f;

const int32_t i0 = tpitg.x;
const int32_t i1 = tgpig.x;
Expand Down Expand Up @@ -2403,32 +2506,47 @@ kernel void kernel_ssm_scan_f32(
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
threadgroup_barrier(mem_flags::mem_threadgroup);

for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
const float dt0 = dt[0];
// Pre-compute x_dt and dA for this batch of tokens
// Only first sgptg threads do the loads and expensive math
if (i0 < sgptg && i2 + i0 < n_t) {
// ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
device const float * x_t = x + i0 * args.ns12;
device const float * dt_t = dt + i0 * args.ns21;

const float dt0 = dt_t[0];
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
const float x_dt = x[0] * dtsp;
const float dA = exp(dtsp * A0);
shared_x_dt[i0] = x_t[0] * dtsp;
shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
}

threadgroup_barrier(mem_flags::mem_threadgroup);

for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
const float x_dt = shared_x_dt[t];
const float dA = exp(shared_dA[t] * A0);

s = (s0 * dA) + (B[i0] * x_dt);

const float sumf = simd_sum(s * C[i0]);

if (tiisg == 0) {
shared[t*NW + sgitg] = sumf;
shared_sums[t*NW + sgitg] = sumf;
}

// recurse
s0 = s;

x += args.ns12;
dt += args.ns21;
B += args.ns42;
C += args.ns52;
}

// Advance pointers for next batch
x += sgptg * args.ns12;
dt += sgptg * args.ns21;

threadgroup_barrier(mem_flags::mem_threadgroup);

const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);

if (tiisg == 0 && i2 + sgitg < n_t) {
y[sgitg*nh*nr] = sumf;
Expand Down
7 changes: 7 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8164,6 +8164,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
}

// Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate


return test_cases;
}

Expand Down
Loading