Skip to content

Commit

Permalink
llamafile : improve moe prompt eval speed on cpu
Browse files Browse the repository at this point in the history
This change introduces a llamafile_mixmul() API that allows tinyBLAS to
speed up "Mixture of Expert" models. On my Threadripper, Mixtral's 8x7b
F16 weights now process prompts 2x faster. I'm also seeing a 60 percent
improvement with Mixtral 8x22b Q4_0. The same applies to Q8_0, which is
also supported by tinyBLAS. MoE models spend the majority of their time
inside MUL_MAT_ID rather than MUL_MAT, which is why llamafile_sgemm was
not able to help them before. llamafile_mixmul works by decomposing the
mixmul operation into sgemm calls.
  • Loading branch information
jart committed Apr 23, 2024
1 parent 4e96a81 commit 00fa7cd
Show file tree
Hide file tree
Showing 3 changed files with 450 additions and 40 deletions.
8 changes: 7 additions & 1 deletion ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -11003,11 +11003,14 @@ static void ggml_compute_forward_mul_mat_id(
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * ids = dst->src[2];

GGML_TENSOR_BINARY_OP_LOCALS
if (llamafile_mixmul(params, src0, src1, ids, dst))
return;

const int ith = params->ith;
const int nth = params->nth;

GGML_TENSOR_BINARY_OP_LOCALS

const enum ggml_type type = src0->type;

const bool src1_cont = ggml_is_contiguous(src1);
Expand Down Expand Up @@ -18504,6 +18507,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur = 0;
const struct ggml_tensor * src0 = node->src[0];
const struct ggml_tensor * src1 = node->src[1];
const struct ggml_tensor * src2 = node->src[2];
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
if (src1->type != vec_dot_type) {
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
Expand All @@ -18512,6 +18516,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur += GGML_PAD(cur, sizeof(int64_t)); // align
cur += n_as * sizeof(int64_t); // matrix_row_counts
cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
size_t cur2 = llamafile_mixmul_needs(src0, src1, src2);
cur = cur > cur2 ? cur : cur2;
} break;
case GGML_OP_OUT_PROD:
{
Expand Down

0 comments on commit 00fa7cd

Please sign in to comment.