Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,19 +414,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int r1ptg) {
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
char base[256];
char name[256];

snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

ggml_metal_cv_free(cv);

return res;
}
Expand Down Expand Up @@ -608,7 +615,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
};

snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_nsg=%d", base, nsg);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
Expand Down Expand Up @@ -824,7 +831,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
};

snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_nsg=%d", base, nsg);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
Expand Down Expand Up @@ -923,11 +930,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
dk,
dv);

snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
"flash_attn_ext",
ggml_type_name(op->src[1]->type),
dk,
dv,
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
base,
has_mask,
has_sinks,
has_bias,
Expand Down Expand Up @@ -985,11 +989,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
dk,
dv);

snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
"flash_attn_ext_vec",
ggml_type_name(op->src[1]->type),
dk,
dv,
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
base,
has_mask,
has_sinks,
has_bias,
Expand Down Expand Up @@ -1033,7 +1034,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
char name[256];

snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);
snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int r1ptg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
Expand Down
7 changes: 2 additions & 5 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
#define N_R0_Q3_K 2
#define N_SG_Q3_K 2

#define N_R0_Q4_K 4
#define N_R0_Q4_K 2
#define N_SG_Q4_K 2

#define N_R0_Q5_K 2
#define N_SG_Q5_K 2

#define N_R0_Q6_K 1
#define N_R0_Q6_K 2
#define N_SG_Q6_K 2

#define N_R0_IQ1_S 4
Expand Down Expand Up @@ -374,9 +374,6 @@ typedef struct {
int32_t ne1;
int16_t r2;
int16_t r3;
int16_t nsg;
int16_t nxpsg;
int16_t r1ptg;
} ggml_metal_kargs_mul_mv_ext;

typedef struct {
Expand Down
5 changes: 1 addition & 4 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
GGML_ABORT("unsupported ne11");
};

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, r1ptg);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);

ggml_metal_kargs_mul_mv_ext args = {
/*.ne00 =*/ ne00,
Expand All @@ -1465,9 +1465,6 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
/*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
/*.nsg =*/ nsg,
/*.nxpsg =*/ nxpsg,
/*.r1ptg =*/ r1ptg,
};

ggml_metal_encoder_set_pipeline(enc, pipeline);
Expand Down
Loading
Loading