Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
double nll = 0.0;
double nll2 = 0.0;

const int num_batches = (n_ctx + n_batch - 1) / n_batch;

std::vector<float> logits;
if (num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
}

fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);

std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
Expand All @@ -333,10 +340,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const int start = i * n_ctx;
const int end = start + n_ctx;

const int num_batches = (n_ctx + n_batch - 1) / n_batch;

std::vector<float> logits;

const auto t_start = std::chrono::high_resolution_clock::now();

// clear the KV cache
Expand All @@ -362,8 +365,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;

const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
if (num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
}

const auto t_end = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -392,7 +397,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
const int first = n_ctx/2;
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += n_ctx - first - 1;

Expand All @@ -406,6 +412,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
}
fflush(stdout);

logits.clear();
}
printf("\n");

Expand Down