Skip to content

Commit

Permalink
[CPP Graph] ChatGLM2 MHA support (#435)
Browse files Browse the repository at this point in the history
  • Loading branch information
DDEle committed Oct 11, 2023
1 parent 69dc64b commit 692fde3
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "chatglm2.h"

#include <math.h>
#include <stdio.h>
#include <stdlib.h>
Expand All @@ -31,6 +33,7 @@
#include "core/data_types.h"
#include "core/ne.h"
#include "core/ne_layers.h"
#include "core/layers/mha_dense.h"
#include "models/model_utils/model_config.h"
#include "models/model_utils/model_utils.h"
#include "models/model_utils/util.h"
Expand Down Expand Up @@ -86,13 +89,35 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_token*
ne_cgraph gf = {};
gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads;

const bool run_mha_reordered = model.layers[0].k_cache->type == NE_TYPE_JBLAS;
kv_cache_info_t kv_cache_info = {};
if (run_mha_reordered) {
NE_ASSERT(("kv cache should be the same dtype", model.layers[0].v_cache->type == NE_TYPE_JBLAS));
attn_shape_t attn_shape = {
/* .batch_size = */ 1,
/* .head_num = */ n_head,
/* .heads_kv = */ num_kv_heads,
/* .head_size = */ head_size,
/* .sl_q = */ N, // Note: make sure that jblas reordered attn supports next token inference
/* .sl_kv = */ n_past + N,
};

NE_ASSERT(("jblas managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead",
jblas_reordered_attn_fp32_support(&attn_shape)));
kv_shape_t kv_shape{
/* .heads_kv = */ static_cast<uint32_t>(num_kv_heads),
/* .head_size = */ static_cast<uint32_t>(head_size),
/* .sl_kv_max = */ static_cast<uint32_t>(n_ctx),
};
jblas_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info);
}

struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N);
ne_set_name(embd, "embd");
memcpy(embd->data, tokens, N * ne_element_size(embd));
// int qlen = embd->ne[1];
struct ne_tensor* inpL = ne_get_rows(ctx0, model.others[0], embd);

int qlen = inpL->ne[1];
NE_ASSERT(N == inpL->ne[1]);
for (int il = 0; il < n_layer; ++il) {
struct ne_tensor* cur;

Expand All @@ -108,75 +133,113 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_token*

struct ne_tensor* query_layer =
ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1],
0); // [qlen, heads, head_size]
0); // [N, heads, head_size]
ne_set_name(query_layer, "query_layer");
query_layer = ne_rope_inplace(ctx0, query_layer, n_past, rope_dim, 0, 0);
query_layer = ne_cont(ctx0, ne_permute(ctx0, query_layer, 0, 2, 1, 3)); // [heads, qlen, head_size]
query_layer = ne_reshape_3d(ctx0, query_layer, head_size, mqa_scale * qlen,
num_kv_heads); // [kv_heads, mqa_scale * qlen, head_size]

struct ne_tensor* key_layer =
ne_view_3d(ctx0, cur, head_size, num_kv_heads, qlen, head_size * ne_element_size(cur), cur->nb[1],
hidden_size * ne_element_size(cur)); // [qlen, kv_heads, head_size]
ne_view_3d(ctx0, cur, head_size, num_kv_heads, N, head_size * ne_element_size(cur), cur->nb[1],
hidden_size * ne_element_size(cur)); // [N, kv_heads, head_size]
ne_set_name(key_layer, "key_layer");
key_layer = ne_rope_inplace(ctx0, key_layer, n_past, rope_dim, 0, 0);
key_layer = ne_permute(ctx0, key_layer, 0, 2, 1, 3); // [kv_heads, qlen, head_size]

struct ne_tensor* value_layer =
ne_view_3d(ctx0, cur, head_size, num_kv_heads, qlen, head_size * ne_element_size(cur), cur->nb[1],
(hidden_size + head_size * num_kv_heads) * ne_element_size(cur)); // [qlen, kv_heads, head_size]
ne_view_3d(ctx0, cur, head_size, num_kv_heads, N, head_size * ne_element_size(cur), cur->nb[1],
(hidden_size + head_size * num_kv_heads) * ne_element_size(cur)); // [N, kv_heads, head_size]
ne_set_name(value_layer, "value_layer");
value_layer = ne_permute(ctx0, value_layer, 1, 2, 0, 3); // [kv_heads, head_size, qlen]

// store key and value to memory
{
struct ne_tensor* k_cache_view =
ne_view_3d(ctx0, model.layers[il].k_cache, head_size, qlen, num_kv_heads, model.layers[il].k_cache->nb[1],
model.layers[il].k_cache->nb[2],
n_past * head_size * ne_element_size(model.layers[il].k_cache)); // [kv_heads, qlen, head_size]
ne_set_name(k_cache_view, "k_cache_view");
struct ne_tensor* v_cache_view =
ne_view_3d(ctx0, model.layers[il].v_cache, qlen, head_size, num_kv_heads, model.layers[il].v_cache->nb[1],
model.layers[il].v_cache->nb[2],
n_past * ne_element_size(model.layers[il].v_cache)); // [kv_heads, head_size, qlen]
ne_set_name(v_cache_view, "v_cache_view");

ne_build_forward_expand(&gf, ne_cpy(ctx0, key_layer, k_cache_view));
ne_build_forward_expand(&gf, ne_cpy(ctx0, value_layer, v_cache_view));
}

// concat key & value with past kv
key_layer = ne_view_3d(ctx0, model.layers[il].k_cache, head_size, n_past + qlen, num_kv_heads,
model.layers[il].k_cache->nb[1], model.layers[il].k_cache->nb[2],
0); // [kv_heads, klen, head_size]
value_layer = ne_view_3d(ctx0, model.layers[il].v_cache, n_past + qlen, head_size, num_kv_heads,
model.layers[il].v_cache->nb[1], model.layers[il].v_cache->nb[2],
0); // [kv_heads, head_size, klen]

// attention
struct ne_tensor* attn_scores = ne_mul_mat(ctx0, key_layer, query_layer); // [kv_heads, mqa_scale * qlen, klen]
ne_set_name(attn_scores, "attn_scores");
attn_scores = ne_scale_inplace(ctx0, attn_scores, ne_new_f32(ctx0, 1.f / std::sqrt(head_size)));

if (n_past == 0) {
// build attention mask for context input
attn_scores = ne_reshape_3d(ctx0, attn_scores, n_past + qlen, qlen,
num_attention_heads); // [heads, qlen, klen]
attn_scores = ne_diag_mask_inf_inplace(ctx0, attn_scores, n_past);
attn_scores = ne_reshape_3d(ctx0, attn_scores, n_past + qlen, mqa_scale * qlen,
num_kv_heads); // [kv_heads, mqa_scale * qlen, klen]
const float attn_scale = 1.f / std::sqrt(head_size);
if (!run_mha_reordered) {
query_layer = ne_cont(ctx0, ne_permute(ctx0, query_layer, 0, 2, 1, 3)); // [heads, N, head_size]
query_layer = ne_reshape_3d(ctx0, query_layer, head_size, mqa_scale * N,
num_kv_heads); // [kv_heads, mqa_scale * N, head_size]
key_layer = ne_permute(ctx0, key_layer, 0, 2, 1, 3); // [kv_heads, N, head_size]
value_layer = ne_permute(ctx0, value_layer, 1, 2, 0, 3); // [kv_heads, head_size, N]
// store key and value to memory
{
struct ne_tensor* k_cache_view =
ne_view_3d(ctx0, model.layers[il].k_cache, head_size, N, num_kv_heads, model.layers[il].k_cache->nb[1],
model.layers[il].k_cache->nb[2],
n_past * head_size * ne_element_size(model.layers[il].k_cache)); // [kv_heads, N, head_size]
ne_set_name(k_cache_view, "k_cache_view");
struct ne_tensor* v_cache_view =
ne_view_3d(ctx0, model.layers[il].v_cache, N, head_size, num_kv_heads, model.layers[il].v_cache->nb[1],
model.layers[il].v_cache->nb[2],
n_past * ne_element_size(model.layers[il].v_cache)); // [kv_heads, head_size, N]
ne_set_name(v_cache_view, "v_cache_view");

ne_build_forward_expand(&gf, ne_cpy(ctx0, key_layer, k_cache_view));
ne_build_forward_expand(&gf, ne_cpy(ctx0, value_layer, v_cache_view));
}

// concat key & value with past kv
key_layer = ne_view_3d(ctx0, model.layers[il].k_cache, head_size, n_past + N, num_kv_heads,
model.layers[il].k_cache->nb[1], model.layers[il].k_cache->nb[2],
0); // [kv_heads, klen, head_size]
value_layer = ne_view_3d(ctx0, model.layers[il].v_cache, n_past + N, head_size, num_kv_heads,
model.layers[il].v_cache->nb[1], model.layers[il].v_cache->nb[2],
0); // [kv_heads, head_size, klen]

// attention
struct ne_tensor* attn_scores = ne_mul_mat(ctx0, key_layer, query_layer); // [kv_heads, mqa_scale * N, klen]
ne_set_name(attn_scores, "attn_scores");
attn_scores = ne_scale_inplace(ctx0, attn_scores, ne_new_f32(ctx0, attn_scale));

if (n_past == 0) {
// build attention mask for context input
attn_scores = ne_reshape_3d(ctx0, attn_scores, n_past + N, N,
num_attention_heads); // [heads, N, klen]
attn_scores = ne_diag_mask_inf_inplace(ctx0, attn_scores, n_past);
attn_scores = ne_reshape_3d(ctx0, attn_scores, n_past + N, mqa_scale * N,
num_kv_heads); // [kv_heads, mqa_scale * N, klen]
}

struct ne_tensor* attn_probs = ne_soft_max_inplace(ctx0, attn_scores); // [kv_heads, mqa_scale * N, klen]

cur = ne_mul_mat(ctx0, value_layer, attn_probs); // [kv_heads, mqa_scale * N, head_size]
cur = ne_reshape_3d(ctx0, cur, head_size, N,
num_attention_heads); // [heads, N, head_size]
cur = ne_cont(ctx0, ne_permute(ctx0, cur, 0, 2, 1, 3)); // [N, heads, head_size]
cur = ne_reshape_2d(ctx0, cur, hidden_size, N); // [N, hidden]
} else { // Using MHA (GQA/MQA) managed kv-cache
const auto seq_kv = n_past + N;
const auto k_size = kv_cache_info.k_bytes;
const auto v_size = kv_cache_info.v_bytes;

// store key and value to memory
{
const auto k_cache = ne_view_3d(ctx0, model.layers[il].k_cache, // tensor
head_size, n_ctx, num_kv_heads, // ne
0, 0, // nb (jblas managed)
0); // offset
ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, key_layer, n_past));
const auto v_cache = ne_view_3d(ctx0, model.layers[il].v_cache, // tensor
head_size, n_ctx, num_kv_heads, // ne
0, 0, // nb (jblas managed)
0); // offset
ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, value_layer, n_past));
}

query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [heads, N, head_size]
key_layer = //
ne_view_3d(ctx0, model.layers[il].k_cache, // tensor
head_size, seq_kv, num_kv_heads, // ne
kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (jblas managed)
0); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&key_layer->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout
value_layer = //
ne_view_3d(ctx0, model.layers[il].v_cache, // tensor
seq_kv, head_size, num_kv_heads, // ne
kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (jblas managed)
0); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&value_layer->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout

ne_attn_flags_t attn_flags = 0;
if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases
struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, query_layer, key_layer, value_layer, attn_scale, attn_flags);
cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0);
}

struct ne_tensor* attn_probs = ne_soft_max_inplace(ctx0, attn_scores); // [kv_heads, mqa_scale * qlen, klen]

struct ne_tensor* context_layer =
ne_mul_mat(ctx0, value_layer, attn_probs); // [kv_heads, mqa_scale * qlen, head_size]
context_layer = ne_reshape_3d(ctx0, context_layer, head_size, qlen,
num_attention_heads); // [heads, qlen, head_size]
context_layer = ne_cont(ctx0, ne_permute(ctx0, context_layer, 0, 2, 1, 3)); // [qlen, heads, head_size]
context_layer = ne_reshape_2d(ctx0, context_layer, hidden_size, qlen); // [qlen, hidden]

cur = ne_mul_mat(ctx0, model.layers[il].attn[2], context_layer);
cur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
}

lctx.use_buf(ctx0, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <vector>

#include "core/data_types.h"
#include "core/layers/mha_dense.h"
#include "core/ne.h"
#include "core/ne_layers.h"
#include "models/chatglm/chatglm2.h"
Expand All @@ -48,6 +49,7 @@ void model_load_internal(const std::string& fname, model_archs arch, model_conte
ms->init(fname.c_str(), lctx, n_ctx, n_gpu_layers, use_mmap, use_mlock, vocab_only);
ms->load(lctx, progress_callback, progress_callback_user_data);

lctx.support_jblas_kv = true;
lctx.t_load_us = ne_time_us() - lctx.t_start_us;
}

Expand Down Expand Up @@ -94,9 +96,8 @@ void CHATGLM2::load(model_context& lctx, model_progress_callback progress_callba
fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);

const auto& hparams = model.hparams;
const int head_dim = n_embd / hparams.n_head;
const int kv_heads = hparams.n_head; // 1 if MQA else hparams.n_head
const int kv_dim = kv_heads * head_dim;
MODEL_ASSERT(("chatglm uses multi_query_group_num rather than n_head_kv",
hparams.n_head_kv == 0 && hparams.multi_query_group_num != 0));

// create the ne context
lctx.model.buf.resize(ctx_size);
Expand Down Expand Up @@ -135,28 +136,28 @@ void CHATGLM2::load(model_context& lctx, model_progress_callback progress_callba
layer.norm[1] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend);

// qkv GEMM
layer.attn[0] = ml->get_tensor(
layers_i + ".self_attention.query_key_value.weight",
{n_embd, n_embd + 2 * (n_embd / model.hparams.n_head) * model.hparams.multi_query_group_num}, backend);
layer.attn[1] =
ml->get_tensor(layers_i + ".self_attention.query_key_value.bias",
{n_embd + 2 * (n_embd / model.hparams.n_head) * model.hparams.multi_query_group_num}, backend);
layer.attn[0] =
ml->get_tensor(layers_i + ".self_attention.query_key_value.weight",
{n_embd, n_embd + 2 * (n_embd / hparams.n_head) * hparams.multi_query_group_num}, backend);
layer.attn[1] = ml->get_tensor(layers_i + ".self_attention.query_key_value.bias",
{n_embd + 2 * (n_embd / hparams.n_head) * hparams.multi_query_group_num}, backend);
layer.attn[2] = ml->get_tensor(layers_i + ".self_attention.dense.weight", {n_embd, n_embd}, backend);

// ffn GEMM
layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.dense_h_to_4h.weight",
{n_embd, uint32_t(model.hparams.ffn_hidden_size * 2)}, backend);
layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.dense_4h_to_h.weight",
{uint32_t(model.hparams.ffn_hidden_size), n_embd}, backend);
{n_embd, uint32_t(hparams.ffn_hidden_size * 2)}, backend);
layer.ffn[1] =
ml->get_tensor(layers_i + ".mlp.dense_4h_to_h.weight", {uint32_t(hparams.ffn_hidden_size), n_embd}, backend);

// kv-cache
layer.k_cache = nullptr; // kv-cache will be init later in model_utils
layer.v_cache = nullptr; // kv-cache will be init later in model_utils

layer.k_cache = d_ne_new_tensor_3d(model.ctx, NE_TYPE_F16, 4096 / 32, 32768, 2);
layer.v_cache = d_ne_new_tensor_3d(model.ctx, NE_TYPE_F16, 32768, 4096 / 32, 2);
if (backend != NE_BACKEND_CPU) {
vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.norm[1]) + ne_nbytes(layer.norm[2]) +
ne_nbytes(layer.norm[3]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) +
ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.ffn[0]) +
ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.k_cache) + ne_nbytes(layer.v_cache) +
ne_nbytes(layer.ffn[2]) + ne_nbytes(layer.ffn[3]);
ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]) + ne_nbytes(layer.ffn[3]);
}
}

Expand Down

0 comments on commit 692fde3

Please sign in to comment.