From bbd92a8ae0295de4750f56ba753780d73da4d2bb Mon Sep 17 00:00:00 2001 From: Charles Chan Date: Sat, 20 Apr 2024 15:32:45 +0800 Subject: [PATCH 1/2] Use lambda to split function --- gemma.cc | 141 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 77 insertions(+), 64 deletions(-) diff --git a/gemma.cc b/gemma.cc index 80f7c176..0b53cd1e 100644 --- a/gemma.cc +++ b/gemma.cc @@ -1033,76 +1033,89 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, size_t pos_offset = 0; // offset relative to pos const double prefill_start = hwy::platform::Now(); - // Prefill stops before prompt_size - 1 since the last prompt token is the - // first input token for generation. - while (pos_offset < prompt_size - 1) { - const size_t batch_size = - std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset); - HWY_DASSERT(batch_size <= kPrefillBatchSize); - HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1); - const int* batch_tokens = prompt.data() + pos_offset; - Prefill(batch_tokens, batch_size, pos, weights, - prefill_activations, kv_cache, pool, inner_pool); - for (size_t idx = 0; idx < batch_size; ++idx) { - stream_token(batch_tokens[idx], 0.0f); + auto prefill_phase = [&]() { + bool keep_on = true; + // Prefill stops before prompt_size - 1 since the last prompt token is the + // first input token for generation. + while (pos_offset < prompt_size - 1 && keep_on) { + const size_t batch_size = + std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset); + HWY_DASSERT(batch_size <= kPrefillBatchSize); + HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1); + const int* batch_tokens = prompt.data() + pos_offset; + Prefill(batch_tokens, batch_size, pos, weights, + prefill_activations, kv_cache, pool, inner_pool); + for (size_t idx = 0; idx < batch_size; ++idx) { + keep_on = stream_token(batch_tokens[idx], 0.0f); + if(!keep_on) { + break; + } + } + pos += batch_size; + pos_offset += batch_size; } - pos += batch_size; - pos_offset += batch_size; - } - - if (verbosity >= 2) { - // in the future this output should not occur in GenerateImpl but instead - // should be available as observable state for frontend code to handle I/O. - const double prefill_end = hwy::platform::Now(); - const double prefill_tok_sec = - static_cast(pos_offset) / (prefill_end - prefill_start); - std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]"; - } - const double gen_start = hwy::platform::Now(); - - HWY_DASSERT(pos_offset == prompt_size - 1); + if (verbosity >= 2) { + // in the future this output should not occur in GenerateImpl but instead + // should be available as observable state for frontend code to handle I/O. + const double prefill_end = hwy::platform::Now(); + const double prefill_tok_sec = + static_cast(pos_offset) / (prefill_end - prefill_start); + std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]"; + } + return keep_on; + }; - size_t pos_gen_start = pos_offset; - int token = prompt.at(pos_offset); - stream_token(token, 0); - for (size_t generate_pos = 0; - pos < max_tokens && generate_pos < max_generated_tokens; - ++pos, ++pos_offset, ++generate_pos) { - Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool); - float* final_activation = activations.x.data(); - // The condition below is always true if we are doing Prefill above. - // We keep it here for clarity so that the code is correct even if Prefill - // is disabled. - if (pos_offset >= prompt_size - 1) { - PROFILER_ZONE("Gen.Embedding"); - // Generation phase - MatVec(weights.embedder_input_embedding, - 0, final_activation, - activations.logits.data(), pool); - // Barrier: must have all logits so we can subtract max. - Softmax(activations.logits.data(), kVocabSize); - token = SampleTopK(activations.logits.data(), kVocabSize, - gen, temperature, accept_token); - if (!stream_token(token, activations.logits[token])) { - token = EOS_ID; + auto transform_phase = [&]() { + const double gen_start = hwy::platform::Now(); + + HWY_DASSERT(pos_offset == prompt_size - 1); + + size_t pos_gen_start = pos_offset; + int token = prompt.at(pos_offset); + stream_token(token, 0); + for (size_t generate_pos = 0; + pos < max_tokens && generate_pos < max_generated_tokens; + ++pos, ++pos_offset, ++generate_pos) { + Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool); + float* final_activation = activations.x.data(); + // The condition below is always true if we are doing Prefill above. + // We keep it here for clarity so that the code is correct even if Prefill + // is disabled. + if (pos_offset >= prompt_size - 1) { + PROFILER_ZONE("Gen.Embedding"); + // Generation phase + MatVec(weights.embedder_input_embedding, + 0, final_activation, + activations.logits.data(), pool); + // Barrier: must have all logits so we can subtract max. + Softmax(activations.logits.data(), kVocabSize); + token = SampleTopK(activations.logits.data(), kVocabSize, + gen, temperature, accept_token); + if (!stream_token(token, activations.logits[token])) { + token = EOS_ID; + } + } else { + // We would take this branch if we were not doing Prefill but would + // process the tokens of the prompt one at a time. + token = prompt.at(pos_offset + 1); + stream_token(token, 0); } - } else { - // We would take this branch if we were not doing Prefill but would - // process the tokens of the prompt one at a time. - token = prompt.at(pos_offset + 1); - stream_token(token, 0); - } - if (token == EOS_ID) { - if (verbosity >= 2) { - const double gen_end = hwy::platform::Now(); - const double gen_tok_sec = - static_cast(pos_offset - pos_gen_start) / - (gen_end - gen_start); - std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; + if (token == EOS_ID) { + if (verbosity >= 2) { + const double gen_end = hwy::platform::Now(); + const double gen_tok_sec = + static_cast(pos_offset - pos_gen_start) / + (gen_end - gen_start); + std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; + } + break; } - break; } + }; + + if(prefill_phase()) { + transform_phase(); } } From 62f69fe837c074b667bc880d9e5e3a131bdd5418 Mon Sep 17 00:00:00 2001 From: Charles Chan Date: Mon, 22 Apr 2024 17:44:16 +0800 Subject: [PATCH 2/2] add HWY_ATTR for Windows --- gemma.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gemma.cc b/gemma.cc index 0b53cd1e..47ad6430 100644 --- a/gemma.cc +++ b/gemma.cc @@ -1033,7 +1033,7 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, size_t pos_offset = 0; // offset relative to pos const double prefill_start = hwy::platform::Now(); - auto prefill_phase = [&]() { + auto prefill_phase = [&]() HWY_ATTR { bool keep_on = true; // Prefill stops before prompt_size - 1 since the last prompt token is the // first input token for generation. @@ -1066,7 +1066,7 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, return keep_on; }; - auto transform_phase = [&]() { + auto transform_phase = [&]() HWY_ATTR { const double gen_start = hwy::platform::Now(); HWY_DASSERT(pos_offset == prompt_size - 1);