Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
char base[256];
char name[256];

snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
const char * suffix = "";

if (op->src[1]->ne[0] % 4 == 0) {
suffix = "_4";
}

snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
Expand All @@ -352,15 +358,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);

char base[256];
char name[256];

if (op->src[3]->ne[0] == 1) {
snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type));
} else {
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
}
snprintf(name, 256, "%s", base);
const int nsg = (ne00 + 31)/32;

snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s_nsg=%d", base, nsg);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
Expand All @@ -369,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);

ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);

return res;
}
Expand Down
4 changes: 1 addition & 3 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -776,9 +776,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
};
}
case GGML_OP_GET_ROWS:
{
return op->ne[3] == 1;
}
return true;
case GGML_OP_SET_ROWS:
{
if (op->src[0]->type != GGML_TYPE_F32) {
Expand Down
18 changes: 16 additions & 2 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ typedef struct {
} ggml_metal_kargs_clamp;

typedef struct {
int64_t nk0;
int64_t ne00;
int64_t ne01;
int64_t ne02;
Expand Down Expand Up @@ -572,32 +573,45 @@ typedef struct {
int64_t n_seq_tokens;
int64_t n_seqs;
uint64_t s_off;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t ns12;
uint64_t nb13;
uint64_t nb20;
uint64_t nb21;
uint64_t ns21;
uint64_t nb22;
int64_t ne30;
uint64_t nb31;
uint64_t nb41;
uint64_t nb42;
uint64_t ns42;
uint64_t nb43;
uint64_t nb51;
uint64_t nb52;
uint64_t ns52;
uint64_t nb53;
uint64_t nb0;
} ggml_metal_kargs_ssm_scan;

typedef struct {
int64_t ne00;
int32_t ne00t;
int32_t ne00;
uint64_t nb01;
uint64_t nb02;
int64_t ne10;
uint64_t nb03;
int32_t ne10;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_get_rows;

typedef struct {
Expand Down
78 changes: 46 additions & 32 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);

ggml_metal_kargs_cpy args = {
/*.nk0 =*/ ne00,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
Expand Down Expand Up @@ -906,23 +907,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);

ggml_metal_kargs_get_rows args = {
/*.ne00 =*/ ne00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.ne10 =*/ ne10,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
/*.ne00 =*/ ne00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};

const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));

const int nw0 = (args.ne00t + nth - 1)/nth;

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);

return 1;
}
Expand Down Expand Up @@ -1117,7 +1126,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);

Expand Down Expand Up @@ -1172,25 +1181,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.ns12 =*/ nb12/nb10,
/*.nb13 =*/ nb13,
/*.nb20 =*/ nb20,
/*.nb21 =*/ nb21,
/*.ns21 =*/ nb21/nb20,
/*.nb22 =*/ nb22,
/*.ne30 =*/ ne30,
/*.nb31 =*/ nb31,
/*.nb41 =*/ nb41,
/*.nb42 =*/ nb42,
/*.ns42 =*/ nb42/nb40,
/*.nb43 =*/ nb43,
/*.nb51 =*/ nb51,
/*.nb52 =*/ nb52,
/*.ns52 =*/ nb52/nb50,
/*.nb53 =*/ nb53,
/*.nb0 =*/ nb0,
};

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);

GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));

const size_t sms = ggml_metal_pipeline_get_smem(pipeline);

ggml_metal_encoder_set_pipeline(enc, pipeline);
Expand All @@ -1206,13 +1226,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {

ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);

if (ne30 == 1) {
// Mamba-2
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
} else {
GGML_ASSERT(d_inner == 1);
ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
}
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);

return 1;
}
Expand Down Expand Up @@ -1273,37 +1287,35 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {

GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);

// TODO: support
//const int32_t nk00 = ne00/ggml_blck_size(op->type);
const int32_t nk00 = ne00;

int nth = 32; // SIMD width

while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
int64_t nk0 = ne00;
if (ggml_is_quantized(op->src[0]->type)) {
nk0 = ne00/16;
} else if (ggml_is_quantized(op->type)) {
nk0 = ne00/ggml_blck_size(op->type);
}

nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));

// when rows are small, we can batch them together in a single threadgroup
int nrptg = 1;

// TODO: relax this constraint in the future
if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
if (nth > nk00) {
nrptg = (nth + nk00 - 1)/nk00;
nth = nk00;
if (nth > nk0) {
nrptg = (nth + nk0 - 1)/nk0;
nth = nk0;

if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nrptg--;
}
}
}

nth = std::min(nth, nk00);
nth = std::min<int>(nth, nk0);

ggml_metal_kargs_cpy args = {
/*.ne00 =*/ nk00,
/*.nk0 =*/ nk0,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
Expand All @@ -1321,12 +1333,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
/*.nb3 =*/ nb3,
};

const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);

ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);

return 1;
}
Expand Down
Loading
Loading