Skip to content

Commit

Permalink
[Cpp Graph] Align Cpp Beam Search (#322)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhentaoyu committed Sep 27, 2023
1 parent d36e30a commit 6ea8250
Show file tree
Hide file tree
Showing 6 changed files with 365 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ option(NE_PROFILING "neural_engine: use Profiling"
if (NE_PROFILING)
add_compile_definitions(NE_PERF)
endif()
option(NE_BEAM_SEARCH_VERBOSE "neural_engine: print beam search processing log" OFF)
if (NE_BEAM_SEARCH_VERBOSE)
add_compile_definitions(NE_BEAM_SEARCH_VERBOSE_ON)
endif()
option(NE_GELU_VEC "neural_engine: enable vec in gelu" ON)
if (NE_GELU_VEC)
add_compile_definitions(NE_GELU_USE_VEC)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ bool gptj_model_eval_ids(model_context* ctx, model_token* tokens, size_t n_eval,
extern "C" {
void* init_gptj(int seed, int n_predict, int n_batch, int top_k, float top_p, float temp, float repeat_penalty,
bool perplexity, int n_ctx, const char* model_file, bool beam_search = false, int beam_size = 4,
int batch_size = 1, int n_threads = 56, int min_new_tokens = 0, float length_penalty = 1.0) {
int batch_size = 1, int n_threads = 56, int min_new_tokens = 0, float length_penalty = 1.0,
bool do_early_stopping = false) {
gpt_params params;
params.n_threads = n_threads;
params.seed = seed;
Expand All @@ -68,6 +69,7 @@ void* init_gptj(int seed, int n_predict, int n_batch, int top_k, float top_p, fl
params.batch_size = batch_size;
params.beam_search = beam_search;
params.beam_size = beam_size;
params.memory_type = KV_MEM_TYPE_F16; // TODO MEMORY_AUTO for MHA
// params.use_mmap = false;
// params.use_mlock= true;
model_init_backend();
Expand All @@ -80,6 +82,7 @@ void* init_gptj(int seed, int n_predict, int n_batch, int top_k, float top_p, fl
}
ctx->generation_conf.min_new_tokens = min_new_tokens;
ctx->generation_conf.length_penalty = length_penalty;
ctx->generation_conf.do_early_stopping = do_early_stopping;
return (void*)ctx;
}

Expand Down Expand Up @@ -220,13 +223,17 @@ int main(int argc, char* argv[]) {
return 1;
}

auto gptj_in_all_bs = init_gptj(1234, 32, 32, 40, 1.0, 0.8, 1.02, false, 2048, argv[1], true, 4, 1, 56, 30, 1.0);
auto gptj_in_all_bs =
init_gptj(1234, 32, 32, 40, 1.0, 0.8, 1.02, false, 2048, argv[1], true, 4, 1, 56, 30, 1.0, true);
std::vector<void*> ctxs = {gptj_in_all_bs};
for (auto gptj_in_all : ctxs) {
auto res = eval_gptj_char(
gptj_in_all,
//"she opened the door and see",
// "she opened the door and see",
// "Once upon a time",
// "Tell me 10 things about jazz music",
// "A spaceship lands on the moon",
// "What is the meaning of life?",
"2017: It is done, and submitted. You can play 'Survival of the Tastiest' on Android, and on the web. Playing "
"on the web works, but you have to simulate multiple touch for table moving and that can be a bit confusing. "
"There is a lot I'd like to talk about. I will go through every topic, insted of making the typical what went "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,14 @@ static bool gptj_model_eval_internal(model_context& lctx, const model_token* tok
std::vector<ne_tensor*> v_bs(batch_size);
for (int i = 0; i < batch_size; ++i) {
if (run_mha_fp16) {
// batch K
// batch V
Vcur_bs[i] = ne_view_4d(ctx0, Vcur, n_embd / n_head, n_head, N, 1, ne_element_size(Vcur) * n_embd / n_head,
ne_element_size(Vcur) * n_embd, ne_element_size(Vcur) * n_embd * N,
i * ne_element_size(Vcur) * n_embd * N);
v_bs[i] = ne_view_1d(ctx0, kv_self.v, n_embd * N * 1,
(ne_element_size(kv_self.v) * n_embd) * (il * n_ctx * kv_n_ctx_block + n_past) +
i * n_ctx * n_embd * ne_element_size(kv_self.v));
// batch V
// batch K
Kcur_bs[i] = ne_permute(ctx0,
ne_reshape_4d(ctx0,
ne_view_2d(ctx0, Kcur, n_embd, N, ne_element_size(Kcur) * n_embd,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ struct generation_config {
// likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
// `length_penalty` < 0.0 encourages shorter sequences. (default = 1.0)
float length_penalty = 1.0f;
bool do_early_stopping = false; // TODO
bool do_early_stopping = false;
};

class beam_search_kv_cache_reorder; // forward declaration
Expand Down

0 comments on commit 6ea8250

Please sign in to comment.