diff --git a/intel_extension_for_transformers/llm/runtime/graph/CMakeLists.txt b/intel_extension_for_transformers/llm/runtime/graph/CMakeLists.txt index 27b4ff63452..39ada4b27a5 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/CMakeLists.txt +++ b/intel_extension_for_transformers/llm/runtime/graph/CMakeLists.txt @@ -83,6 +83,10 @@ option(NE_PROFILING "neural_engine: use Profiling" if (NE_PROFILING) add_compile_definitions(NE_PERF) endif() +option(NE_BEAM_SEARCH_VERBOSE "neural_engine: print beam search processing log" OFF) +if (NE_BEAM_SEARCH_VERBOSE) + add_compile_definitions(NE_BEAM_SEARCH_VERBOSE_ON) +endif() option(NE_GELU_VEC "neural_engine: enable vec in gelu" ON) if (NE_GELU_VEC) add_compile_definitions(NE_GELU_USE_VEC) diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/pybind_gptj.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/pybind_gptj.cpp index 50db74ddaf0..79de24aa495 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/pybind_gptj.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/pybind_gptj.cpp @@ -50,7 +50,8 @@ bool gptj_model_eval_ids(model_context* ctx, model_token* tokens, size_t n_eval, extern "C" { void* init_gptj(int seed, int n_predict, int n_batch, int top_k, float top_p, float temp, float repeat_penalty, bool perplexity, int n_ctx, const char* model_file, bool beam_search = false, int beam_size = 4, - int batch_size = 1, int n_threads = 56, int min_new_tokens = 0, float length_penalty = 1.0) { + int batch_size = 1, int n_threads = 56, int min_new_tokens = 0, float length_penalty = 1.0, + bool do_early_stopping = false) { gpt_params params; params.n_threads = n_threads; params.seed = seed; @@ -68,6 +69,7 @@ void* init_gptj(int seed, int n_predict, int n_batch, int top_k, float top_p, fl params.batch_size = batch_size; params.beam_search = beam_search; params.beam_size = beam_size; + params.memory_type = KV_MEM_TYPE_F16; // TODO MEMORY_AUTO for MHA // params.use_mmap = false; // params.use_mlock= true; model_init_backend(); @@ -80,6 +82,7 @@ void* init_gptj(int seed, int n_predict, int n_batch, int top_k, float top_p, fl } ctx->generation_conf.min_new_tokens = min_new_tokens; ctx->generation_conf.length_penalty = length_penalty; + ctx->generation_conf.do_early_stopping = do_early_stopping; return (void*)ctx; } @@ -220,13 +223,17 @@ int main(int argc, char* argv[]) { return 1; } - auto gptj_in_all_bs = init_gptj(1234, 32, 32, 40, 1.0, 0.8, 1.02, false, 2048, argv[1], true, 4, 1, 56, 30, 1.0); + auto gptj_in_all_bs = + init_gptj(1234, 32, 32, 40, 1.0, 0.8, 1.02, false, 2048, argv[1], true, 4, 1, 56, 30, 1.0, true); std::vector ctxs = {gptj_in_all_bs}; for (auto gptj_in_all : ctxs) { auto res = eval_gptj_char( gptj_in_all, - //"she opened the door and see", + // "she opened the door and see", + // "Once upon a time", + // "Tell me 10 things about jazz music", // "A spaceship lands on the moon", + // "What is the meaning of life?", "2017: It is done, and submitted. You can play 'Survival of the Tastiest' on Android, and on the web. Playing " "on the web works, but you have to simulate multiple touch for table moving and that can be a bit confusing. " "There is a lot I'd like to talk about. I will go through every topic, insted of making the typical what went " diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/gptj/gptj.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/gptj/gptj.cpp index 5960217860d..995e4f8d3e5 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/gptj/gptj.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/gptj/gptj.cpp @@ -225,14 +225,14 @@ static bool gptj_model_eval_internal(model_context& lctx, const model_token* tok std::vector v_bs(batch_size); for (int i = 0; i < batch_size; ++i) { if (run_mha_fp16) { - // batch K + // batch V Vcur_bs[i] = ne_view_4d(ctx0, Vcur, n_embd / n_head, n_head, N, 1, ne_element_size(Vcur) * n_embd / n_head, ne_element_size(Vcur) * n_embd, ne_element_size(Vcur) * n_embd * N, i * ne_element_size(Vcur) * n_embd * N); v_bs[i] = ne_view_1d(ctx0, kv_self.v, n_embd * N * 1, (ne_element_size(kv_self.v) * n_embd) * (il * n_ctx * kv_n_ctx_block + n_past) + i * n_ctx * n_embd * ne_element_size(kv_self.v)); - // batch V + // batch K Kcur_bs[i] = ne_permute(ctx0, ne_reshape_4d(ctx0, ne_view_2d(ctx0, Kcur, n_embd, N, ne_element_size(Kcur) * n_embd, diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_types.h b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_types.h index 809fea7cd34..581feaff785 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_types.h +++ b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_types.h @@ -230,7 +230,7 @@ struct generation_config { // likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while // `length_penalty` < 0.0 encourages shorter sequences. (default = 1.0) float length_penalty = 1.0f; - bool do_early_stopping = false; // TODO + bool do_early_stopping = false; }; class beam_search_kv_cache_reorder; // forward declaration diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp index b1ac55eb4ae..53f85d2a36a 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp @@ -1935,7 +1935,7 @@ std::vector>& model_internal_get_tenso // A struct for calculating logits-related info. struct logits_info { const model_context* const ctx = nullptr; - // (batch, seq_len * vocab_size) + // (batch, seq_len * vocab_size) batch = input_prompt_bs* beam_size const float* const logits = nullptr; const int batch_size; const int32_t n_vocab; @@ -1971,42 +1971,41 @@ struct logits_info { } } - model_token_data get_token_data(const int& batch_idx, const int32_t& token_idx) const { - return {token_idx, *(logits + batch_idx * bs_stride + offset + token_idx), 0.0f}; + beam_next_token get_token_data(const int& batch_idx, const int32_t& token_idx) const { + return {token_idx, *(logits + batch_idx * bs_stride + offset + token_idx), -1}; } - // Return top k token_data by logit. (batch, top_k) - std::vector> top_k(const int& k) { - std::vector> min_heap(batch_size); // min-heap by logit + float probability_from_logit(const int& batch_idx, const float& logit) { + return normalizers[batch_idx] * std::exp(logit - max_ls[batch_idx]); + } + + float log_probability_from_logit(const int& batch_idx, const float& logit) { + return std::log(probability_from_logit(batch_idx, logit)); + } + + // Return top k token_data by raw logit in n_vocab dim. (request_bs*num_beam, top_k) + std::vector> vocab_top_k(const int& k) { + std::vector> min_heap(batch_size); // min-heap by logit int tk = std::min(k, n_vocab); - // min_heap.reserve(batch_size * tk); for (int idx = 0; idx < batch_size; ++idx) { for (int32_t token_idx = 0; token_idx < tk; ++token_idx) { min_heap[idx].push_back(get_token_data(idx, token_idx)); } } - auto comp = [](const model_token_data& a, const model_token_data& b) { return a.logit > b.logit; }; + auto comp = [](const beam_next_token& a, const beam_next_token& b) { return a.score > b.score; }; for (int idx = 0; idx < batch_size; ++idx) { std::make_heap(min_heap[idx].begin(), min_heap[idx].end(), comp); for (int32_t token_idx = tk; token_idx < n_vocab; ++token_idx) { - if (min_heap[idx].front().logit < get_token_data(idx, token_idx).logit) { + if (min_heap[idx].front().score < get_token_data(idx, token_idx).score) { std::pop_heap(min_heap[idx].begin(), min_heap[idx].end(), comp); min_heap[idx].back().id = token_idx; - min_heap[idx].back().logit = get_token_data(idx, token_idx).logit; + min_heap[idx].back().score = get_token_data(idx, token_idx).score; std::push_heap(min_heap[idx].begin(), min_heap[idx].end(), comp); } } } return min_heap; } - - float probability_from_logit(const int& batch_idx, const float& logit) { - return normalizers[batch_idx] * std::exp(logit - max_ls[batch_idx]); - } - - float log_probability_from_logit(const int& batch_idx, const float& logit) { - return std::log(probability_from_logit(batch_idx, logit)); - } }; void logits_processor::min_new_tokens_logits_process(const uint32_t& cur_len, const model_vocab::id& eos_token_id) { @@ -2019,7 +2018,7 @@ void logits_processor::min_new_tokens_logits_process(const uint32_t& cur_len, co size_t bs_stride = ctx->logits.size() / ctx->batch_size; for (int i = 0; i < batch_size; ++i) { // forbidden to choose eos_token if cur_len < min_new_tokens - *(model_get_logits(ctx) + i * bs_stride + offset + eos_token_id) = 0.0f; + *(model_get_logits(ctx) + i * bs_stride + offset + eos_token_id) = NEG_INF; } } } @@ -2033,7 +2032,7 @@ void logits_processor::process(const uint32_t& cur_len, const model_vocab::id& e // TODO dispatch JBLAS kv cache manager void beam_search_kv_cache_reorder::update(const uint32_t& n_past, const uint32_t& n_prompt_tokens, - const std::unordered_map& kv_reorder_indices, + const std::vector>& kv_reorder_indices, const std::vector& next_beams) { // first step if (n_past == n_prompt_tokens) { @@ -2065,9 +2064,11 @@ void beam_search_kv_cache_reorder::update(const uint32_t& n_past, const uint32_t } } else if (n_past > n_prompt_tokens) { // next setp - for (auto it : kv_reorder_indices) { - if (it.first != it.second) { - uint32_t len = next_beams[it.first].token_ids.size() - 1; + for (auto t : kv_reorder_indices) { + int cur_id = std::get<0>(t); + int cpy_id = std::get<1>(t); + if (cur_id != cpy_id) { + uint32_t len = next_beams[cur_id].token_ids.size() - 1; // last token in beam is for next step inference MODEL_ASSERT(len == n_past - n_prompt_tokens); size_t input_token_offset_k = n_prompt_tokens * ne_element_size(ctx->model.kv_self.k) * n_embd; @@ -2083,22 +2084,22 @@ void beam_search_kv_cache_reorder::update(const uint32_t& n_past, const uint32_t // [n_embd, N] memcpy(static_cast(ctx->model.kv_self.k->data) + (i * n_ctx * ne_element_size(ctx->model.kv_self.k) * n_embd * kv_n_ctx_block + - it.first * n_ctx * ne_element_size(ctx->model.kv_self.k) * n_embd) + + cur_id * n_ctx * ne_element_size(ctx->model.kv_self.k) * n_embd) + input_token_offset_k, static_cast(ctx->model.kv_self.k->data) + i * n_ctx * ne_element_size(ctx->model.kv_self.k) * n_embd * kv_n_ctx_block + - it.second * n_ctx * ne_element_size(ctx->model.kv_self.k) * n_embd + input_token_offset_k, + cpy_id * n_ctx * ne_element_size(ctx->model.kv_self.k) * n_embd + input_token_offset_k, ne_element_size(ctx->model.kv_self.k) * n_embd * len); // [N, n_embd] for (int k = 0; k < n_embd; ++k) { memcpy(static_cast(ctx->model.kv_self.v->data) + (i * n_ctx * ne_element_size(ctx->model.kv_self.v) * n_embd * kv_n_ctx_block + - it.first * n_ctx * ne_element_size(ctx->model.kv_self.v) * n_embd + + cur_id * n_ctx * ne_element_size(ctx->model.kv_self.v) * n_embd + n_ctx * ne_element_size(ctx->model.kv_self.v) * k + input_token_offset_v), static_cast(ctx->model.kv_self.v->data) + (i * n_ctx * ne_element_size(ctx->model.kv_self.v) * n_embd * kv_n_ctx_block + - it.second * n_ctx * ne_element_size(ctx->model.kv_self.v) * n_embd + - n_ctx * ne_element_size(ctx->model.kv_self.v) + input_token_offset_v), + cpy_id * n_ctx * ne_element_size(ctx->model.kv_self.v) * n_embd + + n_ctx * ne_element_size(ctx->model.kv_self.v) * k + input_token_offset_v), ne_element_size(ctx->model.kv_self.v) * len); } } @@ -2109,111 +2110,157 @@ void beam_search_kv_cache_reorder::update(const uint32_t& n_past, const uint32_t } } +// Return top k token_data by score. (prompt_bs * sample_scale * num_beam) +// each beam gives top_k results --> + prev_scores --> from (num_beam * top_k) sort num_beam +// for example, huggingface transformers repo implements like this: +// log_softmax(num_beam*n_vocab) -- > + prev_scores --> sort num_beam +// it's straightforward but computing all log_softmax brings overhead +// we sample top_k logits for each beam, than compute scores in these logits positions +// then we sample top_k results among all beams. +// this approach will accelerate sampling speed by log_softmax times reduction +std::vector beam_search_flow::beam_top_k_next_tokens(model_context* ctx, const uint32_t& cur_len, + const std::vector& beams_score, + const std::vector& num_beams, + const std::vector beam_indices, + const int& sample_scale, const int& dim) { + MODEL_ASSERT(dim == -1); // raise unimplemented error + const int request_bs = 1; // TODO ctx->request_running_num + logits_info li(ctx); + lp.process(cur_len, ctx->vocab.eos_token_id); + const int raw_k = sample_scale * beam_size; + // raw logits top_k + std::vector> raw_top_k = li.vocab_top_k(raw_k); + MODEL_ASSERT(raw_top_k.size() == ctx->batch_size); // request_bs * num_beam + MODEL_ASSERT(raw_top_k[0].size() == raw_k); + MODEL_ASSERT(beams_score.size() == ctx->batch_size); + // compute score: log_softmax + prev_score +#pragma omp parallel for + for (int i = 0; i < ctx->batch_size; ++i) { + std::for_each(raw_top_k[i].begin(), raw_top_k[i].end(), + [&](beam_next_token& r) { r.score = li.log_probability_from_logit(i, r.score) + beams_score[i]; }); + } + MODEL_ASSERT(num_beams.size() == request_bs); + std::vector res; + res.reserve(sample_scale * std::accumulate(num_beams.begin(), num_beams.end(), 0)); + std::vector min_heap; + const uint32_t n_vocab = ctx->model.hparams.n_vocab; + size_t row_off = 0; + auto comp = [](const beam_next_token& a, const beam_next_token& b) { return a.score > b.score; }; + for (int i = 0; i < request_bs; ++i) { + const int num_beam = num_beams[i]; + const int sample_k = sample_scale * num_beam; + MODEL_ASSERT(raw_k >= sample_k); + min_heap.clear(); + min_heap.reserve(sample_k); + for (int j = 0; j < num_beam; ++j) { + int n = 0; + if (j == 0) { // init heap + for (; n < sample_k; ++n) { + min_heap.push_back(beam_next_token( + {raw_top_k[row_off + j][n].id, raw_top_k[row_off + j][n].score, beam_indices[row_off + j]})); + } + std::make_heap(min_heap.begin(), min_heap.end(), comp); + } + MODEL_ASSERT(min_heap.size() == sample_k); + for (; n < raw_k; ++n) { + beam_next_token nr({raw_top_k[row_off + j][n].id, raw_top_k[row_off + j][n].score, beam_indices[row_off + j]}); + if (min_heap.front().score < nr.score) { + std::pop_heap(min_heap.begin(), min_heap.end(), comp); + min_heap.back().id = nr.id; + min_heap.back().score = nr.score; + min_heap.back().beam_idx = nr.beam_idx; + std::push_heap(min_heap.begin(), min_heap.end(), comp); + } + } + } + row_off += i * num_beam; + std::sort(min_heap.begin(), min_heap.end(), + [](const beam_next_token& a, const beam_next_token& b) { return a.score > b.score; }); + for (const auto b : min_heap) { + res.push_back(b); + } + } + return res; +} + // TODO debug info unify (function ptr?) -void beam_search_flow::fill_next_beams_by_top_probabilities() { +void beam_search_flow::fill_next_beams_by_top_scores() { auto const comp = [](const beam& a, const beam& b) { return a.score > b.score; }; std::vector embd_inp; - std::vector infer_beam_ids(beam_size); int record = 0; int batch_size = 0; uint32_t cur_len = 0; + std::vector beam_indices; + std::vector beams_score; for (int i = 0; i < beam_size; ++i) { - // is done or not - if (!cur_beams[i].eos()) { - if (cur_len != 0) { - MODEL_ASSERT(cur_len == cur_beams[i].token_ids.size()); - } else { - cur_len = cur_beams[i].token_ids.size(); - } - // (batch, 1) - // ordered by infer_bs_id - embd_inp.push_back(cur_beams[i].token_ids.back()); - infer_beam_ids[i] = record++; - batch_size++; + MODEL_ASSERT(!cur_beams[i].eos()); + if (cur_len != 0) { + MODEL_ASSERT(cur_len == cur_beams[i].token_ids.size()); + } else { + cur_len = cur_beams[i].token_ids.size(); } + // (batch, 1) + // ordered by infer_bs_id + embd_inp.push_back(cur_beams[i].token_ids.back()); + batch_size++; + beam_indices.push_back(i); + beams_score.push_back(cur_beams[i].score); } // DEBUG -#if 0 - printf("====================== \n"); +#ifdef NE_BEAM_SEARCH_VERBOSE_ON + printf("========================================================================================= \n"); + printf("next_tokens for inference: \n"); for (auto kk : embd_inp) { - printf("%s \n", (ctx->vocab.id_to_token.at(kk).tok).c_str()); + printf("%d: %s \n", kk, (ctx->vocab.id_to_token.at(kk).tok).c_str()); } + printf("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ \n"); #endif ctx->batch_size = batch_size; int n_tokens = 1; model_eval(ctx, embd_inp.data(), n_tokens, n_past, num_threads); - // DEBUG -#if 0 - size_t bs_stride = n_tokens * ctx->model.hparams.n_vocab; - for (int k = 0; k < batch_size; ++k) { - printf("====================== \n"); - for (int kk = 0; kk < 10; ++kk) { - printf("%4.5f \n", model_get_logits(ctx) + k * bs_stride + kk); - } - } -#endif - lp.process(cur_len, 50256); // TODO ctx->model.eos_id; - logits_info li(ctx); - // sample 2 - const int sample_num = 2; - std::vector> next_tokens = li.top_k(sample_num); + const int sample_scale = 2; + std::vector next_tokens = + beam_top_k_next_tokens(ctx, cur_len, beams_score, {batch_size}, beam_indices, sample_scale); + // DEBUG -#if 0 - for (int k = 0; k < next_tokens.size(); ++k) { - printf("====================== \n"); - for (auto kk : next_tokens[k]) { - printf("%s, l: %3.6f, p: %0.6f \n", (ctx->vocab.id_to_token.at(kk.id).tok).c_str(), kk.logit, - li.log_probability_from_logit(k, kk.logit)); - } +#ifdef NE_BEAM_SEARCH_VERBOSE_ON + printf("top_k next_tokens: \n"); + for (auto kk : next_tokens) { + printf("%d: %s, score: %10.6f, beam_idx: %d \n", kk.id, (ctx->vocab.id_to_token.at(kk.id).tok).c_str(), kk.score, + kk.beam_idx); } + printf("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ \n"); #endif - MODEL_ASSERT(next_tokens.size() == batch_size); - for (int i = 0; i < beam_size; ++i) { - beam b = cur_beams[i]; - if (b.eos()) { - // b is at end-of-sentence, so just copy it to next_beams if its - // probability is high enough. - if (next_beams.size() < beam_size) { - next_beams.push_back(b); - if (next_beams.size() == beam_size) { - std::make_heap(next_beams.begin(), next_beams.end(), comp); - } - } else if (next_beams.front().score < b.score) { - std::pop_heap(next_beams.begin(), next_beams.end(), comp); - next_beams.back() = b; - std::push_heap(next_beams.begin(), next_beams.end(), comp); + MODEL_ASSERT(next_tokens.size() == batch_size * sample_scale); + MODEL_ASSERT(next_beams.empty()); + for (int i = 0; i < next_tokens.size(); ++i) { + if (next_tokens[i].id == ctx->vocab.eos_token_id) { + // if beam_token does not belong to top num_beams tokens, it should not be added + bool is_beam_token_worse_than_top_num_beams = i >= beam_size ? true : false; + if (is_beam_token_worse_than_top_num_beams) { + continue; } + // update score with eos next token + cur_beams[next_tokens[i].beam_idx].score = next_tokens[i].score; + beam_hypos[0].add(cur_beams[next_tokens[i].beam_idx], n_prompt_tokens); } else { - int j = 0; - if (next_beams.size() < beam_size) { - for (; next_beams.size() < beam_size && j < sample_num; ++j) { - beam next_beam = b; - next_beam.token_ids.push_back(next_tokens[infer_beam_ids[i]][j].id); - next_beam.score += li.log_probability_from_logit(infer_beam_ids[i], next_tokens[infer_beam_ids[i]][j].logit); - next_beams.push_back(std::move(next_beam)); - } - std::make_heap(next_beams.begin(), next_beams.end(), comp); - } - for (; j < sample_num; ++j) { - float const next_score = - b.score + li.log_probability_from_logit(infer_beam_ids[i], next_tokens[infer_beam_ids[i]][j].logit); - if (next_beams.front().score < next_score) { - std::pop_heap(next_beams.begin(), next_beams.end(), comp); - next_beams.back() = b; - next_beams.back().token_ids.push_back(next_tokens[infer_beam_ids[i]][j].id); - next_beams.back().score = next_score; - std::push_heap(next_beams.begin(), next_beams.end(), comp); - } - } + beam next_beam = cur_beams[next_tokens[i].beam_idx]; + next_beam.token_ids.push_back(next_tokens[i].id); + next_beam.score = next_tokens[i].score; + next_beams.push_back(std::move(next_beam)); + } + if (next_beams.size() == beam_size) { + break; } } + std::sort(next_beams.begin(), next_beams.end(), [](beam& a, beam& b) { return a.infer_bs_id < b.infer_bs_id; }); } // get kv cache reorder indices, -// k: dst_beam batch idx, v: src_beam batch idx +// idx_0: dst_beam batch idx, idx_1: src_beam batch idx // for copy predicted past token kv cache // for example: // - c @@ -2223,13 +2270,14 @@ void beam_search_flow::fill_next_beams_by_top_probabilities() { // - f | - ad // b -| ---------->| // - g -// kv_cache_reorder_indices = {0:0, 1:0} +// kv_cache_reorder_indices = {{0,0}, {1,0}} // if kv_cache_reorder_indices = {0:0, 1:1}, then do not need reorder (cpy) -std::unordered_map beam_search_flow::update_kv_cache_reorder_indices() { +std::vector> beam_search_flow::update_kv_cache_reorder_indices() { MODEL_ASSERT(next_beams.size() == beam_size); MODEL_ASSERT(cur_beams.size() == beam_size); // DEBUG -#if 0 +#ifdef NE_BEAM_SEARCH_VERBOSE_ON + printf("kv cache update indices info: \n"); printf("cur_beams: "); for (int i = 0; i < beam_size; ++i) { printf("%d, ", cur_beams[i].infer_bs_id); @@ -2241,7 +2289,7 @@ std::unordered_map beam_search_flow::update_kv_cache_reorder_indices() } printf("\n"); #endif - std::unordered_map kv_reorder_indices; + std::vector> kv_reorder_indices; kv_reorder_indices.reserve(beam_size); // shuffle beams which are early stopped (eos) // keep them behind beams which have non-eos @@ -2266,13 +2314,43 @@ std::unordered_map beam_search_flow::update_kv_cache_reorder_indices() // update indices and batch ids for (int i = 0; i < beam_size; ++i) { - kv_reorder_indices[i] = cpy_final_bs_ids[i]; // update infer_bs_id before next beam generation next_beams[nb_shuffle_ids[i]].infer_bs_id = i; } // beams should be ordered by batch id std::sort(next_beams.begin(), next_beams.end(), [](beam& a, beam& b) { return a.infer_bs_id < b.infer_bs_id; }); -#if 0 // DEBUG + + // we arrange beams by inference batch indice rather score for memcpy time reduction + // so there will be 2 circumstances (ignore no memcpy : 0,1,2,3 --> 0,1,2,3) + // 1. cpoy former beams into latter beams, like: 0,1,2,3 --> 0,0,0,1 + // 2. copy latter beams into former beams, like: 0,1,2,3 -- > 1,2,2,3 + // kv cache memcpy happens in itself which would cause memory dislocation if follows wrong order + // so we give the contrary order to beams vector indice, which is: + // if 1, copy order is from tail to head + // if 2, copy order is from head to tail + bool cpy_from_head = true; + int dst_idx_sum = 0; + int src_idx_sum = 0; + for (int i = 0; i < cpy_final_bs_ids.size(); ++i) { + dst_idx_sum += i; + src_idx_sum += cpy_final_bs_ids[i]; + if (src_idx_sum < dst_idx_sum) { + cpy_from_head = false; + break; + } + } + if (cpy_from_head) { + for (int i = 0; i < cpy_final_bs_ids.size(); ++i) { + kv_reorder_indices.push_back({i, cpy_final_bs_ids[i]}); + } + } else { + for (int i = cpy_final_bs_ids.size() - 1; i >= 0; --i) { + kv_reorder_indices.push_back({i, cpy_final_bs_ids[i]}); + } + } + + // DEBUG +#ifdef NE_BEAM_SEARCH_VERBOSE_ON printf("cpy_final_bs_ids: "); for (int i = 0; i < beam_size; ++i) { printf("%d, ", cpy_final_bs_ids[i]); @@ -2288,20 +2366,36 @@ std::unordered_map beam_search_flow::update_kv_cache_reorder_indices() printf("%d, ", next_beams[i].infer_bs_id); } printf("\n"); + printf("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ \n"); #endif return kv_reorder_indices; } -void beam_search_flow::beam_score_length_penalize() { - float length_penalty = ctx->generation_conf.length_penalty; - std::for_each(cur_beams.begin(), cur_beams.end(), - [&](beam& b) { b.score /= std::pow(b.token_ids.size(), length_penalty); }); -} - // Return beam with highest probability. -const beam& beam_search_flow::top_beam() { - auto const by_score = [](beam const& a, beam const& b) { return a.score < b.score; }; - return *std::max_element(cur_beams.begin(), cur_beams.end(), by_score); +const beam& beam_search_flow::finalize() { +#ifdef NE_BEAM_SEARCH_VERBOSE_ON + printf("========================================================================================= \n"); + printf("finalize: \n"); + printf("before: \n"); + for (auto b : beam_hypos[0].beams) { + b.print(); + } + printf("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ \n"); +#endif + if (!requests_done[0]) { + for (const auto b : cur_beams) { + beam_hypos[0].add(b, n_prompt_tokens); + } +#ifdef NE_BEAM_SEARCH_VERBOSE_ON + printf("after (adding more beams from outside): \n"); + for (auto b : beam_hypos[0].beams) { + b.print(); + } + printf("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ \n"); + printf("========================================================================================= \n"); +#endif + } + return beam_hypos[0].top1(); } // TODO batch_size = 1 only @@ -2313,66 +2407,91 @@ std::vector beam_search_flow::loop(const model_token* tokens_inp, c return std::vector(); } num_threads = n_threads; + n_prompt_tokens = n_tokens; std::vector beam_search_response; std::vector embd(tokens_inp, tokens_inp + n_tokens); ctx->batch_size = 1; const uint32_t max_new_tokens = ctx->generation_conf.max_new_tokens; - // Loop while there are any beams that have not yet reached end-of-sentence. - // If the top beam is at end-of-sentence, then finish since all other - // beam score can only decrease. + // Loop ends in: 1. all requests done; or 2. reach max_new_tokens length auto const eos = [](const beam& b) { return b.eos(); }; kv_reorder = ctx->bs_kv_reorder; if (kv_reorder == nullptr) { kv_reorder = std::make_shared(ctx); } - for (int n = 0; n < max_new_tokens && !eos(top_beam()) && !std::all_of(cur_beams.begin(), cur_beams.end(), eos); - ++n) { + beam_hypos.push_back(beam_hypotheses(ctx)); // TODO ctx->request_running_bs; + requests_done.push_back(false); + for (int n = 0; n < max_new_tokens; ++n) { // first step if (n_past == 0) { model_eval(ctx, embd.data(), n_tokens, n_past, num_threads); n_past += n_tokens; kv_reorder->update(n_past, n_tokens); - lp.process(0, 50256); // TODO ctx->model.eos_id; - logits_info li(ctx); - std::vector> next_tokens = li.top_k(beam_size); - MODEL_ASSERT(next_tokens.size() == 1); + std::vector next_tokens = beam_top_k_next_tokens(ctx, 0, {0.0f}, {1}, {0}, beam_size); + MODEL_ASSERT(next_tokens.size() == beam_size); cur_beams.clear(); + // DEBUG +#ifdef NE_BEAM_SEARCH_VERBOSE_ON + printf("========================================================================================== \n"); + printf("top_k next_tokens: \n"); + for (auto kk : next_tokens) { + printf("%d: %s, score: %12.6f, beam_idx: %d \n", kk.id, (ctx->vocab.id_to_token.at(kk.id).tok).c_str(), + kk.score, kk.beam_idx); + } + printf("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ \n"); +#endif for (int i = 0; i < beam_size; ++i) { beam b; b.ctx = ctx; - b.token_ids.push_back(next_tokens[0][i].id); - b.score = li.log_probability_from_logit(0, next_tokens[0][i].logit); + b.token_ids.push_back(next_tokens[i].id); + b.score = next_tokens[i].score; b.infer_bs_id = i; cur_beams.push_back(b); } - beam_score_length_penalize(); } else { - fill_next_beams_by_top_probabilities(); - std::unordered_map kv_reorder_indices = update_kv_cache_reorder_indices(); + fill_next_beams_by_top_scores(); + std::vector> kv_reorder_indices = update_kv_cache_reorder_indices(); n_past += 1; kv_reorder->update(n_past, n_tokens, kv_reorder_indices, next_beams); cur_beams.swap(next_beams); next_beams.clear(); - beam_score_length_penalize(); } -#if 0 // DEBUG: print current beams for this iteration - printf("\n\nCurrent beams:\n"); - for (size_t j = 0; j < beams.size(); ++j) { + // DEBUG: print current beams for this iteration +#ifdef NE_BEAM_SEARCH_VERBOSE_ON + printf("current beams:\n"); + for (size_t j = 0; j < cur_beams.size(); ++j) { printf("beams[%d]: ", j); - beams[j].print(); + cur_beams[j].print(); fflush(stdout); } + printf("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ \n"); #endif + + // check if done + for (int h = 0; h < beam_hypos.size(); ++h) { + if (requests_done[h]) { + continue; + } + if (beam_hypos[h].is_done()) { + requests_done[h] = true; + } + } + auto const done_or_not = [](const bool& flag) { return flag; }; + if (std::all_of(requests_done.begin(), requests_done.end(), done_or_not)) { + break; + } } - const beam& top_b = top_beam(); + const beam& top_b = finalize(); -#if 0 // DEBUG: print final beam result - printf("\n\nFinal beam:\n"); - top_b.print(); +#ifdef NE_BEAM_SEARCH_VERBOSE_ON // DEBUG: print final beam result + printf("========================================================================================= \n"); + printf("final beam:\n"); + top_b.print(); + printf("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ \n"); + printf("========================================================================================= \n"); #endif beam_search_response.clear(); diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.h b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.h index f7456bc0593..1fb02cfae1e 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.h +++ b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.h @@ -17,6 +17,7 @@ #include #include #include +#include #include "application/common.h" #include "models/model_utils/model_config.h" @@ -259,6 +260,14 @@ MODEL_API const char* model_print_system_info(void); #endif /* beam search utils */ +#define NEG_INF -std::numeric_limits::max() + +typedef struct beam_next_token { + model_token id; // token id + float score; // score of the token + int beam_idx; // token in which beam (-1 means unknown) +} beam_next_token; + struct beam { const model_context* ctx = nullptr; std::vector token_ids; @@ -267,16 +276,77 @@ struct beam { // record inference batch indice int infer_bs_id; // end-of-text - const bool eos() const { return !token_ids.empty() && token_ids.back() == 50256; } // TODO ctx->vocab.eos_id + const bool eos() const { return !token_ids.empty() && token_ids.back() == ctx->vocab.eos_token_id; } void print() const { - printf("score: %0.6f, eos: %d, tokens: ", score, eos()); + printf("length: %d, score: %12.6f, eos: %d, tokens:\n", token_ids.size(), score, eos()); for (const auto& id : token_ids) { - printf("%s", model_token_to_str(ctx, id)); + printf("%d: %s, ", id, model_token_to_str(ctx, id)); } printf("\n"); } }; +struct beam_hypotheses { + const model_context* const ctx = nullptr; + const int num_beams; + const float length_penalty = 1.0f; + const bool early_stopping = false; + std::vector beams; + + beam_hypotheses(model_context* lctx) + : ctx(lctx), + num_beams(lctx->beam_size), + length_penalty(lctx->generation_conf.length_penalty), + early_stopping(lctx->generation_conf.do_early_stopping) { + beams.reserve(lctx->beam_size); + } + + int len() { return beams.size(); } + + void add(beam b, const uint32_t& n_prompt_tokens) { + auto comp = [](const beam& a, const beam& b) { return a.score > b.score; }; + uint32_t cur_len = b.eos() ? b.token_ids.size() - 1 : b.token_ids.size(); + float score = b.score / std::pow(cur_len + n_prompt_tokens, length_penalty); +#ifdef NE_BEAM_SEARCH_VERBOSE_ON + printf("add beam hypos: \n"); + b.print(); + printf("origin_score: %12.6f, new_score: %12.6f, sentence_len: %d \n", b.score, score, cur_len + n_prompt_tokens); + printf("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ \n"); +#endif + b.score = score; + if (beams.size() < num_beams) { + beams.push_back(std::move(b)); + if (beams.size() == num_beams) { + std::make_heap(beams.begin(), beams.end(), comp); + } + } else { + MODEL_ASSERT(beams.size() == num_beams); + if (beams.front().score > b.score) { + return; + } + std::pop_heap(beams.begin(), beams.end(), comp); + beams.back() = b; + std::push_heap(beams.begin(), beams.end(), comp); + } + } + + const bool is_done() const { + if (beams.size() < num_beams) { + return false; + } + // stop as soon as at least `num_beams` hypotheses are finished + if (early_stopping) { + return true; + } + return false; + } + + const beam& top1() const { + auto const by_score = [](beam const& a, beam const& b) { return a.score < b.score; }; + return *std::max_element(beams.begin(), beams.end(), by_score); + } +}; + struct logits_info; class logits_processor { @@ -307,7 +377,7 @@ class beam_search_kv_cache_reorder { ~beam_search_kv_cache_reorder() {} virtual void update(const uint32_t& n_past, const uint32_t& n_prompt_tokens, - const std::unordered_map& kv_reorder_indices = {}, + const std::vector>& kv_reorder_indices = {}, const std::vector& next_beams = {}); private: @@ -324,7 +394,7 @@ class beam_search_flow { explicit beam_search_flow(model_context* lctx) : ctx(lctx), beam_size(lctx->beam_size), lp(logits_processor(lctx)) { cur_beams.reserve(beam_size); next_beams.reserve(beam_size); - cur_beams.push_back({ctx, {}, 1.0f}); + cur_beams.push_back({ctx, {}, 0.0f}); } ~beam_search_flow() {} @@ -332,16 +402,23 @@ class beam_search_flow { std::vector loop(const model_token* tokens_inp, const int& n_tokens, const int& n_threads); private: - void fill_next_beams_by_top_probabilities(); - std::unordered_map update_kv_cache_reorder_indices(); - void beam_score_length_penalize(); - const beam& top_beam(); + std::vector beam_top_k_next_tokens(model_context* ctx, const uint32_t& cur_len, + const std::vector& beams_score, + const std::vector& num_beams, + const std::vector beam_indices, const int& sample_scale = 2, + const int& dim = -1); + void fill_next_beams_by_top_scores(); + std::vector> update_kv_cache_reorder_indices(); + const beam& finalize(); model_context* ctx = nullptr; const int beam_size; std::vector cur_beams; std::vector next_beams; - size_t n_past = 0; + std::vector beam_hypos; + std::vector requests_done; + uint32_t n_past = 0; + uint32_t n_prompt_tokens = 0; int num_threads = 4; // default by 4 logits_processor lp; std::shared_ptr kv_reorder;