diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index c647baef87..33ab43d58f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -57,7 +57,7 @@ ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, cons return ppls->data[name]; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) { +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) { char base[256]; char name[256]; @@ -71,34 +71,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t snprintf(base, 256, "kernel_%s", op_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) { char base[256]; char name[256]; snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); @@ -115,68 +111,60 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) { char base[256]; char name[256]; snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) { char base[256]; char name[256]; snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { char base[256]; char name[256]; snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); char base[256]; @@ -224,17 +212,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); char base[256]; @@ -258,17 +244,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_SUM); char base[256]; @@ -277,17 +261,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t l snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); char base[256]; @@ -306,19 +288,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->op == GGML_OP_CUMSUM); char base[256]; @@ -327,17 +307,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_libr snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->op == GGML_OP_CUMSUM); char base[256]; @@ -346,17 +324,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_libr snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); char base[256]; @@ -373,19 +349,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_librar snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); @@ -404,17 +378,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); char base[256]; @@ -425,19 +397,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s_nsg=%d", base, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg); + res.smem = 32*sizeof(float)*nsg; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; @@ -467,41 +437,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; - } - - ggml_metal_cv_t cv = ggml_metal_cv_init(); + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - ggml_metal_cv_free(cv); + ggml_metal_cv_free(cv); + } return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; @@ -514,27 +480,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_ snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; - } + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_t cv = ggml_metal_cv_init(); + ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); + ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); - ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); - ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); + ggml_metal_cv_free(cv); + } // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes - ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048); + res.smem = bc_out ? 8192 : 4096 + 2048; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); @@ -689,49 +653,43 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_ snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); snprintf(name, 256, "%s_nsg=%d", base, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; - } - - ggml_metal_cv_t cv = ggml_metal_cv_init(); + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - ggml_metal_cv_free(cv); + ggml_metal_cv_free(cv); + } - ggml_metal_pipeline_set_nr0 (res, nr0); - ggml_metal_pipeline_set_nr1 (res, nr1); - ggml_metal_pipeline_set_nsg (res, nsg); - ggml_metal_pipeline_set_smem(res, smem); + res.nr0 = nr0; + res.nr1 = nr1; + res.nsg = nsg; + res.smem = smem; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) { char base[256]; char name[256]; snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20); snprintf(name, 256, "%s_ne02=%d", base, ne02); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t); - - ggml_metal_pipeline_set_smem(res, smem); + res.smem = (size_t) ne02*ne20*sizeof(uint16_t); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; @@ -743,25 +701,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_libra snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); snprintf(name, 256, "%s_bci=%d", base, bc_inp); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; - } + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_t cv = ggml_metal_cv_init(); + ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); - ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); + ggml_metal_cv_free(cv); + } - ggml_metal_pipeline_set_smem(res, 8192); + res.smem = 8192; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); @@ -909,28 +865,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); snprintf(name, 256, "%s_nsg=%d", base, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; - } - - ggml_metal_cv_t cv = ggml_metal_cv_init(); + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - ggml_metal_cv_free(cv); + ggml_metal_cv_free(cv); + } - ggml_metal_pipeline_set_nr0 (res, nr0); - ggml_metal_pipeline_set_nr1 (res, nr1); - ggml_metal_pipeline_set_nsg (res, nsg); - ggml_metal_pipeline_set_smem(res, smem); + res.nr0 = nr0; + res.nr1 = nr1; + res.nsg = nsg; + res.smem = smem; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); @@ -941,19 +895,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_ snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*(sizeof(float) + sizeof(int32_t))); + res.smem = 32*(sizeof(float) + sizeof(int32_t)); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ARGSORT); char base[256]; @@ -971,17 +923,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ARGSORT); char base[256]; @@ -999,18 +949,16 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } // note: reuse the argsort kernel for top_k -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_TOP_K); char base[256]; @@ -1029,17 +977,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_TOP_K); char base[256]; @@ -1057,17 +1003,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_lib snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, @@ -1086,33 +1030,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( has_mask, ncpsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; - } - - ggml_metal_cv_t cv = ggml_metal_cv_init(); + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); - //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); - //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); - //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); - //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); - //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); - //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); - //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); - //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); - ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); + //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - ggml_metal_cv_free(cv); + ggml_metal_cv_free(cv); + } return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk( ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t nqptg, @@ -1131,33 +1073,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( nqptg, ncpsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; - } - - ggml_metal_cv_t cv = ggml_metal_cv_init(); + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); - //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); - //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); - //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); - //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); + //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); - //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); - //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); - //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); - //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); - ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); - ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); + ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - ggml_metal_cv_free(cv); + ggml_metal_cv_free(cv); + } return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const ggml_tensor * op, bool has_mask, @@ -1198,33 +1138,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ns20, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; - } - - ggml_metal_cv_t cv = ggml_metal_cv_init(); + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0); - ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); - ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); - ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); - ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0); + ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); + ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); + ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); - ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); + ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); - ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); - ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); - ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22); + ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); + ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); + ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22); - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - ggml_metal_cv_free(cv); + ggml_metal_cv_free(cv); + } return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec( ggml_metal_library_t lib, const ggml_tensor * op, bool has_mask, @@ -1262,32 +1200,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( ns20, nsg, nwg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; - } - - ggml_metal_cv_t cv = ggml_metal_cv_init(); + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0); - ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); - ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); - ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); - ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0); + ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); + ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); + ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); - ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); - ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); - ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22); - ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23); + ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); + ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); + ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22); + ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23); - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - ggml_metal_cv_free(cv); + ggml_metal_cv_free(cv); + } return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( ggml_metal_library_t lib, const ggml_tensor * op, int32_t dv, @@ -1300,26 +1236,24 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce"); snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; - } - - ggml_metal_cv_t cv = ggml_metal_cv_init(); + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0); - ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1); + ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0); + ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1); - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - ggml_metal_cv_free(cv); + ggml_metal_cv_free(cv); + } return res; GGML_UNUSED(op); } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( ggml_metal_library_t lib, ggml_op op, int32_t n_fuse, @@ -1344,17 +1278,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin( snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_L2_NORM); GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); @@ -1366,19 +1298,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library snprintf(base, 256, "kernel_l2_norm_f32"); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_GROUP_NORM); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1389,19 +1319,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr snprintf(base, 256, "kernel_group_norm_f32"); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) { assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM); GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); @@ -1434,19 +1362,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ROPE); char base[256]; @@ -1473,23 +1399,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; - } - - ggml_metal_cv_t cv = ggml_metal_cv_init(); + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0); + ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0); - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - ggml_metal_cv_free(cv); + ggml_metal_cv_free(cv); + } return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_IM2COL); GGML_ASSERT(ggml_is_contiguous(op->src[1])); @@ -1502,17 +1426,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_ snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_CONV_TRANSPOSE_1D); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1527,17 +1449,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_CONV_TRANSPOSE_2D); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1552,17 +1472,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_met snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_CONV_2D); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1576,17 +1494,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_UPSCALE); char base[256]; @@ -1595,17 +1511,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_PAD); char base[256]; @@ -1614,8 +1528,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (res.pipeline) { return res; } @@ -1624,7 +1538,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_PAD_REFLECT_1D); char base[256]; @@ -1633,17 +1547,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_ snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ARANGE); char base[256]; @@ -1652,17 +1564,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_ snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_TIMESTEP_EMBEDDING); char base[256]; @@ -1671,17 +1581,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_OPT_STEP_ADAMW); char base[256]; @@ -1690,17 +1598,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_ snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_OPT_STEP_SGD); char base[256]; @@ -1709,12 +1615,10 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_li snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 3976e622b9..17baef2017 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -35,20 +35,6 @@ typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t; ggml_metal_pipeline_t ggml_metal_pipeline_init(void); void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline); -void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg); -int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline); - -void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0); -int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline); - -void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1); -int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline); - -void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem); -size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline); - -int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline); - // a collection of pipelines typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t; @@ -58,6 +44,19 @@ void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls); void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline); ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name); +struct ggml_metal_pipeline_with_params { + ggml_metal_pipeline_t pipeline; + + int nsg; + + int nr0; + int nr1; + + size_t smem; +}; + +int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline); + // // MTLCommandBuffer wrapper // @@ -76,7 +75,7 @@ void ggml_metal_encoder_free(ggml_metal_encoder_t encoder); void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name); void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder); -void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline); +void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline); void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx); void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx); @@ -100,66 +99,66 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev void ggml_metal_library_free(ggml_metal_library_t lib); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); -ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); - -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); - -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); +struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); + +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); + +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, int32_t ncpsg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk( ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t nqptg, int32_t ncpsg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, @@ -169,7 +168,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_kvpad, int32_t nsg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, @@ -180,7 +179,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( int32_t nsg, int32_t nwg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t dv, diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 4d2bfcf91c..d22672a816 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -75,14 +75,6 @@ void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) { struct ggml_metal_pipeline { id obj; - - // suggested dispatch sizes - int nsg; - - int nr0; - int nr1; - - size_t smem; }; ggml_metal_pipeline_t ggml_metal_pipeline_init(void) { @@ -90,10 +82,6 @@ ggml_metal_pipeline_t ggml_metal_pipeline_init(void) { *res = (struct ggml_metal_pipeline) { /*.obj =*/ nil, - /*.nsg =*/ 0, - /*.nr0 =*/ 0, - /*.nr1 =*/ 0, - /*.smem =*/ 0, }; return res; @@ -105,40 +93,8 @@ void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) { free(pipeline); } -void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg) { - pipeline->nsg = nsg; -} - -int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline) { - return pipeline->nsg; -} - -void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0) { - pipeline->nr0 = nr0; -} - -int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline) { - return pipeline->nr0; -} - -void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1) { - pipeline->nr1 = nr1; -} - -int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline) { - return pipeline->nr1; -} - -void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem) { - pipeline->smem = smem; -} - -size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline) { - return pipeline->smem; -} - -int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline) { - return pipeline->obj.maxTotalThreadsPerThreadgroup; +int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline) { + return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup; } struct ggml_metal_library { @@ -389,28 +345,42 @@ void ggml_metal_library_free(ggml_metal_library_t lib) { free(lib); } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { [lib->lock lock]; - ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name); + struct ggml_metal_pipeline_with_params res = { + /*.pipeline =*/ nil, + /*.nr0 =*/ 0, + /*.nr1 =*/ 0, + /*.nsg =*/ 0, + /*.smem =*/ 0, + }; + + res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); [lib->lock unlock]; return res; } -ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { +struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { + struct ggml_metal_pipeline_with_params res = { + /*.pipeline =*/ nil, + /*.nr0 =*/ 0, + /*.nr1 =*/ 0, + /*.nsg =*/ 0, + /*.smem =*/ 0, + }; + [lib->lock lock]; - ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name); - if (res) { + res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); + if (res.pipeline) { [lib->lock unlock]; return res; } - res = ggml_metal_pipeline_init(); - @autoreleasepool { NSError * error = nil; @@ -432,26 +402,43 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]); } - return nil; + return res; } - res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; + id obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; [mtl_function release]; - GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj, - (int) res->obj.maxTotalThreadsPerThreadgroup, - (int) res->obj.threadExecutionWidth); + if (!obj) { + [lib->lock unlock]; + + GGML_LOG_ERROR("%s: failed to create pipeline state: base = '%s', name = '%s'\n", __func__, base, name); + if (error) { + GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]); + } + + return res; + } + + GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, + (void *) obj, + (int) obj.maxTotalThreadsPerThreadgroup, + (int) obj.threadExecutionWidth); + + if (obj.maxTotalThreadsPerThreadgroup == 0 || obj.threadExecutionWidth == 0) { + [obj release]; - if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) { [lib->lock unlock]; GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name); - return nil; + return res; } - ggml_metal_pipelines_add(lib->pipelines, name, res); + res.pipeline = ggml_metal_pipeline_init(); + res.pipeline->obj = obj; + + ggml_metal_pipelines_add(lib->pipelines, name, res.pipeline); } [lib->lock unlock]; @@ -496,8 +483,8 @@ void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) { [encoder->obj popDebugGroup]; } -void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline) { - [encoder->obj setComputePipelineState:pipeline->obj]; +void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline) { + [encoder->obj setComputePipelineState:pipeline.pipeline->obj]; } void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) { @@ -622,8 +609,8 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); dev->props.has_tensor = false; } else { - ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); - if (!ppl) { + struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); + if (!ppl.pipeline) { GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); dev->props.has_tensor = false; } @@ -672,8 +659,8 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); dev->props.has_bfloat = false; } else { - ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); - if (!ppl) { + struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); + if (!ppl.pipeline) { GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); dev->props.has_bfloat = false; } diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 9871e976f2..edb227a210 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -524,7 +524,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { /*.dim =*/ dim, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); + auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -550,7 +550,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); + auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); ggml_metal_kargs_repeat args = { /*.ne00 =*/ ne00, @@ -616,7 +616,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { // TODO: make a simpler cpy_bytes kernel //const id pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); ggml_metal_kargs_cpy args = { /*.nk0 =*/ ne00, @@ -679,7 +679,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.o1 =*/ { 0 }, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); + auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -721,7 +721,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { n /= 4; } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -760,7 +760,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { n /= 4; } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -789,7 +789,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { n /= 4; } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0); @@ -817,7 +817,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1])); } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_glu(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op); const int32_t swp = ggml_get_op_params_i32(op, 1); const float alpha = ggml_get_op_params_f32(op, 2); @@ -870,7 +870,7 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { /*.np =*/ n, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op); int nth = 32; // SIMD width @@ -925,7 +925,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { /*.nb3 =*/ nb3, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); int nth = 32; // SIMD width @@ -936,7 +936,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, ne00); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -963,7 +963,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op); + auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op); int nth = 1; while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) { @@ -1060,7 +1060,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); { - ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op); + auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op); ggml_metal_kargs_cumsum_add args = { /*.ne00 =*/ ne00, @@ -1106,7 +1106,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); + auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); ggml_metal_kargs_get_rows args = { /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00, @@ -1151,7 +1151,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type); + auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type); const int32_t nk0 = ne0/ggml_blck_size(op->type); @@ -1252,7 +1252,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { /*.n_head_log2 =*/ n_head_log2, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op); int nth = 32; // SIMD width @@ -1266,7 +1266,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { } } - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); @@ -1322,7 +1322,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { /*.nb2 =*/ nb2, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); @@ -1409,11 +1409,11 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { /*.nb0 =*/ nb0, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - const size_t sms = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -1426,7 +1426,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8); - ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0); + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); @@ -1449,7 +1449,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { const int64_t C = op->ne[0]; const int64_t H = op->src[0]->ne[1]; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); int ida = 0; @@ -1485,7 +1485,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); @@ -1592,7 +1592,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { /* .np = */ np }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool); + auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool); const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); const int ntg = (np + nth - 1) / nth; @@ -1701,7 +1701,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported ne11"); }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); ggml_metal_kargs_mul_mv_ext args = { /*.ne00 =*/ ne00, @@ -1748,7 +1748,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { // default: break; //} - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op); ggml_metal_kargs_mul_mm args = { /*.ne00 =*/ ne00, @@ -1773,18 +1773,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1); } else { - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_kargs_mul_mv args = { /*.ne00 =*/ ne00, @@ -1915,9 +1915,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { nb21, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20); + auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); @@ -1938,7 +1938,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); { - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op); ggml_metal_kargs_mul_mm_id args = { /*.ne00 =*/ ne00, @@ -1967,20 +1967,20 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, bid_ids, 4); ggml_metal_encoder_set_buffer (enc, bid_dst, 5); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); } } else { - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_kargs_mul_mv_id args = { /*.nei0 =*/ ne20, @@ -2064,7 +2064,7 @@ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID); + auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -2308,7 +2308,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb33 =*/nb33, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2339,7 +2339,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb33 =*/ nb33, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2424,7 +2424,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); + auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -2476,7 +2476,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb33 =*/nb33, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2578,7 +2578,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); + auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); @@ -2630,7 +2630,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { nrows, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2762,7 +2762,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer bid_src1.offs = 0; - ggml_metal_pipeline_t pipeline = nullptr; + struct ggml_metal_pipeline_with_params pipeline; if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -2835,7 +2835,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { /*.eps =*/ eps, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; @@ -2844,7 +2844,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, ne00/4); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; const int64_t nrows = ggml_nrows(op->src[0]); @@ -2887,7 +2887,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { /*.eps =*/ eps, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op); int nth = 32; // SIMD width //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { @@ -2897,7 +2897,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); //nth = std::min(nth, ne00/4); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3022,7 +3022,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { } } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); + auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); int nth = 32; // SIMD width @@ -3033,7 +3033,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, args.ne00_t); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3127,7 +3127,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { /* src2 =*/ op->src[2] != nullptr, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3199,7 +3199,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { /*.KHW =*/ KH * KW, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); @@ -3270,7 +3270,7 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) { /*.d1 =*/ d1, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op); int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline); nth = std::min(nth, 256); @@ -3325,7 +3325,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { /*.nb1 =*/ nb1, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3377,7 +3377,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) { /*.nb2 =*/ nb2, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3433,7 +3433,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { /*.sf3 =*/ sf3 }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); @@ -3477,7 +3477,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { /*.nb3 =*/ nb3 }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op); const int nth = std::min(1024, ne0); @@ -3523,7 +3523,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) { /*.p1 =*/ ((const int32_t *)(op->op_params))[1] }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op); const int nth = std::min(1024, ne0); @@ -3560,7 +3560,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { const int nth = std::min(1024, ne0); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3591,7 +3591,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) { /*.max_period =*/ max_period, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op); const int nth = std::max(1, std::min(1024, dim/2)); @@ -3621,7 +3621,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { /*.nb01 = */ nb01, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argmax(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op); const int64_t nrows = ggml_nrows(op->src[0]); @@ -3630,7 +3630,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { nth *= 2; } - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3657,7 +3657,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); // bitonic sort requires the number of elements to be power of 2 int nth = 1; @@ -3706,7 +3706,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); - ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op); + auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op); int len = nth; @@ -3764,7 +3764,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); // bitonic sort requires the number of elements to be power of 2 int nth = 1; @@ -3818,7 +3818,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); - ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); + auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); int len = args.top_k; @@ -3881,7 +3881,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { /*.slope =*/ slope }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); int64_t n = ggml_nelements(op); @@ -3910,7 +3910,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); const int64_t np = ggml_nelements(op->src[0]); ggml_metal_kargs_opt_step_adamw args = { @@ -3946,7 +3946,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); const int64_t np = ggml_nelements(op->src[0]); ggml_metal_kargs_opt_step_sgd args = {