Skip to content

Commit

Permalink
[LLM Runtime] enable MHA fusion for gptneox&dolly&starcoder&llama2-70b (
Browse files Browse the repository at this point in the history
  • Loading branch information
intellinjun committed Nov 1, 2023
1 parent 0c57d11 commit 81dde20
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,28 @@ static bool gptneox_model_eval_internal(model_context& lctx, const model_token*
ne_cgraph gf = {};
gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads;

const bool run_mha_reordered = kv_self.k->type == NE_TYPE_JBLAS;
kv_cache_info_t kv_cache_info = {};
if (run_mha_reordered) {
NE_ASSERT(("kv cache should be the same dtype", kv_self.v->type == NE_TYPE_JBLAS));
attn_shape_t attn_shape = {
/* .batch_size = */ 1,
/* .head_num = */ n_head,
/* .heads_kv = */ n_head,
/* .head_size = */ head_dim,
/* .sl_q = */ N, // Note: make sure that jblas reordered attn supports next token inference
/* .sl_kv = */ n_past + N,
};

NE_ASSERT(("jblas managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead",
jblas_reordered_attn_fp32_support(&attn_shape)));
kv_shape_t kv_shape{
/* .heads_kv = */ static_cast<uint32_t>(n_head),
/* .head_size = */ static_cast<uint32_t>(head_dim),
/* .sl_kv_max = */ static_cast<uint32_t>(n_ctx),
};
jblas_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info);
}
struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N * batch_size);
ne_set_name(embd, "embd");
for (int i = 0; i < batch_size; ++i) {
Expand Down Expand Up @@ -151,78 +173,120 @@ static bool gptneox_model_eval_internal(model_context& lctx, const model_token*
// using mode = 2 for GPT-NeoX mode
Qcur = ne_rope_inplace(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), n_past, n_rot, 2, 0);
Kcur = ne_rope_inplace(ctx0, ne_reshape_4d(ctx0, Kcur, head_dim, n_head, N, batch_size), n_past, n_rot, 2, 0);

const float attn_scale = 1.0f / sqrtf(static_cast<float>(head_dim));
// store key and value to memory
{
std::vector<ne_tensor*> Kcur_bs(batch_size);
std::vector<ne_tensor*> Vcur_bs(batch_size);
std::vector<ne_tensor*> k_bs(batch_size);
std::vector<ne_tensor*> v_bs(batch_size);
for (int i = 0; i < batch_size; ++i) {
// batch K
Kcur_bs[i] = ne_permute(ctx0,
ne_view_4d(ctx0, Kcur, head_dim, n_head, N, 1, ne_element_size(Kcur) * head_dim,
ne_element_size(Kcur) * n_embd, ne_element_size(Kcur) * n_embd * N,
i * ne_element_size(Kcur) * n_embd * N),
0, 2, 1, 3);
k_bs[i] = ne_view_4d(
ctx0, kv_self.k, head_dim, N, n_head, 1, ne_element_size(kv_self.k) * head_dim,
ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx,
((il * n_ctx) * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block +
i * n_ctx * n_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k)));

// batch V
Vcur_bs[i] = ne_permute(ctx0,
ne_reshape_4d(ctx0,
ne_view_2d(ctx0, Vcur, n_embd, N, ne_element_size(Vcur) * n_embd,
i * ne_element_size(Vcur) * n_embd * N),
head_dim, n_head, N, 1),
1, 2, 0, 3);
v_bs[i] =
ne_view_4d(ctx0, kv_self.v, N, head_dim, n_head, 1, n_ctx * ne_element_size(kv_self.v),
n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd,
((il * n_ctx) * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block +
i * n_ctx * n_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v)));
ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i]));
ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i]));
if (!run_mha_reordered) {
{
std::vector<ne_tensor*> Kcur_bs(batch_size);
std::vector<ne_tensor*> Vcur_bs(batch_size);
std::vector<ne_tensor*> k_bs(batch_size);
std::vector<ne_tensor*> v_bs(batch_size);
for (int i = 0; i < batch_size; ++i) {
// batch K
Kcur_bs[i] = ne_permute(ctx0,
ne_view_4d(ctx0, Kcur, head_dim, n_head, N, 1, ne_element_size(Kcur) * head_dim,
ne_element_size(Kcur) * n_embd, ne_element_size(Kcur) * n_embd * N,
i * ne_element_size(Kcur) * n_embd * N),
0, 2, 1, 3);
k_bs[i] = ne_view_4d(
ctx0, kv_self.k, head_dim, N, n_head, 1, ne_element_size(kv_self.k) * head_dim,
ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx,
((il * n_ctx) * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block +
i * n_ctx * n_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k)));

// batch V
Vcur_bs[i] = ne_permute(ctx0,
ne_reshape_4d(ctx0,
ne_view_2d(ctx0, Vcur, n_embd, N, ne_element_size(Vcur) * n_embd,
i * ne_element_size(Vcur) * n_embd * N),
head_dim, n_head, N, 1),
1, 2, 0, 3);
v_bs[i] =
ne_view_4d(ctx0, kv_self.v, N, head_dim, n_head, 1, n_ctx * ne_element_size(kv_self.v),
n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd,
((il * n_ctx) * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block +
i * n_ctx * n_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v)));
ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i]));
ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i]));
}
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
struct ne_tensor* Q = ne_permute(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), 0, 2, 1, 3);

// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
struct ne_tensor* K =
ne_view_4d(ctx0, kv_self.k, head_dim, n_past + N, n_head, batch_size, ne_element_size(kv_self.k) * head_dim,
ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx,
il * n_ctx * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block);

// K * Q
struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q);

// KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head)));

// KQ_masked = mask_past(KQ_scaled)
struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);

// KQ = soft_max(KQ_masked)
struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_masked);

// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ne_tensor* V =
ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head, batch_size, n_ctx * ne_element_size(kv_self.v),
n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd,
il * n_ctx * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block);

// KQV = transpose(V) * KQ_soft_max
struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max);

// KQV_merged = KQV.permute(0, 2, 1, 3)
struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3);

// cur = KQV_merged.contiguous().view(n_embd, N)
cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size, NE_SIZE_CALC));
} else {
const auto seq_kv = n_past + N;
const auto k_size = kv_cache_info.k_bytes;
const auto v_size = kv_cache_info.v_bytes;

// store key and value to memory
{
const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor
head_dim, n_ctx, n_head, // ne
0, 0, // nb (jblas managed)
il * k_size); // offset
ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past));
const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor
head_dim, n_ctx, n_head, // ne
0, 0, // nb (jblas managed)
il * v_size); // offset
ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past));
}
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
struct ne_tensor* Q = ne_permute(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), 0, 2, 1, 3);

// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
struct ne_tensor* K =
ne_view_4d(ctx0, kv_self.k, head_dim, n_past + N, n_head, batch_size, ne_element_size(kv_self.k) * head_dim,
ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx,
il * n_ctx * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block);

// K * Q
struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q);

// KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head)));

// KQ_masked = mask_past(KQ_scaled)
struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);

// KQ = soft_max(KQ_masked)
struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_masked);

// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ne_tensor* V =
ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head, batch_size, n_ctx * ne_element_size(kv_self.v),
n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd,
il * n_ctx * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block);

// KQV = transpose(V) * KQ_soft_max
struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max);

// KQV_merged = KQV.permute(0, 2, 1, 3)
struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3);

// cur = KQV_merged.contiguous().view(n_embd, N)
cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size, NE_SIZE_CALC));

struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3);
ne_set_name(Q, "Q");

struct ne_tensor* K =
ne_view_3d(ctx0, kv_self.k, // tensor
head_dim, seq_kv, n_head, // ne
kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (jblas managed)
il * k_size); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout
ne_set_name(K, "K");
struct ne_tensor* V =
ne_view_3d(ctx0, kv_self.v, // tensor
seq_kv, head_dim, n_head, // ne
kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (jblas managed)
il * v_size); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout
ne_set_name(V, "V");

ne_attn_flags_t attn_flags = 0;
if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases
struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags);
cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0);
}
// projection
{
cur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ void model_load_internal(const std::string& fname, model_archs arch, model_conte
std::unique_ptr<GPTNEOX> ms(new GPTNEOX());
ms->init(fname.c_str(), lctx, n_gpu_layers, use_mmap, use_mlock, vocab_only);
ms->load(lctx, progress_callback, progress_callback_user_data);
lctx.support_jblas_kv = true;
if (lctx.beam_search) {
lctx.bs_kv_reorder = std::make_shared<gptneox_beam_search_kv_cache_reorder>(&lctx);
#ifdef NE_BEAM_SEARCH_VERBOSE_ON
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,17 +264,18 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to
const auto v_size = kv_cache_info.v_bytes;
// store key and value to memory
{
const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor
head_size, n_ctx, n_head, // ne
0, 0, // nb (jblas managed)
il * k_size); // offset
const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor
head_size, n_ctx, n_head_kv, // ne
0, 0, // nb (jblas managed)
il * k_size); // offset
ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past));
const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor
head_size, n_ctx, n_head, // ne
0, 0, // nb (jblas managed)
il * v_size); // offset
const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor
head_size, n_ctx, n_head_kv, // ne
0, 0, // nb (jblas managed)
il * v_size); // offset
// jblas alway view V as (D, n_head, seq)
const auto Vcur_plain = ne_reshape_3d(ctx0, ne_view_1d(ctx0, Vcur, n_embd * N, 0), n_embd / n_head, n_head, N);
const auto Vcur_plain =
ne_reshape_3d(ctx0, ne_view_1d(ctx0, Vcur, n_embd_gqa * N, 0), n_embd_gqa / n_head_kv, n_head_kv, N);
ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur_plain, n_past));
}

Expand All @@ -283,14 +284,14 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to

struct ne_tensor* K =
ne_view_3d(ctx0, kv_self.k, // tensor
head_size, n_cached, n_head, // ne
head_size, n_cached, n_head_kv, // ne
kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (jblas managed)
il * k_size); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout
ne_set_name(K, "K");
struct ne_tensor* V =
ne_view_3d(ctx0, kv_self.v, // tensor
n_cached, head_size, n_head, // ne
n_cached, head_size, n_head_kv, // ne
kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (jblas managed)
il * v_size); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout
Expand Down

0 comments on commit 81dde20

Please sign in to comment.