Skip to content

Commit

Permalink
[CPP Graph] MPT MHA support (#453)
Browse files Browse the repository at this point in the history
  • Loading branch information
DDEle committed Oct 13, 2023
1 parent 698e589 commit 7b73b1b
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 82 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -7542,6 +7542,7 @@ static void ne_compute_forward_alibi_f16(const struct ne_compute_params* params,
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);

NE_ASSERT(("OP_ALIBI may not be able handle multi-batch cases", src0->ne[3] == 1));
for (int i = 0; i < ne0; i++) {
for (int j = 0; j < ne1; j++) {
for (int k = 0; k < ne2_ne3; k++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ extern "C" {

// Attention flags
typedef enum NE_ATTN_FLAG {
NE_ATTN_FLAG_NONE = 0,
NE_ATTN_FLAG_IS_CAUSAL = 1 << 1,
NE_ATTN_FLAG_IS_ALIBI = 1 << 2,
NE_ATTN_FLAG_IS_ALIBI8 = 1 << 2,
} NE_ATTN_FLAG;
typedef uint32_t ne_attn_flags_t;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_token*
0); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&value_layer->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout

ne_attn_flags_t attn_flags = 0;
ne_attn_flags_t attn_flags = NE_ATTN_FLAG_NONE;
if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases
struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, query_layer, key_layer, value_layer, attn_scale, attn_flags);
cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ static bool falcon_model_eval_internal(model_context& lctx, const model_token* t
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout
ne_set_name(V, "V");

ne_attn_flags_t attn_flags = 0;
ne_attn_flags_t attn_flags = NE_ATTN_FLAG_NONE;
if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases
struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags);
cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ static bool gptj_model_eval_internal(model_context& lctx, const model_token* tok
struct ne_tensor* KQV_merged_contiguous;

const float attn_scale = 1.0f / sqrtf(static_cast<float>(n_embd) / n_head);
ne_attn_flags_t attn_flags = 0;
ne_attn_flags_t attn_flags = NE_ATTN_FLAG_NONE;
if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases
if (run_mha_reordered) { // reordered kv-cache bf16 mha must be used if run_mha_reordered
struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout
ne_set_name(V, "V");

ne_attn_flags_t attn_flags = 0;
ne_attn_flags_t attn_flags = NE_ATTN_FLAG_NONE;
if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases
struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags);
struct ne_tensor* KQV_merged_contiguous =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ static bool mpt_model_eval_internal(model_context& lctx, const model_token* toke
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
// const int n_rot = hparams.n_embd / hparams.n_head;
const int head_dim = hparams.n_embd / hparams.n_head;

auto& mem_per_token = lctx.mem_per_token;
auto& buf_compute = lctx.buf_compute;
Expand All @@ -88,8 +88,28 @@ static bool mpt_model_eval_internal(model_context& lctx, const model_token* toke
ne_cgraph gf = {};
gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads;

const bool kv_mem_jblas = kv_self.k->type == NE_TYPE_JBLAS;
NE_ASSERT(("jblas managed kv-cache is not yet supported; use `--memory-f16 / --memory-f32` instead", !kv_mem_jblas));
const bool run_mha_reordered = kv_self.k->type == NE_TYPE_JBLAS;
kv_cache_info_t kv_cache_info = {};
if (run_mha_reordered) {
NE_ASSERT(("kv cache should be the same dtype", kv_self.v->type == NE_TYPE_JBLAS));
attn_shape_t attn_shape = {
/* .batch_size = */ 1,
/* .head_num = */ n_head,
/* .heads_kv = */ n_head,
/* .head_size = */ head_dim,
/* .sl_q = */ N, // Note: make sure that jblas reordered attn supports next token inference
/* .sl_kv = */ n_past + N,
};

NE_ASSERT(("jblas managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead",
jblas_reordered_attn_fp32_support(&attn_shape)));
kv_shape_t kv_shape{
/* .heads_kv = */ static_cast<uint32_t>(n_head),
/* .head_size = */ static_cast<uint32_t>(head_dim),
/* .sl_kv_max = */ static_cast<uint32_t>(n_ctx),
};
jblas_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info);
}

struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N);
ne_set_name(embd, "embd");
Expand All @@ -109,20 +129,23 @@ static bool mpt_model_eval_internal(model_context& lctx, const model_token* toke
cur = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[0], cur), cur);
}

{
cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur);
cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur);

if (model.hparams.clip_qkv > 0.0f) {
cur = ne_clamp(ctx0, cur, -model.hparams.clip_qkv, model.hparams.clip_qkv);
}
if (model.hparams.clip_qkv > 0.0f) {
cur = ne_clamp(ctx0, cur, -model.hparams.clip_qkv, model.hparams.clip_qkv);
}

struct ne_tensor* Qcur = ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0 * sizeof(float) * n_embd);
struct ne_tensor* Kcur = ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1 * sizeof(float) * n_embd);
struct ne_tensor* Vcur = ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2 * sizeof(float) * n_embd);

struct ne_tensor* Qcur = ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0 * sizeof(float) * n_embd);
struct ne_tensor* Kcur = ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1 * sizeof(float) * n_embd);
// self-attention
const float attn_scale = 1.0f / sqrtf(static_cast<float>(head_dim));

if (!run_mha_reordered) {
// store key and value to memory
{
struct ne_tensor* Vcur =
ne_transpose(ctx0, ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2 * sizeof(float) * n_embd));
Vcur = ne_transpose(ctx0, Vcur);
struct ne_tensor* k =
ne_view_1d(ctx0, kv_self.k, N * n_embd, (ne_element_size(kv_self.k) * n_embd) * (il * n_ctx + n_past));
struct ne_tensor* v =
Expand Down Expand Up @@ -151,7 +174,7 @@ static bool mpt_model_eval_internal(model_context& lctx, const model_token* toke
struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q);

// KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ne_tensor* KQ_scaled = ne_scale(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head)));
struct ne_tensor* KQ_scaled = ne_scale(ctx0, KQ, ne_new_f32(ctx0, attn_scale));

struct ne_tensor* KQ_scaled_alibi = ne_alibi(ctx0, KQ_scaled, n_past, n_head, model.hparams.alibi_bias_max);

Expand All @@ -175,10 +198,52 @@ static bool mpt_model_eval_internal(model_context& lctx, const model_token* toke

// cur = KQV_merged.contiguous().view(n_embd, N)
cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N, NE_SIZE_CALC));
} else {
const auto seq_kv = n_past + N;
const auto k_size = kv_cache_info.k_bytes;
const auto v_size = kv_cache_info.v_bytes;
// store key and value to memory
{
const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor
head_dim, n_ctx, n_head, // ne
0, 0, // nb (jblas managed)
il * k_size); // offset
Kcur = ne_view_3d(ctx0, Kcur, head_dim, n_head, N, Kcur->nb[0] * head_dim, Kcur->nb[1], 0);
ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past));
const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor
head_dim, n_ctx, n_head, // ne
0, 0, // nb (jblas managed)
il * v_size); // offset
Vcur = ne_view_3d(ctx0, Vcur, head_dim, n_head, N, Vcur->nb[0] * head_dim, Vcur->nb[1], 0);
ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past));
}

// projection
{ cur = ne_mul_mat(ctx0, model.layers[il].attn[1], cur); }
struct ne_tensor* Q = ne_view_3d(ctx0, Qcur, head_dim, n_head, N, Qcur->nb[0] * head_dim, Qcur->nb[1], 0);
Q = ne_permute(ctx0, Q, 0, 2, 1, 3);
ne_set_name(Q, "Q");
struct ne_tensor* K =
ne_view_3d(ctx0, kv_self.k, // tensor
head_dim, seq_kv, n_head, // ne
kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (jblas managed)
il * k_size); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout
ne_set_name(K, "K");
struct ne_tensor* V =
ne_view_3d(ctx0, kv_self.v, // tensor
seq_kv, head_dim, n_head, // ne
kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (jblas managed)
il * v_size); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout
ne_set_name(V, "V");

ne_attn_flags_t attn_flags = NE_ATTN_FLAG_IS_ALIBI8; // mpt uses alibi operation
if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases
struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags);
cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0);
}

// projection
{ cur = ne_mul_mat(ctx0, model.layers[il].attn[1], cur); }
inpL = ne_add(ctx0, inpL, cur);

lctx.use_buf(ctx0, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ void model_load_internal(const std::string& fname, model_archs arch, model_conte
ms->init(fname.c_str(), lctx, n_ctx, n_gpu_layers, use_mmap, use_mlock, vocab_only);
ms->load(lctx, progress_callback, progress_callback_user_data);

lctx.support_jblas_kv = true;
lctx.t_load_us = ne_time_us() - lctx.t_start_us;
}

Expand Down

0 comments on commit 7b73b1b

Please sign in to comment.