Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 77 additions & 64 deletions gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1033,76 +1033,89 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t pos_offset = 0; // offset relative to pos
const double prefill_start = hwy::platform::Now();

// Prefill stops before prompt_size - 1 since the last prompt token is the
// first input token for generation.
while (pos_offset < prompt_size - 1) {
const size_t batch_size =
std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset);
HWY_DASSERT(batch_size <= kPrefillBatchSize);
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
const int* batch_tokens = prompt.data() + pos_offset;
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
prefill_activations, kv_cache, pool, inner_pool);
for (size_t idx = 0; idx < batch_size; ++idx) {
stream_token(batch_tokens[idx], 0.0f);
auto prefill_phase = [&]() HWY_ATTR {
bool keep_on = true;
// Prefill stops before prompt_size - 1 since the last prompt token is the
// first input token for generation.
while (pos_offset < prompt_size - 1 && keep_on) {
const size_t batch_size =
std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset);
HWY_DASSERT(batch_size <= kPrefillBatchSize);
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
const int* batch_tokens = prompt.data() + pos_offset;
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
prefill_activations, kv_cache, pool, inner_pool);
for (size_t idx = 0; idx < batch_size; ++idx) {
keep_on = stream_token(batch_tokens[idx], 0.0f);
if(!keep_on) {
break;
}
}
pos += batch_size;
pos_offset += batch_size;
}
pos += batch_size;
pos_offset += batch_size;
}

if (verbosity >= 2) {
// in the future this output should not occur in GenerateImpl but instead
// should be available as observable state for frontend code to handle I/O.
const double prefill_end = hwy::platform::Now();
const double prefill_tok_sec =
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
}

const double gen_start = hwy::platform::Now();

HWY_DASSERT(pos_offset == prompt_size - 1);
if (verbosity >= 2) {
// in the future this output should not occur in GenerateImpl but instead
// should be available as observable state for frontend code to handle I/O.
const double prefill_end = hwy::platform::Now();
const double prefill_tok_sec =
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
}
return keep_on;
};

size_t pos_gen_start = pos_offset;
int token = prompt.at(pos_offset);
stream_token(token, 0);
for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) {
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool);
float* final_activation = activations.x.data();
// The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill
// is disabled.
if (pos_offset >= prompt_size - 1) {
PROFILER_ZONE("Gen.Embedding");
// Generation phase
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
0, final_activation,
activations.logits.data(), pool);
// Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
gen, temperature, accept_token);
if (!stream_token(token, activations.logits[token])) {
token = EOS_ID;
auto transform_phase = [&]() HWY_ATTR {
const double gen_start = hwy::platform::Now();

HWY_DASSERT(pos_offset == prompt_size - 1);

size_t pos_gen_start = pos_offset;
int token = prompt.at(pos_offset);
stream_token(token, 0);
for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) {
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool);
float* final_activation = activations.x.data();
// The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill
// is disabled.
if (pos_offset >= prompt_size - 1) {
PROFILER_ZONE("Gen.Embedding");
// Generation phase
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
0, final_activation,
activations.logits.data(), pool);
// Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
gen, temperature, accept_token);
if (!stream_token(token, activations.logits[token])) {
token = EOS_ID;
}
} else {
// We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1);
stream_token(token, 0);
}
} else {
// We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1);
stream_token(token, 0);
}
if (token == EOS_ID) {
if (verbosity >= 2) {
const double gen_end = hwy::platform::Now();
const double gen_tok_sec =
static_cast<double>(pos_offset - pos_gen_start) /
(gen_end - gen_start);
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
if (token == EOS_ID) {
if (verbosity >= 2) {
const double gen_end = hwy::platform::Now();
const double gen_tok_sec =
static_cast<double>(pos_offset - pos_gen_start) /
(gen_end - gen_start);
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
}
break;
}
break;
}
};

if(prefill_phase()) {
transform_phase();
}
}

Expand Down