From f7b113d3b5f975a2f627b182efd42cb8d2df792e Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Wed, 15 Nov 2023 09:31:30 +0000 Subject: [PATCH 01/18] chatglm batch infer prototype Signed-off-by: Yu, Zhentao --- .../runtime/graph/models/chatglm/chatglm.cpp | 133 +++++++++++------- .../graph/models/chatglm/chatglm_utils.cpp | 4 +- .../graph/models/model_utils/model_types.h | 2 +- .../graph/models/model_utils/model_utils.cpp | 4 +- 4 files changed, 91 insertions(+), 52 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp index 131311eac9d..c01a468f1d8 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp @@ -55,9 +55,18 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* const int N = inputs->n_tokens; const int n_past = inputs->n_past; const int n_total = inputs->n_total; - + const int beam_size = lctx.beam_search ? lctx.beam_size : 1; const int batch_size = lctx.batch_size; MODEL_ASSERT(batch_size == n_input); + std::vector block_ids; + std::vector n_padding; + bool no_padding = true; + for (int i = 0; i < batch_size; ++i) { + block_ids.push_back((inputs + i)->request_idx * beam_size + (inputs + i)->beam_idx); + n_padding.push_back((inputs + i)->n_padding); + if (no_padding && (inputs + i)->n_padding != 0) no_padding = false; + } + const auto& model = lctx.model; const auto& hparams = model.hparams; @@ -96,7 +105,7 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* ne_cgraph gf = {}; gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads; - struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N); + struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N * batch_size); ne_set_name(embd, "embd"); for (int i = 0; i < batch_size; ++i) { memcpy(static_cast(embd->data) + i * N, (inputs + i)->tokens, N * ne_element_size(embd)); @@ -127,61 +136,86 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur); cur = ne_add(ctx0, cur, model.layers[il].attn[1]); - ne_tensor* query_layer = ne_view_3d(ctx0, cur, head_size, n_head, N, 3 * head_size * ne_element_size(cur), - cur->nb[1], 0); // [qlen, 3 * hidden] + ne_tensor* query_layer = + ne_view_3d(ctx0, cur, head_size, n_head, N * batch_size, 3 * head_size * ne_element_size(cur), cur->nb[1], + 0); // [qlen * bs, 3 * hidden] ne_set_name(query_layer, "query_layer"); query_layer = ne_rope_inplace(ctx0, query_layer, n_past, rope_dim, 4, first_tokens_size); - query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [heads, qlen, head_size] + query_layer = ne_permute(ctx0, ne_reshape_4d(ctx0, query_layer, head_size, n_head, N, batch_size), 0, 2, 1, + 3); // [bs, heads, qlen, head_size] ne_tensor* key_layer = - ne_view_3d(ctx0, cur, head_size, num_attention_heads, qlen, 3 * head_size * ne_element_size(cur), cur->nb[1], - head_size * ne_element_size(cur)); + ne_view_3d(ctx0, cur, head_size, num_attention_heads, qlen * batch_size, 3 * head_size * ne_element_size(cur), + cur->nb[1], head_size * ne_element_size(cur)); key_layer = ne_rope_inplace(ctx0, key_layer, n_past, rope_dim, 4, first_tokens_size); // [qlen, heads, head_size] - key_layer = ne_permute(ctx0, key_layer, 0, 2, 1, 3); // [heads, qlen, head_size] + key_layer = ne_permute(ctx0, ne_reshape_4d(ctx0, key_layer, head_size, num_attention_heads, qlen, batch_size), 0, + 2, 1, 3); // [bs, heads, qlen, head_size] - ne_tensor* value_layer = - ne_view_3d(ctx0, cur, head_size, num_attention_heads, qlen, 3 * head_size * ne_element_size(cur), cur->nb[1], - 2 * head_size * ne_element_size(cur)); // [qlen, heads, head_size] - value_layer = ne_permute(ctx0, value_layer, 1, 2, 0, 3); // [heads, head_size, qlen] + ne_tensor* value_layer = ne_view_3d(ctx0, cur, head_size, num_attention_heads, qlen * batch_size, + 3 * head_size * ne_element_size(cur), cur->nb[1], + 2 * head_size * ne_element_size(cur)); // [bs, qlen, heads, head_size] + value_layer = ne_permute(ctx0, ne_reshape_4d(ctx0, value_layer, head_size, num_attention_heads, qlen, batch_size), + 1, 2, 0, 3); // [bs, heads, head_size, qlen] // store key and value to memory { - struct ne_tensor* k_cache_view = - ne_view_3d(ctx0, model.layers[il].k_cache, head_size, qlen, num_attention_heads, - model.layers[il].k_cache->nb[1], model.layers[il].k_cache->nb[2], - n_past * head_size * ne_element_size(model.layers[il].k_cache)); // [kv_heads, qlen, head_size] - ne_set_name(k_cache_view, "k_cache_view"); - - struct ne_tensor* v_cache_view = - ne_view_3d(ctx0, model.layers[il].v_cache, qlen, head_size, num_attention_heads, - model.layers[il].v_cache->nb[1], model.layers[il].v_cache->nb[2], - n_past * ne_element_size(model.layers[il].v_cache)); // [kv_heads, head_size, qlen] - ne_set_name(v_cache_view, "v_cache_view"); - - ne_build_forward_expand(&gf, ne_cpy(ctx0, key_layer, k_cache_view)); - ne_build_forward_expand(&gf, ne_cpy(ctx0, value_layer, v_cache_view)); + std::vector Kcur_bs(batch_size); + std::vector Vcur_bs(batch_size); + std::vector k_bs(batch_size); + std::vector v_bs(batch_size); + for (int i = 0; i < batch_size; ++i) { + const int block_idx = block_ids[i]; + Kcur_bs[i] = ne_view_4d(ctx0, key_layer, head_size, qlen, num_attention_heads, 1, + ne_element_size(key_layer) * head_size, ne_element_size(key_layer) * head_size * qlen, + ne_element_size(key_layer) * head_size * qlen * num_attention_heads, + i * ne_element_size(key_layer) * head_size * qlen * num_attention_heads); + k_bs[i] = ne_view_4d( + ctx0, model.layers[il].k_cache, head_size, qlen, num_attention_heads, 1, model.layers[il].k_cache->nb[1], + model.layers[il].k_cache->nb[2], model.layers[il].k_cache->nb[3], + block_idx * n_ctx * n_embd * ne_element_size(model.layers[il].k_cache) + + n_past * head_size * ne_element_size(model.layers[il].k_cache)); // [kv_heads, qlen, head_size] + + Vcur_bs[i] = ne_view_4d(ctx0, value_layer, qlen, head_size, num_attention_heads, 1, + ne_element_size(value_layer) * qlen, ne_element_size(value_layer) * head_size * qlen, + ne_element_size(value_layer) * head_size * qlen * num_attention_heads, + i * ne_element_size(value_layer) * head_size * qlen * num_attention_heads); + v_bs[i] = ne_view_4d(ctx0, model.layers[il].v_cache, qlen, head_size, num_attention_heads, 1, + model.layers[il].v_cache->nb[1], model.layers[il].v_cache->nb[2], + model.layers[il].v_cache->nb[3], + block_idx * n_ctx * n_embd * ne_element_size(model.layers[il].v_cache) + + n_past * ne_element_size(model.layers[il].v_cache)); // [kv_heads, head_size, qlen] + + ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i])); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i])); + } } // concat key & value with past kv - key_layer = ne_view_3d(ctx0, model.layers[il].k_cache, head_size, n_past + qlen, num_attention_heads, - model.layers[il].k_cache->nb[1], model.layers[il].k_cache->nb[2], - 0); // [kv_heads, klen, head_size] - value_layer = ne_view_3d(ctx0, model.layers[il].v_cache, n_past + qlen, head_size, num_attention_heads, - model.layers[il].v_cache->nb[1], model.layers[il].v_cache->nb[2], - 0); // [kv_heads, head_size, klen] + key_layer = + ne_view_4d(ctx0, model.layers[il].k_cache, head_size, n_past + qlen, num_attention_heads, batch_size, + model.layers[il].k_cache->nb[1], model.layers[il].k_cache->nb[2], model.layers[il].k_cache->nb[3], + 0); // [bs, kv_heads, klen, head_size] + value_layer = + ne_view_4d(ctx0, model.layers[il].v_cache, n_past + qlen, head_size, num_attention_heads, batch_size, + model.layers[il].v_cache->nb[1], model.layers[il].v_cache->nb[2], model.layers[il].v_cache->nb[3], + 0); // [bs, kv_heads, head_size, klen] // attention - struct ne_tensor* attn_scores = ne_mul_mat(ctx0, key_layer, query_layer); // [kv_heads, mqa_scale * qlen, klen] + struct ne_tensor* attn_scores = + ne_mul_mat(ctx0, key_layer, query_layer); // [bs, kv_heads, mqa_scale * qlen, klen] ne_set_name(attn_scores, "attn_scores"); if (n_past == 0) { // build attention mask for context input - ne_tensor* inf = ne_new_tensor_3d(ctx0, attn_scores->type, 1, qlen - 1, num_attention_heads, NE_SIZE_CALC); + ne_tensor* inf = + ne_new_tensor_4d(ctx0, attn_scores->type, 1, qlen - 1, num_attention_heads, batch_size, NE_SIZE_CALC); ne_set_f32(inf, -INFINITY); ne_tensor* masked_attn_scores = - ne_view_3d(ctx0, attn_scores, 1, qlen - 1, num_attention_heads, qlen * ne_element_size(attn_scores), - qlen * qlen * ne_element_size(attn_scores), (qlen - 1) * ne_element_size(attn_scores)); + ne_view_4d(ctx0, attn_scores, 1, qlen - 1, num_attention_heads, batch_size, + qlen * ne_element_size(attn_scores), qlen * qlen * ne_element_size(attn_scores), + qlen * qlen * ne_element_size(attn_scores) * num_attention_heads, + (qlen - 1) * ne_element_size(attn_scores)); ne_set_name(masked_attn_scores, "masked_attn_scores"); ne_build_forward_expand(&gf, ne_cpy(ctx0, inf, masked_attn_scores)); @@ -190,13 +224,13 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* attn_scores = ne_scale_inplace(ctx0, attn_scores, ne_new_f32(ctx0, 1.f / std::sqrt(head_size))); ne_set_name(attn_scores, "attn_scores"); - ne_tensor* attn_probs = ne_soft_max_inplace(ctx0, attn_scores); // [heads, qlen, klen] + ne_tensor* attn_probs = ne_soft_max_inplace(ctx0, attn_scores); // [bs, heads, qlen, klen] - ne_tensor* context_layer = ne_mul_mat(ctx0, value_layer, attn_probs); // [heads, qlen, head_size] + ne_tensor* context_layer = ne_mul_mat(ctx0, value_layer, attn_probs); // [bs, heads, qlen, head_size] context_layer = ne_cont(ctx0, ne_permute(ctx0, context_layer, 0, 2, 1, 3)); - context_layer = ne_reshape_2d(ctx0, context_layer, hidden_size, qlen); + context_layer = ne_reshape_2d(ctx0, context_layer, hidden_size, qlen * batch_size); cur = ne_mul_mat(ctx0, model.layers[il].attn[2], context_layer); cur = ne_add(ctx0, cur, model.layers[il].attn[3]); @@ -248,8 +282,9 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* } lctx.use_buf(ctx0, -1); - if (embd->ne[0] > 1) { - inpL = ne_view_1d(ctx0, inpL, hidden_size, (embd->ne[0] - 1) * hidden_size * ne_element_size(inpL)); + if (!lctx.logits_all && embd->ne[0] / batch_size > 1) { + inpL = ne_view_2d(ctx0, inpL, hidden_size, batch_size, ne_element_size(inpL) * hidden_size * N, + (N - 1) * hidden_size * ne_element_size(inpL)); } // lm_head inpL = ne_mul_mat(ctx0, model.others[3], inpL); @@ -272,12 +307,12 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* auto& logits_out = lctx.logits; if (lctx.logits_all) { - logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float*)ne_get_data(inpL), sizeof(float) * n_vocab * N); + logits_out.resize(n_vocab * N * batch_size); + memcpy(logits_out.data(), (float*)ne_get_data(inpL), sizeof(float) * n_vocab * N * batch_size); } else { // return result for just the last token - logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float*)ne_get_data(inpL), sizeof(float) * n_vocab); + logits_out.resize(n_vocab * batch_size); + memcpy(logits_out.data(), (float*)ne_get_data(inpL), sizeof(float) * n_vocab * batch_size); } } @@ -285,8 +320,12 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* if (!lctx.embedding.empty()) { auto& embedding_out = lctx.embedding; - embedding_out.resize(n_embd); - memcpy(embedding_out.data(), (float*)ne_get_data(embeddings) + (n_embd * (N - 1)), sizeof(float) * n_embd); + embedding_out.resize(n_embd * batch_size); +#pragma omp parallel for + for (int i = 0; i < batch_size; ++i) { + memcpy(embedding_out.data() + (i * n_embd), + (float*)ne_get_data(embeddings) + (i * n_embd * N) + (n_embd * (N - 1)), sizeof(float) * n_embd); + } } if (mem_per_token == 0) { diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm_utils.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm_utils.cpp index df285fd1c24..f3dfc994f5e 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm_utils.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm_utils.cpp @@ -141,8 +141,8 @@ void CHATGLM::load(model_context& lctx, model_progress_callback progress_callbac layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.dense_4h_to_h.weight", {4 * n_embd, n_embd}, backend); layer.ffn[3] = ml->get_tensor(layers_i + ".mlp.dense_4h_to_h.bias", {n_embd}, backend); - layer.k_cache = d_ne_new_tensor_3d(model.ctx, NE_TYPE_F16, 4096 / 32, 2048, 32); - layer.v_cache = d_ne_new_tensor_3d(model.ctx, NE_TYPE_F16, 2048, 4096 / 32, 32); + layer.k_cache = nullptr; // kv-cache will be init later in model_utils + layer.v_cache = nullptr; // kv-cache will be init later in model_utils // if (backend != NE_BACKEND_CPU) { // vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.norm[1]) + // ne_nbytes(layer.norm[2]) + ne_nbytes(layer.norm[3]) + diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_types.h b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_types.h index c01cca1b3d2..74baacf6adf 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_types.h +++ b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_types.h @@ -299,7 +299,7 @@ struct model_context { std::vector logits; bool logits_all = false; - // input embedding (1-dimensional array: [n_embd]) + // input embedding (1-dimensional array: [n_embd * batch_size]) std::vector embedding; // memory buffers used to evaluate the model diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp index c29972f9796..b8ff4c0a2e1 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp @@ -104,8 +104,8 @@ static bool kv_cache_init(const struct model_hparams& hparams, struct model_kv_c if (wtype == NE_TYPE_F16) { // chatglm does not support fp32 kv-cache in original impl of chatglm_util.cpp const int head_size = hparams.n_embd / hparams.n_head; const int heads_kv = hparams.multi_query_group_num > 0 ? hparams.multi_query_group_num : hparams.n_head; - k_cache = d_ne_new_tensor_3d(model->ctx, NE_TYPE_F16, head_size, n_ctx, heads_kv); - v_cache = d_ne_new_tensor_3d(model->ctx, NE_TYPE_F16, n_ctx, head_size, heads_kv); + k_cache = d_ne_new_tensor_4d(model->ctx, NE_TYPE_F16, head_size, n_ctx, heads_kv, batch_size * beam_size); + v_cache = d_ne_new_tensor_4d(model->ctx, NE_TYPE_F16, n_ctx, head_size, heads_kv, batch_size * beam_size); } else if (wtype == NE_TYPE_JBLAS) { k_cache = ne_new_tensor_1d(model->ctx, wtype_alloc, layer_ne_k + NE_ALIGNMENT, NE_SIZE_CALC); const auto k_align_off = reinterpret_cast(k_cache->data) % NE_ALIGNMENT; From bef8f24318a1ef554246682a51cea663a57a139f Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Thu, 16 Nov 2023 08:42:42 +0000 Subject: [PATCH 02/18] fix batch=1 Signed-off-by: Yu, Zhentao --- .../runtime/graph/models/chatglm/chatglm.cpp | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp index c01a468f1d8..7e36fbd5adc 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp @@ -114,7 +114,7 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* struct ne_tensor* inpL = ne_get_rows(ctx0, model.others[0], embd); int hidden_size = inpL->ne[0]; - int qlen = inpL->ne[1]; + int qlen = inpL->ne[1] / batch_size; int head_size = hidden_size / num_attention_heads; int rope_dim = head_size / 2; for (int il = 0; il < n_layer; ++il) { @@ -137,26 +137,22 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* cur = ne_add(ctx0, cur, model.layers[il].attn[1]); ne_tensor* query_layer = - ne_view_3d(ctx0, cur, head_size, n_head, N * batch_size, 3 * head_size * ne_element_size(cur), cur->nb[1], - 0); // [qlen * bs, 3 * hidden] + ne_view_4d(ctx0, cur, head_size, n_head, N, batch_size, 3 * head_size * ne_element_size(cur), cur->nb[1], + cur->nb[1] * N, 0); // [qlen * bs, 3 * hidden] ne_set_name(query_layer, "query_layer"); query_layer = ne_rope_inplace(ctx0, query_layer, n_past, rope_dim, 4, first_tokens_size); - query_layer = ne_permute(ctx0, ne_reshape_4d(ctx0, query_layer, head_size, n_head, N, batch_size), 0, 2, 1, + query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [bs, heads, qlen, head_size] ne_tensor* key_layer = - ne_view_3d(ctx0, cur, head_size, num_attention_heads, qlen * batch_size, 3 * head_size * ne_element_size(cur), - cur->nb[1], head_size * ne_element_size(cur)); + ne_view_4d(ctx0, cur, head_size, num_attention_heads, qlen, batch_size, 3 * head_size * ne_element_size(cur), + cur->nb[1], cur->nb[1] * qlen, head_size * ne_element_size(cur)); key_layer = ne_rope_inplace(ctx0, key_layer, n_past, rope_dim, 4, first_tokens_size); // [qlen, heads, head_size] - key_layer = ne_permute(ctx0, ne_reshape_4d(ctx0, key_layer, head_size, num_attention_heads, qlen, batch_size), 0, - 2, 1, 3); // [bs, heads, qlen, head_size] - ne_tensor* value_layer = ne_view_3d(ctx0, cur, head_size, num_attention_heads, qlen * batch_size, - 3 * head_size * ne_element_size(cur), cur->nb[1], + ne_tensor* value_layer = ne_view_4d(ctx0, cur, head_size, num_attention_heads, qlen, batch_size, + 3 * head_size * ne_element_size(cur), cur->nb[1], cur->nb[1] * qlen, 2 * head_size * ne_element_size(cur)); // [bs, qlen, heads, head_size] - value_layer = ne_permute(ctx0, ne_reshape_4d(ctx0, value_layer, head_size, num_attention_heads, qlen, batch_size), - 1, 2, 0, 3); // [bs, heads, head_size, qlen] // store key and value to memory { @@ -166,20 +162,24 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* std::vector v_bs(batch_size); for (int i = 0; i < batch_size; ++i) { const int block_idx = block_ids[i]; - Kcur_bs[i] = ne_view_4d(ctx0, key_layer, head_size, qlen, num_attention_heads, 1, - ne_element_size(key_layer) * head_size, ne_element_size(key_layer) * head_size * qlen, - ne_element_size(key_layer) * head_size * qlen * num_attention_heads, - i * ne_element_size(key_layer) * head_size * qlen * num_attention_heads); + // [bs, heads, qlen, head_size] + Kcur_bs[i] = ne_permute( + ctx0, + ne_view_4d(ctx0, key_layer, head_size, num_attention_heads, qlen, 1, key_layer->nb[1], key_layer->nb[2], + key_layer->nb[3], i * ne_element_size(key_layer) * head_size * qlen * num_attention_heads), + 0, 2, 1, 3); k_bs[i] = ne_view_4d( ctx0, model.layers[il].k_cache, head_size, qlen, num_attention_heads, 1, model.layers[il].k_cache->nb[1], model.layers[il].k_cache->nb[2], model.layers[il].k_cache->nb[3], block_idx * n_ctx * n_embd * ne_element_size(model.layers[il].k_cache) + n_past * head_size * ne_element_size(model.layers[il].k_cache)); // [kv_heads, qlen, head_size] - Vcur_bs[i] = ne_view_4d(ctx0, value_layer, qlen, head_size, num_attention_heads, 1, - ne_element_size(value_layer) * qlen, ne_element_size(value_layer) * head_size * qlen, - ne_element_size(value_layer) * head_size * qlen * num_attention_heads, - i * ne_element_size(value_layer) * head_size * qlen * num_attention_heads); + // [bs, heads, head_size, qlen] + Vcur_bs[i] = ne_permute(ctx0, + ne_view_4d(ctx0, value_layer, head_size, num_attention_heads, qlen, 1, + value_layer->nb[1], value_layer->nb[2], value_layer->nb[3], + i * ne_element_size(value_layer) * head_size * qlen * num_attention_heads), + 1, 2, 0, 3); v_bs[i] = ne_view_4d(ctx0, model.layers[il].v_cache, qlen, head_size, num_attention_heads, 1, model.layers[il].v_cache->nb[1], model.layers[il].v_cache->nb[2], model.layers[il].v_cache->nb[3], From ea136cd6db2b62444139403e9b153be8cab48ee1 Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Thu, 16 Nov 2023 09:43:18 +0000 Subject: [PATCH 03/18] fix batch=2 Signed-off-by: Yu, Zhentao --- .../runtime/graph/models/chatglm/chatglm.cpp | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp index 7e36fbd5adc..38cc79c1ef5 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp @@ -142,8 +142,7 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* ne_set_name(query_layer, "query_layer"); query_layer = ne_rope_inplace(ctx0, query_layer, n_past, rope_dim, 4, first_tokens_size); - query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, - 3); // [bs, heads, qlen, head_size] + query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [bs, heads, qlen, head_size] ne_tensor* key_layer = ne_view_4d(ctx0, cur, head_size, num_attention_heads, qlen, batch_size, 3 * head_size * ne_element_size(cur), @@ -163,11 +162,10 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* for (int i = 0; i < batch_size; ++i) { const int block_idx = block_ids[i]; // [bs, heads, qlen, head_size] - Kcur_bs[i] = ne_permute( - ctx0, - ne_view_4d(ctx0, key_layer, head_size, num_attention_heads, qlen, 1, key_layer->nb[1], key_layer->nb[2], - key_layer->nb[3], i * ne_element_size(key_layer) * head_size * qlen * num_attention_heads), - 0, 2, 1, 3); + Kcur_bs[i] = ne_permute(ctx0, + ne_view_4d(ctx0, key_layer, head_size, num_attention_heads, qlen, 1, key_layer->nb[1], + key_layer->nb[2], key_layer->nb[3], i * key_layer->nb[3]), + 0, 2, 1, 3); k_bs[i] = ne_view_4d( ctx0, model.layers[il].k_cache, head_size, qlen, num_attention_heads, 1, model.layers[il].k_cache->nb[1], model.layers[il].k_cache->nb[2], model.layers[il].k_cache->nb[3], @@ -175,11 +173,11 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* n_past * head_size * ne_element_size(model.layers[il].k_cache)); // [kv_heads, qlen, head_size] // [bs, heads, head_size, qlen] - Vcur_bs[i] = ne_permute(ctx0, - ne_view_4d(ctx0, value_layer, head_size, num_attention_heads, qlen, 1, - value_layer->nb[1], value_layer->nb[2], value_layer->nb[3], - i * ne_element_size(value_layer) * head_size * qlen * num_attention_heads), - 1, 2, 0, 3); + Vcur_bs[i] = + ne_permute(ctx0, + ne_view_4d(ctx0, value_layer, head_size, num_attention_heads, qlen, 1, value_layer->nb[1], + value_layer->nb[2], value_layer->nb[3], i * value_layer->nb[3]), + 1, 2, 0, 3); v_bs[i] = ne_view_4d(ctx0, model.layers[il].v_cache, qlen, head_size, num_attention_heads, 1, model.layers[il].v_cache->nb[1], model.layers[il].v_cache->nb[2], model.layers[il].v_cache->nb[3], @@ -282,9 +280,9 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* } lctx.use_buf(ctx0, -1); - if (!lctx.logits_all && embd->ne[0] / batch_size > 1) { - inpL = ne_view_2d(ctx0, inpL, hidden_size, batch_size, ne_element_size(inpL) * hidden_size * N, - (N - 1) * hidden_size * ne_element_size(inpL)); + if (!lctx.logits_all && qlen > 1) { + inpL = ne_cont(ctx0, ne_view_2d(ctx0, inpL, hidden_size, batch_size, ne_element_size(inpL) * hidden_size * N, + (N - 1) * hidden_size * ne_element_size(inpL))); } // lm_head inpL = ne_mul_mat(ctx0, model.others[3], inpL); From 1e46ba307483928de6108893845b2a97cae6daf3 Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Fri, 17 Nov 2023 02:38:02 +0000 Subject: [PATCH 04/18] fix jblas mul_mat ne1 stride Signed-off-by: Yu, Zhentao --- .../llm/runtime/graph/core/ne_layers.c | 3 ++- .../llm/runtime/graph/models/chatglm/chatglm.cpp | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c index bc2becbff22..8384fb20262 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c +++ b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c @@ -6680,7 +6680,8 @@ static void ne_compute_forward_mul_mat_q_f32_jblas(const struct ne_compute_param if (params->type == NE_TASK_FINALIZE) { return; } - jblas_f32f32_forward((float*)src1->data, src0->data, (float*)dst->data, ne1, ne0, ne10, ne10, ne0, params->wdata); + jblas_f32f32_forward((float*)src1->data, src0->data, (float*)dst->data, ne1, ne0, ne10, nb11 / ne_element_size(src1), + nb1 / ne_element_size(dst), params->wdata); } static void ne_compute_forward_mul_mat(const struct ne_compute_params* params, const struct ne_tensor* src0, diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp index 38cc79c1ef5..30e826bf5dc 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp @@ -281,8 +281,8 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* lctx.use_buf(ctx0, -1); if (!lctx.logits_all && qlen > 1) { - inpL = ne_cont(ctx0, ne_view_2d(ctx0, inpL, hidden_size, batch_size, ne_element_size(inpL) * hidden_size * N, - (N - 1) * hidden_size * ne_element_size(inpL))); + inpL = ne_view_2d(ctx0, inpL, hidden_size, batch_size, ne_element_size(inpL) * hidden_size * N, + (N - 1) * hidden_size * ne_element_size(inpL)); } // lm_head inpL = ne_mul_mat(ctx0, model.others[3], inpL); From 7e593f2de7bfb2c9b19e4833e94d81fcb4c392ba Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Fri, 17 Nov 2023 03:30:33 +0000 Subject: [PATCH 05/18] add ne_padding_mask_inf Signed-off-by: Yu, Zhentao --- .../llm/runtime/graph/core/layers/Ops.h | 1 + .../llm/runtime/graph/core/ne_layers.c | 116 ++++++++++++++++++ .../llm/runtime/graph/core/ne_layers.h | 7 ++ .../runtime/graph/models/chatglm/chatglm.cpp | 3 +- 4 files changed, 126 insertions(+), 1 deletion(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/Ops.h b/intel_extension_for_transformers/llm/runtime/graph/core/layers/Ops.h index dce9c98d683..dc3e28a8d45 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/Ops.h +++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/Ops.h @@ -61,6 +61,7 @@ enum ne_op { NE_OP_DIAG, NE_OP_DIAG_MASK_INF, NE_OP_DIAG_MASK_ZERO, + NE_OP_PADDING_MASK_INF, NE_OP_SOFT_MAX, NE_OP_ROPE, NE_OP_ROPE_BACK, diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c index 8384fb20262..0ab465d7187 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c +++ b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c @@ -406,6 +406,7 @@ static const char* NE_OP_LABEL[NE_OP_COUNT] = { "DIAG", "DIAG_MASK_INF", "DIAG_MASK_ZERO", + "PADDING_MASK_INF", "SOFT_MAX", "ROPE", "ROPE_BACK", @@ -478,6 +479,7 @@ static const char* NE_OP_SYMBOL[NE_OP_COUNT] = { "diag(x)", "diag_mask_inf(x)", "diag_mask_zero(x)", + "padding_mask_inf(x)", "soft_max(x)", "rope(x)", "rope_back(x)", @@ -2889,6 +2891,52 @@ struct ne_tensor* ne_diag_mask_zero_inplace(struct ne_context* ctx, struct ne_te return ne_diag_mask_zero_impl(ctx, a, n_past, true); } +// ne_padding_mask_inf + +struct ne_tensor* ne_padding_mask_inf_impl(struct ne_context* ctx, struct ne_tensor* a, int n_past, int* n_padding, + bool padding_left, bool inplace) { + NE_ASSERT(padding_left); + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); + + ne_scratch_save(ctx); + + const int bs = a->ne[3]; + struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 2 + bs, NE_SIZE_CALC); + + ((int32_t*)b->data)[0] = n_past; + ((int32_t*)b->data)[1] = inplace ? 1 : 0; + for (int i = 0; i < bs; ++i) { + if (n_padding == NULL) { + ((int32_t*)b->data)[2 + i] = 0; + } else { + ((int32_t*)b->data)[2 + i] = *(n_padding + i); + } + } + + ne_scratch_load(ctx); + + result->op = NE_OP_PADDING_MASK_INF; + result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ne_tensor* ne_padding_left_mask_inf(struct ne_context* ctx, struct ne_tensor* a, int* n_padding) { + return ne_padding_mask_inf_impl(ctx, a, 0, n_padding, false, true); +} + +struct ne_tensor* ne_padding_left_mask_inf_inplace(struct ne_context* ctx, struct ne_tensor* a, int* n_padding) { + return ne_padding_mask_inf_impl(ctx, a, 0, n_padding, true, true); +} + // ne_soft_max struct ne_tensor* ne_soft_max_impl(struct ne_context* ctx, struct ne_tensor* a, bool inplace) { @@ -7416,6 +7464,70 @@ static void ne_compute_forward_diag_mask_zero(const struct ne_compute_params* pa } } +// ne_compute_forward_padding_mask_inf + +static void ne_compute_forward_padding_mask_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, + const struct ne_tensor* src1, struct ne_tensor* dst, + const float value) { + assert(src1->type == NE_TYPE_I32); + const int bs = src0->ne[3]; + assert(ne_nelements(src1) == (2 + bs)); + + const int ith = params->ith; + const int nth = params->nth; + + const int n_past = ((int32_t*)src1->data)[0]; + const bool inplace = (bool)((int32_t*)src1->data)[1]; + + assert(n_past >= 0); + + if (!inplace && (params->type == NE_TASK_INIT)) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + NE_ASSERT(ne_nelements(dst) == ne_nelements(src0)); + NE_ASSERT(ne_is_contiguous(dst) && ne_is_contiguous(src0)); + memcpy(((char*)dst->data), ((char*)src0->data), ne_nbytes(dst)); + } + + if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int n = ne_nrows(src0); + const int nc = src0->ne[0]; + const int nr = src0->ne[1]; + const int nz = n / nr; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + // mask padding token (padding left) + for (int b = 0; b < bs; b++) { + const int n_padding = ((int32_t*)src1->data)[2 + b]; + if (n_padding == 0) continue; + for (int k = 0; k < (nz / bs); k++) { + for (int j = ith; j < nr; j += nth) { + // it will not affect next token if don't mask the pad_token row + ne_vec_set_f32(n_padding, (float*)((char*)dst->data + b * dst->nb[3] + k * dst->nb[2] + j * dst->nb[1]), value); + } + } + } +} + +static void ne_compute_forward_padding_mask_inf(const struct ne_compute_params* params, const struct ne_tensor* src0, + const struct ne_tensor* src1, struct ne_tensor* dst) { + switch (src0->type) { + case NE_TYPE_F32: { + ne_compute_forward_padding_mask_f32(params, src0, src1, dst, -INFINITY); + } break; + default: { + NE_ASSERT(false); + } break; + } +} + // ne_compute_forward_soft_max static void ne_compute_forward_soft_max_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, @@ -9339,6 +9451,9 @@ static void ne_compute_forward(struct ne_compute_params* params, struct ne_tenso case NE_OP_DIAG_MASK_ZERO: { ne_compute_forward_diag_mask_zero(params, tensor->src0, tensor->src1, tensor); } break; + case NE_OP_PADDING_MASK_INF: { + ne_compute_forward_padding_mask_inf(params, tensor->src0, tensor->src1, tensor); + } break; case NE_OP_SOFT_MAX: { ne_compute_forward_soft_max(params, tensor->src0, tensor); } break; @@ -10353,6 +10468,7 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) { node->n_tasks = 1; } break; case NE_OP_DIAG_MASK_INF: + case NE_OP_PADDING_MASK_INF: case NE_OP_ROPE: if (node->type == NE_TYPE_JBLAS) { node->n_tasks = 1; diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.h b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.h index 5d920861b70..193c22f7e0d 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.h +++ b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.h @@ -371,6 +371,13 @@ NE_API struct ne_tensor* ne_diag_mask_inf(struct ne_context* ctx, struct ne_tens // in-place, returns view(a) NE_API struct ne_tensor* ne_diag_mask_inf_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past); +// set padding tokens to -INF +// only support padding left for now +NE_API struct ne_tensor* ne_padding_left_mask_inf(struct ne_context* ctx, struct ne_tensor* a, int* n_padding); + +// in-place, returns view(a) +NE_API struct ne_tensor* ne_padding_left_mask_inf_inplace(struct ne_context* ctx, struct ne_tensor* a, int* n_padding); + // set elements above the diagonal and padding tokens to -INF NE_API struct ne_tensor* ne_diag_mask_inf_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int* n_padding); diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp index 30e826bf5dc..763dd7db137 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp @@ -218,7 +218,8 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* ne_set_name(masked_attn_scores, "masked_attn_scores"); ne_build_forward_expand(&gf, ne_cpy(ctx0, inf, masked_attn_scores)); } - + // mask left pad token + attn_scores = ne_padding_left_mask_inf_inplace(ctx0, attn_scores, n_padding.data()); attn_scores = ne_scale_inplace(ctx0, attn_scores, ne_new_f32(ctx0, 1.f / std::sqrt(head_size))); ne_set_name(attn_scores, "attn_scores"); From e193b8b2c2755a3faa722e3fcd21ed9354a0932f Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Fri, 17 Nov 2023 03:34:40 +0000 Subject: [PATCH 06/18] fix typo Signed-off-by: Yu, Zhentao --- .../llm/runtime/graph/core/ne_layers.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c index 0ab465d7187..5be1efdfee5 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c +++ b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c @@ -2930,7 +2930,7 @@ struct ne_tensor* ne_padding_mask_inf_impl(struct ne_context* ctx, struct ne_ten } struct ne_tensor* ne_padding_left_mask_inf(struct ne_context* ctx, struct ne_tensor* a, int* n_padding) { - return ne_padding_mask_inf_impl(ctx, a, 0, n_padding, false, true); + return ne_padding_mask_inf_impl(ctx, a, 0, n_padding, true, false); } struct ne_tensor* ne_padding_left_mask_inf_inplace(struct ne_context* ctx, struct ne_tensor* a, int* n_padding) { From 24f4fc6d927ee736fb3b29a30f4e62b7a61a5d1e Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Fri, 17 Nov 2023 07:28:41 +0000 Subject: [PATCH 07/18] pybind draft Signed-off-by: Yu, Zhentao --- .../llm/runtime/graph/__init__.py | 27 +- .../runtime/graph/application/main_pybind.cpp | 281 ++++++++++++------ 2 files changed, 213 insertions(+), 95 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/__init__.py b/intel_extension_for_transformers/llm/runtime/graph/__init__.py index a0a21cef77b..c29d52fa049 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/__init__.py +++ b/intel_extension_for_transformers/llm/runtime/graph/__init__.py @@ -145,8 +145,9 @@ def quant_model(self, model_type, model_path, out_path, **quant_kwargs): def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, stopping_criteria=None, **generate_kwargs): max_new_tokens = generate_kwargs.get("max_new_tokens", -1) + self.batch_size = input_ids.shape[0] if self.model is None: - self.init_from_bin(self.model_type, self.bin_file, batch_size=input_ids.shape[0], + self.init_from_bin(self.model_type, self.bin_file, batch_size=self.batch_size, **generate_kwargs) self.generate_round = 0 elif not interactive: @@ -160,9 +161,9 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa beam_search = False if (generate_kwargs.get("num_beams", 1) > 1) and not generate_kwargs.get("do_sample", False): beam_search = True - if not beam_search: - # TODO support multi batch - assert input_ids.shape[0] == 1, "Unsupport multi-batch input ids." + # if not beam_search: + # # TODO support multi batch + # assert input_ids.shape[0] == 1, "Unsupport multi-batch input ids." if streamer: assert input_ids.shape[0] == 1, "Streamer only supports batch size 1." @@ -190,9 +191,14 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa if stopping_criteria is not None: if stopping_criteria(torch.tensor(ret), None): break - elif ret[0][-1] == self.eos_token_id() or \ - (max_new_tokens != -1 and out_count >= max_new_tokens): + elif (max_new_tokens != -1 and out_count > max_new_tokens): break + else: + all_done = [(r[-1] in [self.eos_token_id(), self.pad_token_id()]) + for r in ret] + if False not in all_done: + break + out_count += 1 if streamer: streamer.end() @@ -206,6 +212,15 @@ def eos_token_id(self): if self.model_type == 'qwen': return self.tokenizer.special_tokens['<|endoftext|>'] return self.tokenizer.eos_token_id + + def pad_token_id(self): + if self.tokenizer.pad_token_id == None: + if self.batch_size == 1: + return None + else: + raise ValueError("Please set pad_token_id when doing multi batch inference"\ + " with padding!") + return self.tokenizer.pad_token_id def __call__(self, input_ids, reinit=False, **kwargs): if self.model is None: diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp index f67027d6486..1b99563571a 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp @@ -120,12 +120,12 @@ class Model { private: model_context* ctx = nullptr; gpt_params params; - std::vector curr_input_ids; + std::vector> curr_input_ids; int n_past = 0; int n_total = 0; int n_vocab = 0; int n_ctx = 0; - std::vector last_n_tokens; + std::vector> last_n_tokens; bool token_eos = false; long int generate_count = 0; @@ -180,10 +180,14 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int n_ n_total = 0; token_eos = false; curr_input_ids.clear(); + curr_input_ids.resize(params.batch_size); ctx = model_init_from_gpt_params(params); n_vocab = model_n_vocab(ctx); n_ctx = model_n_ctx(ctx); - last_n_tokens.resize(n_ctx, 0); + last_n_tokens.resize(params.batch_size); + for (int i = 0; i < params.batch_size; ++i) { + last_n_tokens[i].resize(n_ctx, 0); + } ctx->generation_conf.min_new_tokens = min_new_tokens; ctx->generation_conf.length_penalty = length_penalty; ctx->generation_conf.do_early_stopping = early_stopping; @@ -194,9 +198,13 @@ void Model::reinit() { n_past = 0; n_total = 0; last_n_tokens.clear(); - last_n_tokens.resize(n_ctx, 0); + last_n_tokens.resize(params.batch_size); + for (int i = 0; i < params.batch_size; ++i) { + last_n_tokens[i].resize(n_ctx, 0); + } token_eos = false; curr_input_ids.clear(); + curr_input_ids.resize(params.batch_size); ctx->n_sample = 0; ctx->t_sample_us = 0; generate_count = 0; @@ -318,19 +326,7 @@ std::vector> Model::generate_tokens(const std::vector> rets; if (ctx->beam_search) { - MODEL_ASSERT(input_ids.size() == ctx->batch_size); - if (ctx->batch_size > 1 && ctx->vocab.pad_token_id == -1) { - fprintf(stderr, "\nERROR: please set pad_token for beam search multi-batch generation!\n"); - return rets; - } - std::vector inputs; for (int bs = 0; bs < input_ids.size(); ++bs) { - uint32_t count = 0; - model_vocab::id pad_token_id = ctx->vocab.pad_token_id; - auto iter = std::find_if(input_ids[bs].begin(), input_ids[bs].end(), - [&pad_token_id](model_token t) { return (t != pad_token_id); }); - if (iter == input_ids[bs].end()) fprintf(stderr, "\nERROR: there are all pad tokens in batch %d!\n", bs); - count = std::distance(input_ids[bs].begin(), iter); inputs.push_back(model_input{ /*.tokens =*/input_ids[bs].data(), /*.n_tokens =*/(uint32_t)input_ids[bs].size(), @@ -340,50 +336,45 @@ std::vector> Model::generate_tokens(const std::vector 1) { - fprintf(stderr, "\nERROR: Only beam search supports multi-batch generation!\n"); - return rets; - } - - if (curr_input_ids.empty()) { - if (input_ids[0].size() > n_ctx - 4) { - fprintf(stderr, "\n%s: Warning: prompt is too long (%d tokens, max %d), will be truncated\n", __func__, - input_ids[0].size(), n_ctx - 4); - curr_input_ids.resize(n_ctx - 4); - std::copy(input_ids[0].end() - n_ctx - 4, input_ids[0].end(), curr_input_ids.begin()); - } else { - curr_input_ids = input_ids[0]; + for (int bs = 0; bs < input_ids.size(); ++bs) { + if (curr_input_ids[bs].empty()) { + if (input_ids[bs].size() > n_ctx - 4) { + fprintf(stderr, "\n%s: Warning: prompt is too long (%d tokens, max %d), will be truncated\n", __func__, + input_ids[bs].size(), n_ctx - 4); + curr_input_ids[bs].resize(n_ctx - 4); + std::copy(input_ids[bs].end() - n_ctx - 4, input_ids[bs].end(), curr_input_ids[bs].begin()); + } else { + curr_input_ids[bs] = input_ids[bs]; + } } - } - while (output_ids.size() < n_remain) { - for (auto item : curr_input_ids) { - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(item); + for (auto item : curr_input_ids[bs]) { + last_n_tokens[bs].erase(last_n_tokens[bs].begin()); + last_n_tokens[bs].push_back(item); } // infinite text generation via context swapping - if (n_past + curr_input_ids.size() > n_ctx) { + if (n_past + curr_input_ids[bs].size() > n_ctx) { // always keep the first token n_past = std::max(1, params.n_keep); int n_discard = params.n_discard; if (!params.shift_roped_k) { // shift_roped_k can use ring-buffer and thus does not need re-computing - if (n_discard == -1) n_discard = (n_ctx - curr_input_ids.size() - params.n_keep) / 2; + if (n_discard == -1) n_discard = (n_ctx - curr_input_ids[bs].size() - params.n_keep) / 2; // drop n_discard tokens - curr_input_ids.insert(curr_input_ids.begin(), last_n_tokens.begin() + params.n_keep + n_discard, - last_n_tokens.end() - curr_input_ids.size()); + curr_input_ids[bs].insert(curr_input_ids[bs].begin(), last_n_tokens[bs].begin() + params.n_keep + n_discard, + last_n_tokens[bs].end() - curr_input_ids[bs].size()); } else { NE_ASSERT(("n_discard cannot be used with shift_roped_k!", n_discard == -1 || n_discard == 1)); } } - std::vector inputs = {model_input{ - /*.tokens =*/curr_input_ids.data(), - /*.n_tokens =*/(uint32_t)curr_input_ids.size(), + inputs.push_back(model_input{ + /*.tokens =*/curr_input_ids[bs].data(), + /*.n_tokens =*/(uint32_t)curr_input_ids[bs].size(), /*.n_prompt_tokens =*/0, /*.n_past =*/(uint32_t)n_past, /*.n_total =*/(uint32_t)n_total, @@ -391,32 +382,139 @@ std::vector> Model::generate_tokens(const std::vectorvocab.eos_token_id) { - token_eos = true; - break; - } - if (params.n_predict > 0 && generate_count >= params.n_predict) { - token_eos = true; - break; + }); + } + model_eval(ctx, inputs.data(), inputs.size(), params.n_threads); + // static batching inference should have same input length and context window length + n_past += curr_input_ids[0].size(); + n_total += curr_input_ids[0].size(); + + float* logits = model_get_logits(ctx); + std::vector next_token_ids = post_process(logits); + for (int bs = 0; bs < next_token_ids.size(); ++bs) { + // padding eos seq for continuous batched kv cache + // TODO batch reduction after for-loop attention implementation + if (curr_input_ids[bs].back() == ctx->vocab.eos_token_id || curr_input_ids[bs].back() == ctx->vocab.pad_token_id) { + curr_input_ids[bs] = {ctx->vocab.pad_token_id}; + } else { + curr_input_ids[bs] = {next_token_ids[bs]}; } } - rets.push_back(output_ids); - return rets; + + generate_count++; + return {next_token_ids}; } -model_token Model::post_greedy_search(const float* logits) { - model_token id = std::max_element(logits, logits + n_vocab) - logits; - return id; +// std::vector> Model::generate_tokens(const std::vector>& input_ids) +// { +// int n_remain = params.n_predict; +// std::vector output_ids; +// std::vector> rets; + +// if (ctx->beam_search) { +// MODEL_ASSERT(input_ids.size() == ctx->batch_size); +// if (ctx->batch_size > 1 && ctx->vocab.pad_token_id == -1) { +// fprintf(stderr, "\nERROR: please set pad_token for beam search multi-batch generation!\n"); +// return rets; +// } +// std::vector inputs; +// for (int bs = 0; bs < input_ids.size(); ++bs) { +// uint32_t count = 0; +// model_vocab::id pad_token_id = ctx->vocab.pad_token_id; +// auto iter = std::find_if(input_ids[bs].begin(), input_ids[bs].end(), +// [&pad_token_id](model_token t) { return (t != pad_token_id); }); +// if (iter == input_ids[bs].end()) fprintf(stderr, "\nERROR: there are all pad tokens in batch %d!\n", bs); +// count = std::distance(input_ids[bs].begin(), iter); +// inputs.push_back(model_input{ +// /*.tokens =*/input_ids[bs].data(), +// /*.n_tokens =*/(uint32_t)input_ids[bs].size(), +// /*.n_prompt_tokens =*/0, +// /*.n_past =*/0, +// /*.n_total =*/0, +// /*.request_idx =*/bs, +// /*.beam_idx =*/0, +// /*.padding_side =*/0, +// /*n_padding =*/count, +// }); +// } +// return post_beam_search(ctx, n_remain, inputs, params.n_threads); +// } +// if (input_ids.size() > 1) { +// fprintf(stderr, "\nERROR: Only beam search supports multi-batch generation!\n"); +// return rets; +// } + +// if (curr_input_ids.empty()) { +// if (input_ids[0].size() > n_ctx - 4) { +// fprintf(stderr, "\n%s: Warning: prompt is too long (%d tokens, max %d), will be truncated\n", __func__, +// input_ids[0].size(), n_ctx - 4); +// curr_input_ids.resize(n_ctx - 4); +// std::copy(input_ids[0].end() - n_ctx - 4, input_ids[0].end(), curr_input_ids.begin()); +// } else { +// curr_input_ids = input_ids[0]; +// } +// } + +// while (output_ids.size() < n_remain) { +// for (auto item : curr_input_ids) { +// last_n_tokens.erase(last_n_tokens.begin()); +// last_n_tokens.push_back(item); +// } +// // infinite text generation via context swapping +// if (n_past + curr_input_ids.size() > n_ctx) { +// // always keep the first token +// n_past = std::max(1, params.n_keep); + +// int n_discard = params.n_discard; +// if (!params.shift_roped_k) { // shift_roped_k can use ring-buffer and thus does not need re-computing +// if (n_discard == -1) n_discard = (n_ctx - curr_input_ids.size() - params.n_keep) / 2; +// // drop n_discard tokens +// curr_input_ids.insert(curr_input_ids.begin(), last_n_tokens.begin() + params.n_keep + n_discard, +// last_n_tokens.end() - curr_input_ids.size()); +// } else { +// NE_ASSERT(("n_discard cannot be used with shift_roped_k!", n_discard == -1 || n_discard == 1)); +// } +// } +// std::vector inputs = {model_input{ +// /*.tokens =*/curr_input_ids.data(), +// /*.n_tokens =*/(uint32_t)curr_input_ids.size(), +// /*.n_prompt_tokens =*/0, +// /*.n_past =*/(uint32_t)n_past, +// /*.n_total =*/(uint32_t)n_total, +// /*.request_idx =*/0, +// /*.beam_idx =*/0, +// /*.padding_side =*/0, +// /*n_padding =*/0, +// }}; +// model_eval(ctx, inputs.data(), inputs.size(), params.n_threads); +// n_past += curr_input_ids.size(); +// n_total += curr_input_ids.size(); + +// float* logits = model_get_logits(ctx); +// model_token next_token_id = post_process(logits); +// curr_input_ids = {next_token_id}; +// output_ids.push_back(next_token_id); +// generate_count++; +// if (next_token_id == ctx->vocab.eos_token_id) { +// token_eos = true; +// break; +// } +// if (params.n_predict > 0 && generate_count >= params.n_predict) { +// token_eos = true; +// break; +// } +// } +// rets.push_back(output_ids); +// return rets; +// } + +std::vector Model::post_greedy_search(float* logits) { + std::vector ids(ctx->batch_size); +#pragma omp parallel for + for (int bs = 0; bs < ctx->batch_size; ++bs) { + ids[bs] = std::max_element(logits + bs * n_vocab, logits + (bs + 1) * n_vocab) - (logits + bs * n_vocab); + } + return ids; } std::vector> Model::post_beam_search(model_context* lctx, const int& n_predict, @@ -432,7 +530,7 @@ std::vector> Model::post_beam_search(model_context* lct } } -model_token Model::post_sample_top_k_top_p_repeat(const float* logits) { +std::vector Model::post_sample_top_k_top_p_repeat(float* logits) { int alpha_frequency = 0; int alpha_presence = 0; int repeat_last_n = 64; @@ -441,33 +539,38 @@ model_token Model::post_sample_top_k_top_p_repeat(const float* logits) { float typical_p = 1.00f; float top_p = params.top_p; float temp = params.temp; - std::vector candidates; - candidates.reserve(n_vocab); - for (model_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(model_token_data{token_id, logits[token_id], 0.0f}); + std::vector ids(ctx->batch_size); +#pragma omp parallel for + for (int bs = 0; bs < ctx->batch_size; ++bs) { + std::vector candidates; + candidates.reserve(n_vocab); + for (model_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(model_token_data{token_id, logits[bs * n_vocab + token_id], 0.0f}); + } + model_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; + + // Apply penalties + float nl_logit = logits[bs * n_vocab + model_token_nl()]; + auto last_n_repeat = std::min(std::min((int)last_n_tokens[bs].size(), repeat_last_n), n_ctx); + model_sample_repetition_penalty(ctx, &candidates_p, + last_n_tokens[bs].data() + last_n_tokens[bs].size() - last_n_repeat, last_n_repeat, + params.repeat_penalty); + model_sample_frequency_and_presence_penalties(ctx, &candidates_p, + last_n_tokens[bs].data() + last_n_tokens[bs].size() - last_n_repeat, + last_n_repeat, alpha_frequency, alpha_presence); + // int id = model_sample_token_greedy(ctx, &candidates_p); + // Temperature sampling + model_sample_top_k(ctx, &candidates_p, top_k, 1); + model_sample_tail_free(ctx, &candidates_p, tfs_z, 1); + model_sample_typical(ctx, &candidates_p, typical_p, 1); + model_sample_top_p(ctx, &candidates_p, top_p, 1); + model_sample_temperature(ctx, &candidates_p, temp); + ids[bs] = model_sample_token(ctx, &candidates_p); } - model_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; - - // Apply penalties - float nl_logit = logits[model_token_nl()]; - auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); - model_sample_repetition_penalty(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, params.repeat_penalty); - model_sample_frequency_and_presence_penalties(ctx, &candidates_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); - // int id = model_sample_token_greedy(ctx, &candidates_p); - // Temperature sampling - model_sample_top_k(ctx, &candidates_p, top_k, 1); - model_sample_tail_free(ctx, &candidates_p, tfs_z, 1); - model_sample_typical(ctx, &candidates_p, typical_p, 1); - model_sample_top_p(ctx, &candidates_p, top_p, 1); - model_sample_temperature(ctx, &candidates_p, temp); - int id = model_sample_token(ctx, &candidates_p); - return id; + return ids; } -model_token Model::post_process(const float* logits) { +std::vector Model::post_process(float* logits) { assert(("Beam search does not support streaming.", params.beam_size == 1)); if (params.do_sample == false) { return post_greedy_search(logits); From d27119646ea8c36e62f4ac97f6f822527b8c1196 Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Mon, 20 Nov 2023 06:40:23 +0000 Subject: [PATCH 08/18] merge new main_pybind Signed-off-by: Yu, Zhentao --- .../runtime/graph/application/main_pybind.cpp | 378 ++++++++---------- 1 file changed, 160 insertions(+), 218 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp index 1b99563571a..ccec7feaf34 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp @@ -70,13 +70,11 @@ class Model { bool shift_roped_k, int batch_size, model_vocab::id pad_token, const std::string& memory_dtype); void reinit(); std::vector> generate(const std::vector>& input_ids); + // deprecated API std::vector> generate_tokens(const std::vector>& input_ids); const std::vector& evaluate_(const std::vector>& input_ids); py::array_t evaluate(const std::vector>& input_ids) { - if (input_ids.size() != 1) { - fprintf(stderr, "\nERROR: only support batch == 1 input!\n"); - return py::array_t(); - } + if (!check_input_and_count_padding(input_ids)) return py::array_t(); const auto& logits = evaluate_(input_ids); return py::array_t(logits.size(), logits.data()) .reshape({py::ssize_t(-1), static_cast(ctx->model.hparams.n_vocab)}); @@ -128,13 +126,15 @@ class Model { std::vector> last_n_tokens; bool token_eos = false; long int generate_count = 0; + std::vector padding_count; std::vector> beam_generate(const std::vector>& input_ids); - model_token post_process(const float* logits); - model_token post_greedy_search(const float* logits); + std::vector post_process(const float* logits); + std::vector post_greedy_search(const float* logits); std::vector> post_beam_search(model_context* lctx, const int& n_predict, const std::vector& inputs, const int& n_threads); - model_token post_sample_top_k_top_p_repeat(const float* logits); + std::vector post_sample_top_k_top_p_repeat(const float* logits); + bool check_input_and_count_padding(const std::vector>& input_ids); }; void Model::init_model(const std::string& model_path, int max_new_tokens, int n_batch, int ctx_size, int seed, @@ -208,22 +208,40 @@ void Model::reinit() { ctx->n_sample = 0; ctx->t_sample_us = 0; generate_count = 0; + padding_count.clear(); } -std::vector> Model::beam_generate(const std::vector>& input_ids) { +bool Model::check_input_and_count_padding(const std::vector>& input_ids) { + if (input_ids.empty()) return false; + if (input_ids.size() == 1) { + padding_count = {0}; + return true; + } + // multi-batch inputs MODEL_ASSERT(input_ids.size() == ctx->batch_size); - if (ctx->batch_size > 1 && ctx->vocab.pad_token_id == -1) { - fprintf(stderr, "\nERROR: please set pad_token for beam search multi-batch generation!\n"); - return {}; + static std::set batched_model_archs = {MODEL_GPTJ, MODEL_GPTNEOX, MODEL_CHATGLM}; + if (batched_model_archs.count(params.model_arch) == 0) { + fprintf(stderr, "\nERROR: Only gpt-j, gpt-neox, chatglm support multi-batch generation!\n"); + return false; } - std::vector inputs; + if (ctx->vocab.pad_token_id == -1) { + fprintf(stderr, "\nERROR: please set pad_token for static multi-batch generation (tokenizer.pad_token_id)!\n"); + return false; + } + if (!padding_count.empty()) padding_count.clear(); for (int bs = 0; bs < input_ids.size(); ++bs) { - uint32_t count = 0; model_vocab::id pad_token_id = ctx->vocab.pad_token_id; auto iter = std::find_if(input_ids[bs].begin(), input_ids[bs].end(), [&pad_token_id](model_token t) { return (t != pad_token_id); }); if (iter == input_ids[bs].end()) fprintf(stderr, "\nERROR: there are all pad tokens in batch %d!\n", bs); - count = std::distance(input_ids[bs].begin(), iter); + padding_count.push_back(std::distance(input_ids[bs].begin(), iter)); + } + return true; +} + +std::vector> Model::beam_generate(const std::vector>& input_ids) { + std::vector inputs; + for (int bs = 0; bs < input_ids.size(); ++bs) { inputs.push_back(model_input{ /*.tokens =*/input_ids[bs].data(), /*.n_tokens =*/(uint32_t)input_ids[bs].size(), @@ -233,7 +251,7 @@ std::vector> Model::beam_generate(const std::vector> Model::beam_generate(const std::vector& Model::evaluate_(const std::vector>& input_ids) { static const std::vector empty_ret{}; - if (input_ids.size() > 1) { - fprintf(stderr, "\nERROR: Only beam search supports multi-batch generation!\n"); - return empty_ret; - } static const std::vector empty_id{}; - const auto& input_id0 = input_ids.empty() ? empty_id : input_ids[0]; // currently only support single batch - if (input_id0.empty()) { // use internel input id - if (curr_input_ids.empty()) { - fprintf(stderr, "%s: error: no input\n", __func__); + std::vector inputs; + for (int bs = 0; bs < input_ids.size(); ++bs) { + const auto& input_id_cb = input_ids.empty() ? empty_id : input_ids[bs]; + if (input_id_cb.empty()) { // use internel input id + if (curr_input_ids[bs].empty()) { + fprintf(stderr, "%s: error: no input\n", __func__); + return empty_ret; + } + } else if (!curr_input_ids[bs].empty()) { + fprintf(stderr, "%s: error: prompt confliction\n", __func__); return empty_ret; + } else if (input_id_cb.size() > n_ctx - 4) { // long input_id_cb and empty curr_input_ids[bs] + fprintf(stderr, "\n%s: Warning: prompt is too long (%d tokens, max %d), will be truncated\n", __func__, + input_id_cb.size(), n_ctx - 4); + curr_input_ids[bs].resize(n_ctx - 4); + std::copy(input_id_cb.end() - n_ctx - 4, input_id_cb.end(), curr_input_ids[bs].begin()); + } else { // good input_id_cb and empty curr_input_ids[bs] + curr_input_ids[bs] = input_id_cb; } - } else if (!curr_input_ids.empty()) { - fprintf(stderr, "%s: error: prompt confliction\n", __func__); - return empty_ret; - } else if (input_id0.size() > n_ctx - 4) { // long input_id0 and empty curr_input_ids - fprintf(stderr, "\n%s: Warning: prompt is too long (%d tokens, max %d), will be truncated\n", __func__, - input_id0.size(), n_ctx - 4); - curr_input_ids.resize(n_ctx - 4); - std::copy(input_id0.end() - n_ctx - 4, input_id0.end(), curr_input_ids.begin()); - } else { // good input_id0 and empty curr_input_ids - curr_input_ids = input_id0; - } - // push elements in curr_input_ids to the last_n_tokens queue - last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + curr_input_ids.size()); - last_n_tokens.insert(last_n_tokens.end(), curr_input_ids.begin(), curr_input_ids.end()); - - // infinite text generation via context swapping - if (n_past + curr_input_ids.size() > n_ctx) { - // always keep the first token - n_past = std::max(1, params.n_keep); - - int n_discard = params.n_discard; - if (!params.shift_roped_k) { // shift_roped_k can use ring-buffer and thus does not need re-computing - if (n_discard == -1) n_discard = (n_ctx - curr_input_ids.size() - params.n_keep) / 2; - // drop n_discard tokens - curr_input_ids.insert(curr_input_ids.begin(), last_n_tokens.begin() + params.n_keep + n_discard, - last_n_tokens.end() - curr_input_ids.size()); - } else { - NE_ASSERT(("n_discard cannot be used with shift_roped_k!", n_discard == -1 || n_discard == 1)); + // push elements in curr_input_ids[bs] to the last_n_tokens[bs] queue + last_n_tokens[bs].erase(last_n_tokens[bs].begin(), last_n_tokens[bs].begin() + curr_input_ids[bs].size()); + last_n_tokens[bs].insert(last_n_tokens[bs].end(), curr_input_ids[bs].begin(), curr_input_ids[bs].end()); + + // infinite text generation via context swapping + if (n_past + curr_input_ids[bs].size() > n_ctx) { + // always keep the first token + n_past = std::max(1, params.n_keep); + + int n_discard = params.n_discard; + if (!params.shift_roped_k) { // shift_roped_k can use ring-buffer and thus does not need re-computing + if (n_discard == -1) n_discard = (n_ctx - curr_input_ids[bs].size() - params.n_keep) / 2; + // drop n_discard tokens + curr_input_ids[bs].insert(curr_input_ids[bs].begin(), last_n_tokens[bs].begin() + params.n_keep + n_discard, + last_n_tokens[bs].end() - curr_input_ids[bs].size()); + } else { + NE_ASSERT(("n_discard cannot be used with shift_roped_k!", n_discard == -1 || n_discard == 1)); + } } - } - std::vector inputs{{ - /*.tokens =*/curr_input_ids.data(), - /*.n_tokens =*/(uint32_t)curr_input_ids.size(), - /*.n_prompt_tokens =*/0, - /*.n_past =*/(uint32_t)n_past, - /*.n_total =*/(uint32_t)n_total, - /*.request_idx =*/0, - /*.beam_idx =*/0, - /*.padding_side =*/0, - /*n_padding =*/0, - }}; + inputs.push_back({ + /*.tokens =*/curr_input_ids[bs].data(), + /*.n_tokens =*/(uint32_t)curr_input_ids[bs].size(), + /*.n_prompt_tokens =*/0, + /*.n_past =*/(uint32_t)n_past, + /*.n_total =*/(uint32_t)n_total, + /*.request_idx =*/bs, + /*.beam_idx =*/0, + /*.padding_side =*/0, + /*n_padding =*/padding_count[bs], + }); + } model_eval(ctx, inputs.data(), inputs.size(), params.n_threads); - n_past += curr_input_ids.size(); - n_total += curr_input_ids.size(); + // static batching inference should have same input length and context window length + n_past += curr_input_ids[0].size(); + n_total += curr_input_ids[0].size(); curr_input_ids.clear(); // add new tok to curr_input_ids if necessary after post processing return ctx->logits; } std::vector> Model::generate(const std::vector>& input_ids) { + if (!check_input_and_count_padding(input_ids)) return {}; if (ctx->beam_search) return beam_generate(input_ids); - if (input_ids.size() > 1) { - fprintf(stderr, "\nERROR: Only beam search supports multi-batch generation!\n"); - return {}; - } const auto& logits = evaluate_(input_ids); if (logits.empty()) return {}; - model_token next_token_id = post_process(logits.data()); - curr_input_ids = {next_token_id}; + std::vector next_token_ids = post_process(logits.data()); + std::vector> ret_next_tokens; + for (int bs = 0; bs < next_token_ids.size(); ++bs) { + // padding eos seq for continuous batched kv cache + // TODO batch reduction after for-loop attention implementation + if (curr_input_ids[bs].back() == ctx->vocab.eos_token_id || curr_input_ids[bs].back() == ctx->vocab.pad_token_id) { + curr_input_ids[bs] = {ctx->vocab.pad_token_id}; + ret_next_tokens.push_back({ctx->vocab.pad_token_id}); + } else { + curr_input_ids[bs] = {next_token_ids[bs]}; + ret_next_tokens.push_back({next_token_ids[bs]}); + } + } generate_count++; - return {{next_token_id}}; + return ret_next_tokens; } +// deprecated API std::vector> Model::generate_tokens(const std::vector>& input_ids) { int n_remain = params.n_predict; std::vector output_ids; std::vector> rets; if (ctx->beam_search) { + MODEL_ASSERT(input_ids.size() == ctx->batch_size); + if (ctx->batch_size > 1 && ctx->vocab.pad_token_id == -1) { + fprintf(stderr, "\nERROR: please set pad_token for beam search multi-batch generation!\n"); + return rets; + } + std::vector inputs; for (int bs = 0; bs < input_ids.size(); ++bs) { + uint32_t count = 0; + model_vocab::id pad_token_id = ctx->vocab.pad_token_id; + auto iter = std::find_if(input_ids[bs].begin(), input_ids[bs].end(), + [&pad_token_id](model_token t) { return (t != pad_token_id); }); + if (iter == input_ids[bs].end()) fprintf(stderr, "\nERROR: there are all pad tokens in batch %d!\n", bs); + count = std::distance(input_ids[bs].begin(), iter); inputs.push_back(model_input{ /*.tokens =*/input_ids[bs].data(), /*.n_tokens =*/(uint32_t)input_ids[bs].size(), @@ -336,45 +375,50 @@ std::vector> Model::generate_tokens(const std::vector n_ctx - 4) { - fprintf(stderr, "\n%s: Warning: prompt is too long (%d tokens, max %d), will be truncated\n", __func__, - input_ids[bs].size(), n_ctx - 4); - curr_input_ids[bs].resize(n_ctx - 4); - std::copy(input_ids[bs].end() - n_ctx - 4, input_ids[bs].end(), curr_input_ids[bs].begin()); - } else { - curr_input_ids[bs] = input_ids[bs]; - } + if (input_ids.size() > 1) { + fprintf(stderr, "\nERROR: Only beam search supports multi-batch generation!\n"); + return rets; + } + + if (curr_input_ids[0].empty()) { + if (input_ids[0].size() > n_ctx - 4) { + fprintf(stderr, "\n%s: Warning: prompt is too long (%d tokens, max %d), will be truncated\n", __func__, + input_ids[0].size(), n_ctx - 4); + curr_input_ids[0].resize(n_ctx - 4); + std::copy(input_ids[0].end() - n_ctx - 4, input_ids[0].end(), curr_input_ids[0].begin()); + } else { + curr_input_ids[0] = input_ids[0]; } + } - for (auto item : curr_input_ids[bs]) { - last_n_tokens[bs].erase(last_n_tokens[bs].begin()); - last_n_tokens[bs].push_back(item); + while (output_ids.size() < n_remain) { + for (auto item : curr_input_ids[0]) { + last_n_tokens[0].erase(last_n_tokens[0].begin()); + last_n_tokens[0].push_back(item); } // infinite text generation via context swapping - if (n_past + curr_input_ids[bs].size() > n_ctx) { + if (n_past + curr_input_ids[0].size() > n_ctx) { // always keep the first token n_past = std::max(1, params.n_keep); int n_discard = params.n_discard; if (!params.shift_roped_k) { // shift_roped_k can use ring-buffer and thus does not need re-computing - if (n_discard == -1) n_discard = (n_ctx - curr_input_ids[bs].size() - params.n_keep) / 2; + if (n_discard == -1) n_discard = (n_ctx - curr_input_ids[0].size() - params.n_keep) / 2; // drop n_discard tokens - curr_input_ids[bs].insert(curr_input_ids[bs].begin(), last_n_tokens[bs].begin() + params.n_keep + n_discard, - last_n_tokens[bs].end() - curr_input_ids[bs].size()); + curr_input_ids[0].insert(curr_input_ids[0].begin(), last_n_tokens[0].begin() + params.n_keep + n_discard, + last_n_tokens[0].end() - curr_input_ids[0].size()); } else { NE_ASSERT(("n_discard cannot be used with shift_roped_k!", n_discard == -1 || n_discard == 1)); } } - inputs.push_back(model_input{ - /*.tokens =*/curr_input_ids[bs].data(), - /*.n_tokens =*/(uint32_t)curr_input_ids[bs].size(), + std::vector inputs = {model_input{ + /*.tokens =*/curr_input_ids[0].data(), + /*.n_tokens =*/(uint32_t)curr_input_ids[0].size(), /*.n_prompt_tokens =*/0, /*.n_past =*/(uint32_t)n_past, /*.n_total =*/(uint32_t)n_total, @@ -382,133 +426,30 @@ std::vector> Model::generate_tokens(const std::vector next_token_ids = post_process(logits); - for (int bs = 0; bs < next_token_ids.size(); ++bs) { - // padding eos seq for continuous batched kv cache - // TODO batch reduction after for-loop attention implementation - if (curr_input_ids[bs].back() == ctx->vocab.eos_token_id || curr_input_ids[bs].back() == ctx->vocab.pad_token_id) { - curr_input_ids[bs] = {ctx->vocab.pad_token_id}; - } else { - curr_input_ids[bs] = {next_token_ids[bs]}; + }}; + model_eval(ctx, inputs.data(), inputs.size(), params.n_threads); + n_past += curr_input_ids[0].size(); + n_total += curr_input_ids[0].size(); + + float* logits = model_get_logits(ctx); + std::vector next_token_id = post_process(logits); + curr_input_ids[0] = {next_token_id[0]}; + output_ids.push_back(next_token_id[0]); + generate_count++; + if (next_token_id[0] == ctx->vocab.eos_token_id) { + token_eos = true; + break; + } + if (params.n_predict > 0 && generate_count >= params.n_predict) { + token_eos = true; + break; } } - - generate_count++; - return {next_token_ids}; + rets.push_back(output_ids); + return rets; } -// std::vector> Model::generate_tokens(const std::vector>& input_ids) -// { -// int n_remain = params.n_predict; -// std::vector output_ids; -// std::vector> rets; - -// if (ctx->beam_search) { -// MODEL_ASSERT(input_ids.size() == ctx->batch_size); -// if (ctx->batch_size > 1 && ctx->vocab.pad_token_id == -1) { -// fprintf(stderr, "\nERROR: please set pad_token for beam search multi-batch generation!\n"); -// return rets; -// } -// std::vector inputs; -// for (int bs = 0; bs < input_ids.size(); ++bs) { -// uint32_t count = 0; -// model_vocab::id pad_token_id = ctx->vocab.pad_token_id; -// auto iter = std::find_if(input_ids[bs].begin(), input_ids[bs].end(), -// [&pad_token_id](model_token t) { return (t != pad_token_id); }); -// if (iter == input_ids[bs].end()) fprintf(stderr, "\nERROR: there are all pad tokens in batch %d!\n", bs); -// count = std::distance(input_ids[bs].begin(), iter); -// inputs.push_back(model_input{ -// /*.tokens =*/input_ids[bs].data(), -// /*.n_tokens =*/(uint32_t)input_ids[bs].size(), -// /*.n_prompt_tokens =*/0, -// /*.n_past =*/0, -// /*.n_total =*/0, -// /*.request_idx =*/bs, -// /*.beam_idx =*/0, -// /*.padding_side =*/0, -// /*n_padding =*/count, -// }); -// } -// return post_beam_search(ctx, n_remain, inputs, params.n_threads); -// } -// if (input_ids.size() > 1) { -// fprintf(stderr, "\nERROR: Only beam search supports multi-batch generation!\n"); -// return rets; -// } - -// if (curr_input_ids.empty()) { -// if (input_ids[0].size() > n_ctx - 4) { -// fprintf(stderr, "\n%s: Warning: prompt is too long (%d tokens, max %d), will be truncated\n", __func__, -// input_ids[0].size(), n_ctx - 4); -// curr_input_ids.resize(n_ctx - 4); -// std::copy(input_ids[0].end() - n_ctx - 4, input_ids[0].end(), curr_input_ids.begin()); -// } else { -// curr_input_ids = input_ids[0]; -// } -// } - -// while (output_ids.size() < n_remain) { -// for (auto item : curr_input_ids) { -// last_n_tokens.erase(last_n_tokens.begin()); -// last_n_tokens.push_back(item); -// } -// // infinite text generation via context swapping -// if (n_past + curr_input_ids.size() > n_ctx) { -// // always keep the first token -// n_past = std::max(1, params.n_keep); - -// int n_discard = params.n_discard; -// if (!params.shift_roped_k) { // shift_roped_k can use ring-buffer and thus does not need re-computing -// if (n_discard == -1) n_discard = (n_ctx - curr_input_ids.size() - params.n_keep) / 2; -// // drop n_discard tokens -// curr_input_ids.insert(curr_input_ids.begin(), last_n_tokens.begin() + params.n_keep + n_discard, -// last_n_tokens.end() - curr_input_ids.size()); -// } else { -// NE_ASSERT(("n_discard cannot be used with shift_roped_k!", n_discard == -1 || n_discard == 1)); -// } -// } -// std::vector inputs = {model_input{ -// /*.tokens =*/curr_input_ids.data(), -// /*.n_tokens =*/(uint32_t)curr_input_ids.size(), -// /*.n_prompt_tokens =*/0, -// /*.n_past =*/(uint32_t)n_past, -// /*.n_total =*/(uint32_t)n_total, -// /*.request_idx =*/0, -// /*.beam_idx =*/0, -// /*.padding_side =*/0, -// /*n_padding =*/0, -// }}; -// model_eval(ctx, inputs.data(), inputs.size(), params.n_threads); -// n_past += curr_input_ids.size(); -// n_total += curr_input_ids.size(); - -// float* logits = model_get_logits(ctx); -// model_token next_token_id = post_process(logits); -// curr_input_ids = {next_token_id}; -// output_ids.push_back(next_token_id); -// generate_count++; -// if (next_token_id == ctx->vocab.eos_token_id) { -// token_eos = true; -// break; -// } -// if (params.n_predict > 0 && generate_count >= params.n_predict) { -// token_eos = true; -// break; -// } -// } -// rets.push_back(output_ids); -// return rets; -// } - -std::vector Model::post_greedy_search(float* logits) { +std::vector Model::post_greedy_search(const float* logits) { std::vector ids(ctx->batch_size); #pragma omp parallel for for (int bs = 0; bs < ctx->batch_size; ++bs) { @@ -530,7 +471,7 @@ std::vector> Model::post_beam_search(model_context* lct } } -std::vector Model::post_sample_top_k_top_p_repeat(float* logits) { +std::vector Model::post_sample_top_k_top_p_repeat(const float* logits) { int alpha_frequency = 0; int alpha_presence = 0; int repeat_last_n = 64; @@ -570,7 +511,7 @@ std::vector Model::post_sample_top_k_top_p_repeat(float* logits) { return ids; } -std::vector Model::post_process(float* logits) { +std::vector Model::post_process(const float* logits) { assert(("Beam search does not support streaming.", params.beam_size == 1)); if (params.do_sample == false) { return post_greedy_search(logits); @@ -738,6 +679,7 @@ PYBIND11_MODULE(qwen_cpp, m) .def("generate", &Model::generate, "Generate token with input ids", py::arg("input_ids")) .def("evaluate", &Model::evaluate, "Evaluate token with input ids and output logits", py::arg("input_ids") = std::vector>{}) + // deprecated API .def("generate_tokens", &Model::generate_tokens, "Generate tokens with input ids", py::arg("input_ids")) .def_static("quant_model", &Model::quant_model, "Quantize model", py::arg("model_path"), py::arg("out_path"), py::arg("weight_dtype") = "int4", py::arg("alg") = "sym", py::arg("group_size") = 32, From 0fb170cbcd70351bff42e0553fcbbbd786a9756b Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Mon, 20 Nov 2023 08:24:19 +0000 Subject: [PATCH 09/18] fix batch infernce after merge Signed-off-by: Yu, Zhentao --- .../runtime/graph/application/main_pybind.cpp | 55 ++++++++++--------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp index ccec7feaf34..6ab5cef07a2 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp @@ -212,31 +212,36 @@ void Model::reinit() { } bool Model::check_input_and_count_padding(const std::vector>& input_ids) { - if (input_ids.empty()) return false; - if (input_ids.size() == 1) { + if (input_ids.empty()) { // next token generation (internal) + if (curr_input_ids.empty()) { + fprintf(stderr, "%s: error: no input\n", __func__); + return false; + } + return true; + } else if (input_ids.size() == 1) { padding_count = {0}; return true; + } else { // multi-batch inputs (first token) + MODEL_ASSERT(input_ids.size() == ctx->batch_size); + static std::set batched_model_archs = {MODEL_GPTJ, MODEL_GPTNEOX, MODEL_CHATGLM}; + if (batched_model_archs.count(params.model_arch) == 0) { + fprintf(stderr, "\nERROR: Only gpt-j, gpt-neox, chatglm support multi-batch generation!\n"); + return false; + } + if (ctx->vocab.pad_token_id == -1) { + fprintf(stderr, "\nERROR: please set pad_token for static multi-batch generation (tokenizer.pad_token_id)!\n"); + return false; + } + if (!padding_count.empty()) padding_count.clear(); + for (int bs = 0; bs < input_ids.size(); ++bs) { + model_vocab::id pad_token_id = ctx->vocab.pad_token_id; + auto iter = std::find_if(input_ids[bs].begin(), input_ids[bs].end(), + [&pad_token_id](model_token t) { return (t != pad_token_id); }); + if (iter == input_ids[bs].end()) fprintf(stderr, "\nERROR: there are all pad tokens in batch %d!\n", bs); + padding_count.push_back(std::distance(input_ids[bs].begin(), iter)); + } + return true; } - // multi-batch inputs - MODEL_ASSERT(input_ids.size() == ctx->batch_size); - static std::set batched_model_archs = {MODEL_GPTJ, MODEL_GPTNEOX, MODEL_CHATGLM}; - if (batched_model_archs.count(params.model_arch) == 0) { - fprintf(stderr, "\nERROR: Only gpt-j, gpt-neox, chatglm support multi-batch generation!\n"); - return false; - } - if (ctx->vocab.pad_token_id == -1) { - fprintf(stderr, "\nERROR: please set pad_token for static multi-batch generation (tokenizer.pad_token_id)!\n"); - return false; - } - if (!padding_count.empty()) padding_count.clear(); - for (int bs = 0; bs < input_ids.size(); ++bs) { - model_vocab::id pad_token_id = ctx->vocab.pad_token_id; - auto iter = std::find_if(input_ids[bs].begin(), input_ids[bs].end(), - [&pad_token_id](model_token t) { return (t != pad_token_id); }); - if (iter == input_ids[bs].end()) fprintf(stderr, "\nERROR: there are all pad tokens in batch %d!\n", bs); - padding_count.push_back(std::distance(input_ids[bs].begin(), iter)); - } - return true; } std::vector> Model::beam_generate(const std::vector>& input_ids) { @@ -262,7 +267,7 @@ const std::vector& Model::evaluate_(const std::vector empty_id{}; std::vector inputs; - for (int bs = 0; bs < input_ids.size(); ++bs) { + for (int bs = 0; bs < ctx->batch_size; ++bs) { const auto& input_id_cb = input_ids.empty() ? empty_id : input_ids[bs]; if (input_id_cb.empty()) { // use internel input id if (curr_input_ids[bs].empty()) { @@ -318,7 +323,6 @@ const std::vector& Model::evaluate_(const std::vectorlogits; } @@ -330,6 +334,7 @@ std::vector> Model::generate(const std::vector next_token_ids = post_process(logits.data()); + MODEL_ASSERT(next_token_ids.size() == ctx->batch_size); std::vector> ret_next_tokens; for (int bs = 0; bs < next_token_ids.size(); ++bs) { // padding eos seq for continuous batched kv cache @@ -679,7 +684,7 @@ PYBIND11_MODULE(qwen_cpp, m) .def("generate", &Model::generate, "Generate token with input ids", py::arg("input_ids")) .def("evaluate", &Model::evaluate, "Evaluate token with input ids and output logits", py::arg("input_ids") = std::vector>{}) - // deprecated API + // deprecated API .def("generate_tokens", &Model::generate_tokens, "Generate tokens with input ids", py::arg("input_ids")) .def_static("quant_model", &Model::quant_model, "Quantize model", py::arg("model_path"), py::arg("out_path"), py::arg("weight_dtype") = "int4", py::arg("alg") = "sym", py::arg("group_size") = 32, From d9ddc43470e81de2c3a38fc1852efb4a64143286 Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Mon, 20 Nov 2023 09:02:41 +0000 Subject: [PATCH 10/18] fix chatglm convert Signed-off-by: Yu, Zhentao --- .../llm/runtime/graph/scripts/convert_chatglm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/scripts/convert_chatglm.py b/intel_extension_for_transformers/llm/runtime/graph/scripts/convert_chatglm.py index 245a80fd676..dfa7aa9ec03 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/scripts/convert_chatglm.py +++ b/intel_extension_for_transformers/llm/runtime/graph/scripts/convert_chatglm.py @@ -262,10 +262,10 @@ def chatglm1_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", hparams["inner_hidden_size"])) - fout.write(struct.pack("i", int(hparams.get("bos_token_id", -1)))) - fout.write(struct.pack("i", int(hparams.get("eos_token_id", -1)))) - fout.write(struct.pack("i", 0)) - fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else -1)) + fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1)) + fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1)) + fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1)) vocab = load_vocab_for_glm1(Path(dir_model)) counter = 0 From b0e5339080b847d4b9e5add60c7bfd049fce0230 Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Wed, 22 Nov 2023 03:01:59 +0000 Subject: [PATCH 11/18] ne_cont for q4j Signed-off-by: Yu, Zhentao --- .../llm/runtime/graph/models/chatglm/chatglm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp index 763dd7db137..c27dd1443c7 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp @@ -282,8 +282,8 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* lctx.use_buf(ctx0, -1); if (!lctx.logits_all && qlen > 1) { - inpL = ne_view_2d(ctx0, inpL, hidden_size, batch_size, ne_element_size(inpL) * hidden_size * N, - (N - 1) * hidden_size * ne_element_size(inpL)); + inpL = ne_cont(ctx0, ne_view_2d(ctx0, inpL, hidden_size, batch_size, ne_element_size(inpL) * hidden_size * N, + (N - 1) * hidden_size * ne_element_size(inpL))); } // lm_head inpL = ne_mul_mat(ctx0, model.others[3], inpL); From 30eaf13f3e6dbadd07f141ee80f8062fa1c885b2 Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Fri, 24 Nov 2023 08:34:12 +0000 Subject: [PATCH 12/18] fix chatglm rope embd with padding Signed-off-by: Yu, Zhentao --- .../runtime/graph/application/main_pybind.cpp | 7 +++- .../llm/runtime/graph/core/ne_layers.c | 39 +++++++++++++++---- .../llm/runtime/graph/core/ne_layers.h | 7 ++++ .../runtime/graph/models/chatglm/chatglm.cpp | 16 +++----- .../graph/models/model_utils/model_utils.cpp | 2 + 5 files changed, 53 insertions(+), 18 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp index 6ab5cef07a2..4e04c8eca8e 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp @@ -127,6 +127,7 @@ class Model { bool token_eos = false; long int generate_count = 0; std::vector padding_count; + uint32_t n_prompt_tokens = 0; std::vector> beam_generate(const std::vector>& input_ids); std::vector post_process(const float* logits); @@ -209,6 +210,7 @@ void Model::reinit() { ctx->t_sample_us = 0; generate_count = 0; padding_count.clear(); + n_prompt_tokens = 0; } bool Model::check_input_and_count_padding(const std::vector>& input_ids) { @@ -220,6 +222,7 @@ bool Model::check_input_and_count_padding(const std::vectorbatch_size); @@ -240,6 +243,8 @@ bool Model::check_input_and_count_padding(const std::vector& Model::evaluate_(const std::vector= 0 || n_keep >= 0); + NE_ASSERT(padding_left); bool is_node = false; if (!inplace && a->grad) { @@ -2979,13 +2981,22 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int ne_scratch_save(ctx); - struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 5, NE_SIZE_CALC); + const int bs = a->ne[3]; + struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 5 + bs, NE_SIZE_CALC); ((int32_t*)b->data)[0] = n_past; ((int32_t*)b->data)[1] = n_dims; ((int32_t*)b->data)[2] = mode; ((int32_t*)b->data)[3] = prompt_size; ((int32_t*)b->data)[4] = n_keep; // set to non-negative value to enable shift mode + // store n_padding (chatglm position ids) + for (int i = 0; i < bs; ++i) { + if (n_padding == NULL) { + ((int32_t*)b->data)[5 + i] = 0; + } else { + ((int32_t*)b->data)[5 + i] = *(n_padding + i); + } + } ne_scratch_load(ctx); @@ -3000,17 +3011,17 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode, int prompt_size) { - return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL); + return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true); } struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode, int prompt_size) { - return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL); + return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true); } struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims, int mode, int prompt_size, int n_keep, struct ne_tensor* cossin) { - return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin); + return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true); } // ne_rope_back @@ -3045,6 +3056,16 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int return result; } +struct ne_tensor* ne_rope_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode, + int prompt_size, int* n_padding) { + return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, n_padding, true); +} + +struct ne_tensor* ne_rope_with_padding_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, + int mode, int prompt_size, int* n_padding) { + return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true); +} + // ne_alibi struct ne_tensor* ne_alibi(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_head, float bias_max) { @@ -7835,8 +7856,9 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { return; } + const int bs = src0->ne[3]; NE_ASSERT(src1->type == NE_TYPE_I32); - NE_ASSERT(ne_nelements(src1) == 5); // 5 params + NE_ASSERT(ne_nelements(src1) == 5 + bs); // 5 + bs params static const float freq_base = 10000.0f; static const float freq_scale = 1.0f; @@ -7879,6 +7901,7 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, const bool is_shift = n_keep >= 0; NE_ASSERT(("RoPE shift not supported!", !is_shift)); + NE_ASSERT(ne3 == bs); for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = (skip ? n_past : 0); i2 < ne2; i2++) { const int64_t p = skip ? i2 : n_past + i2; @@ -7890,7 +7913,9 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, // only for glm when mode == 4 if (is_glm) { - theta = MIN(p, prompt_size - 2); + const int64_t n_padding = ((int32_t*)src1->data)[5 + i3]; + // position ids + theta = MIN(MAX(p - n_padding, 0), prompt_size - 2 - n_padding); float block_theta = MAX(p - (prompt_size - 2), 0); for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { const float cos_theta = cosf(theta); diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.h b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.h index 193c22f7e0d..731680f2ad6 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.h +++ b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.h @@ -418,6 +418,13 @@ NE_API struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne // a - dy NE_API struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode); +NE_API struct ne_tensor* ne_rope_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, + int mode, int prompt_size, int* n_padding); + +// in-place, returns view(a) +NE_API struct ne_tensor* ne_rope_with_padding_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, + int n_dims, int mode, int prompt_size, int* n_padding); + // alibi position embedding // in-place, returns view(a) struct ne_tensor* ne_alibi(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_head, float bias_max); diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp index c27dd1443c7..0e57b691de4 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp @@ -45,8 +45,6 @@ // - n_threads: number of threads to use // -static int flag = 0; -static int first_tokens_size = 0; static bool chatglm_model_eval_internal(model_context& lctx, const model_input* inputs, const int n_input, const int n_threads) { const int64_t t_start_us = ne_time_us(); @@ -66,6 +64,7 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* n_padding.push_back((inputs + i)->n_padding); if (no_padding && (inputs + i)->n_padding != 0) no_padding = false; } + const int first_tokens_size = inputs->n_prompt_tokens; const auto& model = lctx.model; const auto& hparams = model.hparams; @@ -79,11 +78,6 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* const int n_ctx = lctx.n_ctx; const int n_keep = lctx.n_keep; - if (flag == 0) { - first_tokens_size = N; - flag++; - } - const int n_head = hparams.n_head; const int n_vocab = hparams.n_vocab; const int n_rot = n_embd / n_head / 2; @@ -141,13 +135,15 @@ static bool chatglm_model_eval_internal(model_context& lctx, const model_input* cur->nb[1] * N, 0); // [qlen * bs, 3 * hidden] ne_set_name(query_layer, "query_layer"); - query_layer = ne_rope_inplace(ctx0, query_layer, n_past, rope_dim, 4, first_tokens_size); + query_layer = + ne_rope_with_padding_inplace(ctx0, query_layer, n_past, rope_dim, 4, first_tokens_size, n_padding.data()); query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [bs, heads, qlen, head_size] ne_tensor* key_layer = ne_view_4d(ctx0, cur, head_size, num_attention_heads, qlen, batch_size, 3 * head_size * ne_element_size(cur), - cur->nb[1], cur->nb[1] * qlen, head_size * ne_element_size(cur)); - key_layer = ne_rope_inplace(ctx0, key_layer, n_past, rope_dim, 4, first_tokens_size); // [qlen, heads, head_size] + cur->nb[1], cur->nb[1] * qlen, head_size * ne_element_size(cur)); // [bs, qlen, heads, head_size] + key_layer = + ne_rope_with_padding_inplace(ctx0, key_layer, n_past, rope_dim, 4, first_tokens_size, n_padding.data()); ne_tensor* value_layer = ne_view_4d(ctx0, cur, head_size, num_attention_heads, qlen, batch_size, 3 * head_size * ne_element_size(cur), cur->nb[1], cur->nb[1] * qlen, diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp index b8ff4c0a2e1..1dc9e5fc1ea 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp @@ -1175,6 +1175,8 @@ struct model_context* model_init_from_file(const char* path_model, struct model_ ctx->beam_search = true; ctx->beam_size = params.beam_size; ctx->kv_n_ctx_block = ctx->batch_size * ctx->beam_size; + } else { + ctx->kv_n_ctx_block = ctx->batch_size; } const model_archs arch = params.arch; From 0b0be5ffd2073554b63e105b0666e4d7c64a933f Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Wed, 29 Nov 2023 08:32:47 +0000 Subject: [PATCH 13/18] add disable_vec_dot_fp16_simd option Signed-off-by: Yu, Zhentao --- .../llm/runtime/graph/CMakeLists.txt | 6 +++++- .../llm/runtime/graph/application/main_pybind.cpp | 3 ++- .../llm/runtime/graph/core/layers/vec_dot.h | 3 ++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/CMakeLists.txt b/intel_extension_for_transformers/llm/runtime/graph/CMakeLists.txt index 80d4a800273..e54d7ead742 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/CMakeLists.txt +++ b/intel_extension_for_transformers/llm/runtime/graph/CMakeLists.txt @@ -91,7 +91,11 @@ option(NE_GELU_VEC "neural_engine: enable vec in gelu" if (NE_GELU_VEC) add_compile_definitions(NE_GELU_USE_VEC) endif() -option(NE_PYTHON_API "neural_engine: use python api" OFF) +option(NE_PYTHON_API "neural_engine: use python api" OFF) +option(NE_SIMD_VEC_DOT_F16 "neural_engine: enable vec_dot_fp16 SIMD optimization" ON) +if (NE_SIMD_VEC_DOT_F16) + add_compile_definitions(NE_SIMD_VEC_DOT_F16) +endif() if(NE_BUILD_TESTS) enable_testing() diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp index 4e04c8eca8e..a2da4f3a674 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp @@ -491,7 +491,8 @@ std::vector Model::post_sample_top_k_top_p_repeat(const float* logi float top_p = params.top_p; float temp = params.temp; std::vector ids(ctx->batch_size); -#pragma omp parallel for + // #pragma omp parallel for // omp will affect sampling positions in batch infer + // TODO (make sample functions support batch processing) for (int bs = 0; bs < ctx->batch_size; ++bs) { std::vector candidates; candidates.reserve(n_vocab); diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/vec_dot.h b/intel_extension_for_transformers/llm/runtime/graph/core/layers/vec_dot.h index 954c382340b..2529320175a 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/vec_dot.h +++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/vec_dot.h @@ -94,7 +94,8 @@ static void ne_vec_dot_f32(const int n, float* restrict s, const float* restrict static void ne_vec_dot_f16(const int n, float* restrict s, ne_fp16_t* restrict x, ne_fp16_t* restrict y) { ne_float sumf = 0.0; -#if defined(NE_SIMD) +// NE_SIMD_VEC_DOT_F16 (sum order may affect logits, like padding and no padding) +#if defined(NE_SIMD) && defined(NE_SIMD_VEC_DOT_F16) const int np = (n & ~(NE_F16_STEP - 1)); NE_F16_VEC sum[NE_F16_ARR] = {NE_F16_VEC_ZERO}; From 44cedba01483176590a0b6d1ae8338d7ef4b3cff Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Wed, 29 Nov 2023 09:06:14 +0000 Subject: [PATCH 14/18] fix compiler error Signed-off-by: Yu, Zhentao --- .../llm/runtime/graph/__init__.py | 6 ++---- .../llm/runtime/graph/core/ne_layers.c | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/__init__.py b/intel_extension_for_transformers/llm/runtime/graph/__init__.py index c29d52fa049..e4e8b7b68a4 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/__init__.py +++ b/intel_extension_for_transformers/llm/runtime/graph/__init__.py @@ -191,14 +191,12 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa if stopping_criteria is not None: if stopping_criteria(torch.tensor(ret), None): break - elif (max_new_tokens != -1 and out_count > max_new_tokens): + elif (max_new_tokens != -1 and out_count >= max_new_tokens): break else: - all_done = [(r[-1] in [self.eos_token_id(), self.pad_token_id()]) - for r in ret] + all_done = [(r[-1] in [self.eos_token_id(), self.pad_token_id()]) for r in ret] if False not in all_done: break - out_count += 1 if streamer: streamer.end() diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c index d35702d706b..b2fd9230e48 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c +++ b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c @@ -433,7 +433,7 @@ static const char* NE_OP_LABEL[NE_OP_COUNT] = { "DEBUG", }; -static_assert(NE_OP_COUNT == 63, "NE_OP_COUNT != 63"); +static_assert(NE_OP_COUNT == 64, "NE_OP_COUNT != 64"); static const char* NE_OP_SYMBOL[NE_OP_COUNT] = { "none", From 78dcb0fd103dd0da297a7ab4b82cbf167d35dc98 Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Thu, 30 Nov 2023 02:34:12 +0000 Subject: [PATCH 15/18] clean code Signed-off-by: Yu, Zhentao --- intel_extension_for_transformers/llm/runtime/graph/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/__init__.py b/intel_extension_for_transformers/llm/runtime/graph/__init__.py index e4e8b7b68a4..d7b2b6bd217 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/__init__.py +++ b/intel_extension_for_transformers/llm/runtime/graph/__init__.py @@ -161,9 +161,6 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa beam_search = False if (generate_kwargs.get("num_beams", 1) > 1) and not generate_kwargs.get("do_sample", False): beam_search = True - # if not beam_search: - # # TODO support multi batch - # assert input_ids.shape[0] == 1, "Unsupport multi-batch input ids." if streamer: assert input_ids.shape[0] == 1, "Streamer only supports batch size 1." From 6faae24477b418c33316ea0790997983f08cc356 Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Thu, 30 Nov 2023 07:40:04 +0000 Subject: [PATCH 16/18] add memory.h Signed-off-by: Yu, Zhentao --- .../runtime/graph/application/main_pybind.cpp | 52 ++++++++++--------- .../llm/runtime/graph/core/layers/layers.h | 1 + .../llm/runtime/graph/core/layers/memory.cpp | 31 +++++++++++ .../llm/runtime/graph/core/layers/memory.h | 29 +++++++++++ .../llm/runtime/graph/core/ne_layers.c | 51 +++++++++--------- 5 files changed, 116 insertions(+), 48 deletions(-) create mode 100644 intel_extension_for_transformers/llm/runtime/graph/core/layers/memory.cpp create mode 100644 intel_extension_for_transformers/llm/runtime/graph/core/layers/memory.h diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp index a2da4f3a674..b186bf207f6 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp @@ -58,6 +58,7 @@ std::shared_ptr get_model_quant_layer(const std::string model_ return ql_registry::create_ql(model_name); } +#define STATIC_INPUT_HEAD_IDX 0 class Model { public: Model() { model_init_backend(); } @@ -222,7 +223,7 @@ bool Model::check_input_and_count_padding(const std::vectorbatch_size); @@ -244,7 +245,7 @@ bool Model::check_input_and_count_padding(const std::vector& Model::evaluate_(const std::vectorlogits; } @@ -395,40 +396,43 @@ std::vector> Model::generate_tokens(const std::vector n_ctx - 4) { + if (curr_input_ids[STATIC_INPUT_HEAD_IDX].empty()) { + if (input_ids[STATIC_INPUT_HEAD_IDX].size() > n_ctx - 4) { fprintf(stderr, "\n%s: Warning: prompt is too long (%d tokens, max %d), will be truncated\n", __func__, - input_ids[0].size(), n_ctx - 4); - curr_input_ids[0].resize(n_ctx - 4); - std::copy(input_ids[0].end() - n_ctx - 4, input_ids[0].end(), curr_input_ids[0].begin()); + input_ids[STATIC_INPUT_HEAD_IDX].size(), n_ctx - 4); + curr_input_ids[STATIC_INPUT_HEAD_IDX].resize(n_ctx - 4); + std::copy(input_ids[STATIC_INPUT_HEAD_IDX].end() - n_ctx - 4, input_ids[STATIC_INPUT_HEAD_IDX].end(), + curr_input_ids[STATIC_INPUT_HEAD_IDX].begin()); } else { - curr_input_ids[0] = input_ids[0]; + curr_input_ids[STATIC_INPUT_HEAD_IDX] = input_ids[STATIC_INPUT_HEAD_IDX]; } } while (output_ids.size() < n_remain) { - for (auto item : curr_input_ids[0]) { - last_n_tokens[0].erase(last_n_tokens[0].begin()); - last_n_tokens[0].push_back(item); + for (auto item : curr_input_ids[STATIC_INPUT_HEAD_IDX]) { + last_n_tokens[STATIC_INPUT_HEAD_IDX].erase(last_n_tokens[STATIC_INPUT_HEAD_IDX].begin()); + last_n_tokens[STATIC_INPUT_HEAD_IDX].push_back(item); } // infinite text generation via context swapping - if (n_past + curr_input_ids[0].size() > n_ctx) { + if (n_past + curr_input_ids[STATIC_INPUT_HEAD_IDX].size() > n_ctx) { // always keep the first token n_past = std::max(1, params.n_keep); int n_discard = params.n_discard; if (!params.shift_roped_k) { // shift_roped_k can use ring-buffer and thus does not need re-computing - if (n_discard == -1) n_discard = (n_ctx - curr_input_ids[0].size() - params.n_keep) / 2; + if (n_discard == -1) n_discard = (n_ctx - curr_input_ids[STATIC_INPUT_HEAD_IDX].size() - params.n_keep) / 2; // drop n_discard tokens - curr_input_ids[0].insert(curr_input_ids[0].begin(), last_n_tokens[0].begin() + params.n_keep + n_discard, - last_n_tokens[0].end() - curr_input_ids[0].size()); + curr_input_ids[STATIC_INPUT_HEAD_IDX].insert( + curr_input_ids[STATIC_INPUT_HEAD_IDX].begin(), + last_n_tokens[STATIC_INPUT_HEAD_IDX].begin() + params.n_keep + n_discard, + last_n_tokens[STATIC_INPUT_HEAD_IDX].end() - curr_input_ids[STATIC_INPUT_HEAD_IDX].size()); } else { NE_ASSERT(("n_discard cannot be used with shift_roped_k!", n_discard == -1 || n_discard == 1)); } } std::vector inputs = {model_input{ - /*.tokens =*/curr_input_ids[0].data(), - /*.n_tokens =*/(uint32_t)curr_input_ids[0].size(), + /*.tokens =*/curr_input_ids[STATIC_INPUT_HEAD_IDX].data(), + /*.n_tokens =*/(uint32_t)curr_input_ids[STATIC_INPUT_HEAD_IDX].size(), /*.n_prompt_tokens =*/0, /*.n_past =*/(uint32_t)n_past, /*.n_total =*/(uint32_t)n_total, @@ -438,15 +442,15 @@ std::vector> Model::generate_tokens(const std::vector next_token_id = post_process(logits); - curr_input_ids[0] = {next_token_id[0]}; - output_ids.push_back(next_token_id[0]); + curr_input_ids[STATIC_INPUT_HEAD_IDX] = {next_token_id[STATIC_INPUT_HEAD_IDX]}; + output_ids.push_back(next_token_id[STATIC_INPUT_HEAD_IDX]); generate_count++; - if (next_token_id[0] == ctx->vocab.eos_token_id) { + if (next_token_id[STATIC_INPUT_HEAD_IDX] == ctx->vocab.eos_token_id) { token_eos = true; break; } diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/layers.h b/intel_extension_for_transformers/llm/runtime/graph/core/layers/layers.h index 144073095af..ba7ba081618 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/layers.h +++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/layers.h @@ -15,3 +15,4 @@ #include "ele_reduce.h" #include "conv.h" +#include "memory.h" diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/memory.cpp b/intel_extension_for_transformers/llm/runtime/graph/core/layers/memory.cpp new file mode 100644 index 00000000000..3bf8b83e761 --- /dev/null +++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/memory.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "memory.h" + +void ne_attention_padding_mask_f32_forward(const int bs, const int nr_qk, const int qlen, const int ith, const int nth, + const void* padding, const float p_value, struct ne_tensor* dst) { + // mask padding token (padding left) + for (int b = 0; b < bs; b++) { + const int n_padding = ((int32_t*)padding)[b]; + if (n_padding == 0) continue; + for (int k = 0; k < (nr_qk / bs); k++) { + for (int j = ith; j < qlen; j += nth) { + // it will not affect next token if don't mask the pad_token row + ne_vec_set_f32(n_padding, (float*)((char*)dst->data + b * dst->nb[3] + k * dst->nb[2] + j * dst->nb[1]), + p_value); + } + } + } +} diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/memory.h b/intel_extension_for_transformers/llm/runtime/graph/core/layers/memory.h new file mode 100644 index 00000000000..f37a408c578 --- /dev/null +++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/memory.h @@ -0,0 +1,29 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "ele_wise.h" +#include "core/ne.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ne_attention_padding_mask_f32_forward(const int bs, const int nr_qk, const int qlen, const int ith, const int nth, + const void* padding, const float p_value, struct ne_tensor* dst); + +#ifdef __cplusplus +} +#endif diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c index b2fd9230e48..34c15785521 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c +++ b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c @@ -2906,16 +2906,21 @@ struct ne_tensor* ne_padding_mask_inf_impl(struct ne_context* ctx, struct ne_ten ne_scratch_save(ctx); +#define PM_PARAMS_NUM 2 +#define PM_NPAST_IDX 0 +#define PM_INPLACE_IDX 1 +#define PM_PADDING_IDX 2 + const int bs = a->ne[3]; - struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 2 + bs, NE_SIZE_CALC); + struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, PM_PARAMS_NUM + bs, NE_SIZE_CALC); - ((int32_t*)b->data)[0] = n_past; - ((int32_t*)b->data)[1] = inplace ? 1 : 0; + ((int32_t*)b->data)[PM_NPAST_IDX] = n_past; + ((int32_t*)b->data)[PM_INPLACE_IDX] = inplace ? 1 : 0; for (int i = 0; i < bs; ++i) { if (n_padding == NULL) { - ((int32_t*)b->data)[2 + i] = 0; + ((int32_t*)b->data)[PM_PADDING_IDX + i] = 0; } else { - ((int32_t*)b->data)[2 + i] = *(n_padding + i); + ((int32_t*)b->data)[PM_PADDING_IDX + i] = *(n_padding + i); } } @@ -2981,20 +2986,28 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int ne_scratch_save(ctx); +#define ROPE_PARAMS_NUM 5 +#define ROPE_NPAST_IDX 0 +#define ROPE_NDIMS_IDX 1 +#define ROPE_MODE_IDX 2 +#define ROPE_PROMPTSIZE_IDX 3 +#define ROPE_NKEEP_IDX 4 +#define ROPE_PADDING_IDX 5 + const int bs = a->ne[3]; - struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 5 + bs, NE_SIZE_CALC); + struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, ROPE_PARAMS_NUM + bs, NE_SIZE_CALC); - ((int32_t*)b->data)[0] = n_past; - ((int32_t*)b->data)[1] = n_dims; - ((int32_t*)b->data)[2] = mode; - ((int32_t*)b->data)[3] = prompt_size; - ((int32_t*)b->data)[4] = n_keep; // set to non-negative value to enable shift mode + ((int32_t*)b->data)[ROPE_NPAST_IDX] = n_past; + ((int32_t*)b->data)[ROPE_NDIMS_IDX] = n_dims; + ((int32_t*)b->data)[ROPE_MODE_IDX] = mode; + ((int32_t*)b->data)[ROPE_PROMPTSIZE_IDX] = prompt_size; + ((int32_t*)b->data)[ROPE_NKEEP_IDX] = n_keep; // set to non-negative value to enable shift mode // store n_padding (chatglm position ids) for (int i = 0; i < bs; ++i) { if (n_padding == NULL) { - ((int32_t*)b->data)[5 + i] = 0; + ((int32_t*)b->data)[ROPE_PADDING_IDX + i] = 0; } else { - ((int32_t*)b->data)[5 + i] = *(n_padding + i); + ((int32_t*)b->data)[ROPE_PADDING_IDX + i] = *(n_padding + i); } } @@ -7524,17 +7537,7 @@ static void ne_compute_forward_padding_mask_f32(const struct ne_compute_params* assert(dst->nb[0] == sizeof(float)); assert(src0->nb[0] == sizeof(float)); - // mask padding token (padding left) - for (int b = 0; b < bs; b++) { - const int n_padding = ((int32_t*)src1->data)[2 + b]; - if (n_padding == 0) continue; - for (int k = 0; k < (nz / bs); k++) { - for (int j = ith; j < nr; j += nth) { - // it will not affect next token if don't mask the pad_token row - ne_vec_set_f32(n_padding, (float*)((char*)dst->data + b * dst->nb[3] + k * dst->nb[2] + j * dst->nb[1]), value); - } - } - } + ne_attention_padding_mask_f32_forward(bs, nz, nr, ith, nth, src1->data + 2 * ne_element_size(src1), value, dst); } static void ne_compute_forward_padding_mask_inf(const struct ne_compute_params* params, const struct ne_tensor* src0, From bb2b88e504a72c17186d9cabf12e5afb1156df40 Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Fri, 1 Dec 2023 03:04:18 +0000 Subject: [PATCH 17/18] update omp Signed-off-by: Yu, Zhentao --- .../runtime/graph/application/main_pybind.cpp | 21 +++++++++++++++++-- .../llm/runtime/graph/core/ne_layers.c | 19 +++++++++-------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp index b186bf207f6..ae53be1063f 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp @@ -129,6 +129,7 @@ class Model { long int generate_count = 0; std::vector padding_count; uint32_t n_prompt_tokens = 0; + std::vector times; std::vector> beam_generate(const std::vector>& input_ids); std::vector post_process(const float* logits); @@ -465,9 +466,25 @@ std::vector> Model::generate_tokens(const std::vector Model::post_greedy_search(const float* logits) { std::vector ids(ctx->batch_size); -#pragma omp parallel for + static int n_vocab_segment = 1024; + int num_segments = (n_vocab + n_vocab_segment - 1) / n_vocab_segment; + std::vector candidate_tokens(ctx->batch_size * num_segments); + std::vector candidate_logits(ctx->batch_size * num_segments); +#pragma omp parallel for collapse(2) for (int bs = 0; bs < ctx->batch_size; ++bs) { - ids[bs] = std::max_element(logits + bs * n_vocab, logits + (bs + 1) * n_vocab) - (logits + bs * n_vocab); + for (int vocab = 0; vocab < n_vocab; vocab += n_vocab_segment) { + auto max_e = + std::max_element(logits + bs * n_vocab + vocab, vocab + n_vocab_segment > n_vocab + ? logits + bs * n_vocab + n_vocab + : logits + bs * n_vocab + vocab + n_vocab_segment); + candidate_tokens[bs * num_segments + vocab / n_vocab_segment] = max_e - (logits + bs * n_vocab); + candidate_logits[bs * num_segments + vocab / n_vocab_segment] = *max_e; + } + } + for (int bs = 0; bs < ctx->batch_size; ++bs) { + ids[bs] = candidate_tokens[std::distance(candidate_logits.begin(), + std::max_element(candidate_logits.begin() + bs * num_segments, + candidate_logits.begin() + (bs + 1) * num_segments))]; } return ids; } diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c index 34c15785521..d3f888cd3fd 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c +++ b/intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c @@ -7510,8 +7510,8 @@ static void ne_compute_forward_padding_mask_f32(const struct ne_compute_params* const int ith = params->ith; const int nth = params->nth; - const int n_past = ((int32_t*)src1->data)[0]; - const bool inplace = (bool)((int32_t*)src1->data)[1]; + const int n_past = ((int32_t*)src1->data)[PM_NPAST_IDX]; + const bool inplace = (bool)((int32_t*)src1->data)[PM_INPLACE_IDX]; assert(n_past >= 0); @@ -7537,7 +7537,8 @@ static void ne_compute_forward_padding_mask_f32(const struct ne_compute_params* assert(dst->nb[0] == sizeof(float)); assert(src0->nb[0] == sizeof(float)); - ne_attention_padding_mask_f32_forward(bs, nz, nr, ith, nth, src1->data + 2 * ne_element_size(src1), value, dst); + ne_attention_padding_mask_f32_forward(bs, nz, nr, ith, nth, src1->data + PM_PARAMS_NUM * ne_element_size(src1), value, + dst); } static void ne_compute_forward_padding_mask_inf(const struct ne_compute_params* params, const struct ne_tensor* src0, @@ -7866,11 +7867,11 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, static const float freq_base = 10000.0f; static const float freq_scale = 1.0f; - const int64_t n_past = ((int32_t*)src1->data)[0]; - const int64_t n_dims = ((int32_t*)src1->data)[1]; - const int64_t mode = ((int32_t*)src1->data)[2]; - const int64_t prompt_size = ((int32_t*)src1->data)[3]; - const int64_t n_keep = ((int32_t*)src1->data)[4]; + const int64_t n_past = ((int32_t*)src1->data)[ROPE_NPAST_IDX]; + const int64_t n_dims = ((int32_t*)src1->data)[ROPE_NDIMS_IDX]; + const int64_t mode = ((int32_t*)src1->data)[ROPE_MODE_IDX]; + const int64_t prompt_size = ((int32_t*)src1->data)[ROPE_PROMPTSIZE_IDX]; + const int64_t n_keep = ((int32_t*)src1->data)[ROPE_NKEEP_IDX]; assert(n_past >= 0); @@ -7916,7 +7917,7 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, // only for glm when mode == 4 if (is_glm) { - const int64_t n_padding = ((int32_t*)src1->data)[5 + i3]; + const int64_t n_padding = ((int32_t*)src1->data)[ROPE_PARAMS_NUM + i3]; // position ids theta = MIN(MAX(p - n_padding, 0), prompt_size - 2 - n_padding); float block_theta = MAX(p - (prompt_size - 2), 0); From 1dab032e4dc1e8f35d4ef460f956fc360d082e1e Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Fri, 1 Dec 2023 06:07:47 +0000 Subject: [PATCH 18/18] update reset_token_end() Signed-off-by: Yu, Zhentao --- .../llm/runtime/graph/application/main_pybind.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp index ae53be1063f..be72603397b 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp @@ -87,6 +87,7 @@ class Model { void reset_token_end() { token_eos = false; curr_input_ids.clear(); + curr_input_ids.resize(params.batch_size); generate_count = 0; }