Skip to content

Commit

Permalink
[LLM Runtime] Baichuan FFN & MHA support (#497)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenzhong1 committed Oct 23, 2023
1 parent 5eb325d commit 6599bdc
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@
// - n_threads: number of threads to use
//

static int flag = 0;
static int first_tokens_size = 0;
static bool baichuan_model_eval_internal(model_context& lctx, const model_token* tokens, const int n_tokens,
const int n_past, const int n_threads) {
const int64_t t_start_us = ne_time_us();
Expand All @@ -66,15 +64,11 @@ static bool baichuan_model_eval_internal(model_context& lctx, const model_token*
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;

if (flag == 0) {
first_tokens_size = n_tokens;
flag++;
}

const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
const int head_size = n_embd / n_head;
const int n_rot = n_embd / n_head / 2;
const int num_attention_heads = n_head;
const float attn_scale = 1.f / std::sqrt(head_size);

auto& mem_per_token = lctx.mem_per_token;
auto& buf_compute = lctx.buf_compute;
Expand All @@ -92,13 +86,36 @@ static bool baichuan_model_eval_internal(model_context& lctx, const model_token*
ne_cgraph gf = {};
gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads;

const bool run_mha_reordered = model.layers[0].k_cache->type == NE_TYPE_JBLAS;
kv_cache_info_t kv_cache_info = {};
if (run_mha_reordered) {
NE_ASSERT(("kv cache should be the same dtype", model.layers[0].v_cache->type == NE_TYPE_JBLAS));
attn_shape_t attn_shape = {
/* .batch_size = */ 1,
/* .head_num = */ n_head,
/* .heads_kv = */ n_head,
/* .head_size = */ head_size,
/* .sl_q = */ N, // Note: make sure that jblas reordered attn supports next token inference
/* .sl_kv = */ n_past + N,
};

NE_ASSERT(("jblas managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead",
jblas_reordered_attn_fp32_support(&attn_shape)));
kv_shape_t kv_shape{
/* .heads_kv = */ static_cast<uint32_t>(n_head),
/* .head_size = */ static_cast<uint32_t>(head_size),
/* .sl_kv_max = */ static_cast<uint32_t>(n_ctx),
};
jblas_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info);
}

struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N);
memcpy(embd->data, tokens, N * ne_element_size(embd));

struct ne_tensor* inpL = ne_get_rows(ctx0, model.others[0], embd);
int hidden_size = inpL->ne[0];
int qlen = inpL->ne[1];
int head_size = hidden_size / num_attention_heads;
NE_ASSERT(N == inpL->ne[1]);

for (int il = 0; il < n_layer; ++il) {
struct ne_tensor* cur;

Expand All @@ -115,58 +132,99 @@ static bool baichuan_model_eval_internal(model_context& lctx, const model_token*
cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur);

ne_tensor* query_layer = ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1],
0); // [qlen, hidden]
query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [heads, qlen, head_size]

ne_tensor* key_layer =
ne_view_3d(ctx0, cur, head_size, num_attention_heads, qlen, head_size * ne_element_size(cur), cur->nb[1],
hidden_size * ne_element_size(cur));
key_layer = ne_permute(ctx0, key_layer, 0, 2, 1, 3); // [heads, qlen, head_size]

ne_tensor* value_layer =
ne_view_3d(ctx0, cur, head_size, num_attention_heads, qlen, head_size * ne_element_size(cur), cur->nb[1],
2 * hidden_size * ne_element_size(cur)); // [qlen, heads, head_size]
value_layer = ne_permute(ctx0, value_layer, 1, 2, 0, 3); // [heads, head_size, qlen]

// store key and value to memory
{
struct ne_tensor* k_cache_view =
ne_view_3d(ctx0, model.layers[il].k_cache, head_size, qlen, num_attention_heads,
model.layers[il].k_cache->nb[1], model.layers[il].k_cache->nb[2],
n_past * head_size * ne_element_size(model.layers[il].k_cache)); // [kv_heads, qlen, head_size]

struct ne_tensor* v_cache_view =
ne_view_3d(ctx0, model.layers[il].v_cache, qlen, head_size, num_attention_heads,
model.layers[il].v_cache->nb[1], model.layers[il].v_cache->nb[2],
n_past * ne_element_size(model.layers[il].v_cache)); // [kv_heads, head_size, qlen]

ne_build_forward_expand(&gf, ne_cpy(ctx0, key_layer, k_cache_view));
ne_build_forward_expand(&gf, ne_cpy(ctx0, value_layer, v_cache_view));
0); // [N, hidden]

ne_tensor* key_layer = ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1],
hidden_size * ne_element_size(cur));

ne_tensor* value_layer = ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1],
2 * hidden_size * ne_element_size(cur)); // [N, heads, head_size]

if (!run_mha_reordered) {
query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [heads, N, head_size]
key_layer = ne_permute(ctx0, key_layer, 0, 2, 1, 3); // [heads, N, head_size]
value_layer = ne_permute(ctx0, value_layer, 1, 2, 0, 3); // [heads, head_size, N]

// store key and value to memory
{
struct ne_tensor* k_cache_view =
ne_view_3d(ctx0, model.layers[il].k_cache, head_size, N, n_head, model.layers[il].k_cache->nb[1],
model.layers[il].k_cache->nb[2],
n_past * head_size * ne_element_size(model.layers[il].k_cache)); // [kv_heads, N, head_size]

struct ne_tensor* v_cache_view =
ne_view_3d(ctx0, model.layers[il].v_cache, N, head_size, n_head, model.layers[il].v_cache->nb[1],
model.layers[il].v_cache->nb[2],
n_past * ne_element_size(model.layers[il].v_cache)); // [kv_heads, head_size, N]

ne_build_forward_expand(&gf, ne_cpy(ctx0, key_layer, k_cache_view));
ne_build_forward_expand(&gf, ne_cpy(ctx0, value_layer, v_cache_view));
}
// concat key & value with past kv
key_layer = ne_view_3d(ctx0, model.layers[il].k_cache, head_size, n_past + N, n_head,
model.layers[il].k_cache->nb[1], model.layers[il].k_cache->nb[2],
0); // [kv_heads, klen, head_size]
value_layer = ne_view_3d(ctx0, model.layers[il].v_cache, n_past + N, head_size, n_head,
model.layers[il].v_cache->nb[1], model.layers[il].v_cache->nb[2],
0); // [kv_heads, head_size, klen]

// attention
struct ne_tensor* attn_scores = ne_mul_mat(ctx0, key_layer, query_layer); // [heads, N, klen]
attn_scores = ne_scale_inplace(ctx0, attn_scores, ne_new_f32(ctx0, attn_scale));
attn_scores = ne_alibi(ctx0, attn_scores, n_past, n_head, 8);
if (n_past == 0) {
attn_scores = ne_diag_mask_inf_inplace(ctx0, attn_scores, n_past);
}
ne_tensor* attn_probs = ne_soft_max_inplace(ctx0, attn_scores); // [heads, N, klen]

// ne_compute_forward_mul_mat_f16_f32
ne_tensor* context_layer = ne_mul_mat(ctx0, value_layer, attn_probs); // [heads, N, head_size]
context_layer = ne_cont(ctx0, ne_permute(ctx0, context_layer, 0, 2, 1, 3));
context_layer = ne_reshape_2d(ctx0, context_layer, hidden_size, N);

// F32 mul_mat
cur = ne_mul_mat(ctx0, model.layers[il].attn[1], context_layer);
} else {
const auto k_size = kv_cache_info.k_bytes;
const auto v_size = kv_cache_info.v_bytes;

// store key and value to memory
{
const auto k_cache = ne_view_3d(ctx0, model.layers[il].k_cache, // tensor
head_size, n_ctx, n_head, // ne
0, 0, // nb (jblas managed)
0); // offset
ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, key_layer, n_past));
const auto v_cache = ne_view_3d(ctx0, model.layers[il].v_cache, // tensor
head_size, n_ctx, n_head, // ne
0, 0, // nb (jblas managed)
0); // offset
ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, value_layer, n_past));
}

query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [heads, N, head_size]
key_layer = //
ne_view_3d(ctx0, model.layers[il].k_cache, // tensor
head_size, n_past + N, n_head, // ne
kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (jblas managed)
0); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&key_layer->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout

value_layer = //
ne_view_3d(ctx0, model.layers[il].v_cache, // tensor
n_past + N, head_size, n_head, // ne
kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (jblas managed)
0); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&value_layer->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout

ne_attn_flags_t attn_flags = NE_ATTN_FLAG_IS_ALIBI8;
if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases
struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, query_layer, key_layer, value_layer, attn_scale, attn_flags);
cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0);

// F32 mul_mat
cur = ne_mul_mat(ctx0, model.layers[il].attn[1], cur);
}
// concat key & value with past kv
key_layer = ne_view_3d(ctx0, model.layers[il].k_cache, head_size, n_past + qlen, num_attention_heads,
model.layers[il].k_cache->nb[1], model.layers[il].k_cache->nb[2],
0); // [kv_heads, klen, head_size]
value_layer = ne_view_3d(ctx0, model.layers[il].v_cache, n_past + qlen, head_size, num_attention_heads,
model.layers[il].v_cache->nb[1], model.layers[il].v_cache->nb[2],
0); // [kv_heads, head_size, klen]

// attention
struct ne_tensor* attn_scores = ne_mul_mat(ctx0, key_layer, query_layer); // [heads, qlen, klen]
attn_scores = ne_scale_inplace(ctx0, attn_scores, ne_new_f32(ctx0, 1.f / std::sqrt(head_size)));
attn_scores = ne_alibi(ctx0, attn_scores, n_past, num_attention_heads, 8);
if (n_past == 0) {
attn_scores = ne_diag_mask_inf_inplace(ctx0, attn_scores, n_past);
}
ne_tensor* attn_probs = ne_soft_max_inplace(ctx0, attn_scores); // [heads, qlen, klen]

// ne_compute_forward_mul_mat_f16_f32
ne_tensor* context_layer = ne_mul_mat(ctx0, value_layer, attn_probs); // [heads, qlen, head_size]
context_layer = ne_cont(ctx0, ne_permute(ctx0, context_layer, 0, 2, 1, 3));
context_layer = ne_reshape_2d(ctx0, context_layer, hidden_size, qlen);

// F32 mul_mat
cur = ne_mul_mat(ctx0, model.layers[il].attn[1], context_layer);
}

lctx.use_buf(ctx0, 1);
Expand All @@ -178,11 +236,19 @@ static bool baichuan_model_eval_internal(model_context& lctx, const model_token*
hidden_states = ne_mul(ctx0, hidden_states, model.layers[il].norm[1]);

// mlp.forward
struct ne_tensor* gate = ne_mul_mat(ctx0, model.layers[il].ffn[0], hidden_states);
gate = ne_silu(ctx0, gate);
struct ne_tensor* up = ne_mul_mat(ctx0, model.layers[il].ffn[1], hidden_states);
struct ne_tensor* mlp_output = ne_mul(ctx0, gate, up);
mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[2], mlp_output);
struct ne_tensor* mlp_output;
if (jblas_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[0]->data, model.layers[il].ffn[1]->data,
model.layers[il].ffn[2]->data, N, hidden_states->ne[0],
model.layers[il].ffn[0]->ne[1], model.layers[il].ffn[1]->ne[1])) {
mlp_output =
ne_ffn_silu(ctx0, model.layers[il].ffn[0], model.layers[il].ffn[1], model.layers[il].ffn[2], hidden_states);
} else {
struct ne_tensor* up = ne_mul_mat(ctx0, model.layers[il].ffn[2], hidden_states);
struct ne_tensor* gate = ne_mul_mat(ctx0, model.layers[il].ffn[0], hidden_states);
gate = ne_silu(ctx0, gate);
struct ne_tensor* mlp_output = ne_mul(ctx0, gate, up);
mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[1], mlp_output);
}

inpL = ne_add_inplace(ctx0, mlp_output, residual);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ void model_load_internal(const std::string& fname, model_archs arch, model_conte
ms->init(fname.c_str(), lctx, n_ctx, n_gpu_layers, use_mmap, use_mlock, vocab_only);
ms->load(lctx, progress_callback, progress_callback_user_data);

lctx.support_jblas_kv = true;
lctx.t_load_us = ne_time_us() - lctx.t_start_us;
}

Expand Down Expand Up @@ -141,15 +142,14 @@ void BAICHUAN::load(model_context& lctx, model_progress_callback progress_callba
// ffn GEMM
layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight",
{n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend);
layer.ffn[1] =
ml->get_tensor(layers_i + ".mlp.up_proj.weight", {n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend);
layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.down_proj.weight",

layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.down_proj.weight",
{uint32_t(model.hparams.inner_hidden_size), n_embd}, backend);
layer.ffn[2] =
ml->get_tensor(layers_i + ".mlp.up_proj.weight", {n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend);

layer.k_cache = d_ne_new_tensor_3d(model.ctx, NE_TYPE_F16, n_embd / hparams.n_head, max_len,
hparams.n_head); // [n_head, maxlen, head_size]
layer.v_cache = d_ne_new_tensor_3d(model.ctx, NE_TYPE_F16, max_len, n_embd / hparams.n_head,
hparams.n_head); // [n_head, head_size, maxlen]
layer.v_cache == nullptr;
layer.k_cache == nullptr;
}

// print memory requirements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,9 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to
cur = ne_ffn_silu(ctx0, model.layers[il].ffn[0], model.layers[il].ffn[1], model.layers[il].ffn[2], cur);
} else {
struct ne_tensor* tmp = ne_mul_mat(ctx0, model.layers[il].ffn[2], cur);

cur = ne_mul_mat(ctx0, model.layers[il].ffn[0], cur);

// SILU activation
cur = ne_silu(ctx0, cur);

cur = ne_mul(ctx0, cur, tmp);

cur = ne_mul_mat(ctx0, model.layers[il].ffn[1], cur);
}
}
Expand Down

0 comments on commit 6599bdc

Please sign in to comment.