Skip to content
142 changes: 141 additions & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2007,6 +2007,97 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
}
}

static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
const ggml_tensor * ffn_gate,
const ggml_tensor * glu) {
bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU;
bool is_mul_mat_id =
ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU;

GGML_ASSERT(ffn_up && ffn_gate && glu);

if (!is_mul_mat && !is_mul_mat_id) {
return false;
}

if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
!ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
return false;
}

if (ffn_up->src[1] != ffn_gate->src[1]) {
return false;
}

if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
return false;
}

if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
return false;
}

static constexpr std::array<ggml_glu_op, 2> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU };

if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
return false;
}

if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
return false;
}

const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);

//TODO: add support for fusion for split buffers
if (split) {
return false;
}

return true;
}

static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
ggml_tensor * src0 = tensor->src[0];
ggml_tensor * src1 = tensor->src[1];
const ggml_tensor * dst = tensor;

const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;

bool use_mul_mat_vec_f =
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;

const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2]: src1->ne[1]);

if (tensor->op == GGML_OP_MUL_MAT_ID) {
use_mul_mat_vec_f = use_mul_mat_vec_f && dst->ne[2] == 1;
}

return use_mul_mat_vec_f;
}

static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
ggml_tensor * src0 = tensor->src[0];
ggml_tensor * src1 = tensor->src[1];
const ggml_tensor * dst = tensor;

const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
src0->view_src;

bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;

if (tensor->op == GGML_OP_MUL_MAT_ID) {
use_mul_mat_vec_q = use_mul_mat_vec_q && dst->ne[2] == 1;
}

return use_mul_mat_vec_q;
}

static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);

Expand Down Expand Up @@ -2745,7 +2836,7 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
}
}

if (node->op == GGML_OP_SCALE &&
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false;
}
Expand Down Expand Up @@ -2855,6 +2946,25 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}

std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };

std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };

if (ops.size() == 3 && (std::equal(ops.begin(), ops.end(), mul_mat_id_glu_ops.begin()) ||
std::equal(ops.begin(), ops.end(), mul_mat_glu_ops.begin()))) {
if (node_idx + 2 >= cgraph->n_nodes) {
return false;
}

const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];

if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
return true;
}
}

if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
Expand Down Expand Up @@ -2992,6 +3102,36 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
}
}

bool fused_mul_mat_vec = false;

for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
ggml_tensor * glu = cgraph->nodes[i + 2];
ggml_tensor * gate = glu->src[0];
ggml_tensor * up = glu->src[1];

const ggml_tensor * src0 = up->src[0];
const ggml_tensor * src1 = up->src[1];
const ggml_tensor * ids = up->src[2];

if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, gate->src[0], glu);
fused_mul_mat_vec = true;
break;
}

if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, gate->src[0], glu);
fused_mul_mat_vec = true;
break;
}
}
}

if (fused_mul_mat_vec) {
i += 2;
continue;
}

if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
Expand Down
Loading
Loading