Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 148 additions & 111 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
prompt_lines.push_back(line);
}

if( prompt_lines.size() % 6 != 0) {
if (prompt_lines.size() % 6 != 0) {
fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__);
return;
}
Expand All @@ -485,7 +485,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));

// Number of tasks to use when computing the score
if ( params.hellaswag_tasks < hs_task_count ) {
if (params.hellaswag_tasks < hs_task_count) {
hs_task_count = params.hellaswag_tasks;
}

Expand All @@ -502,178 +502,215 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
std::string ending[4];
size_t ending_logprob_count[4];
double ending_logprob[4];

size_t i_batch; // starting index in the llama_batch
size_t common_prefix; // max number of initial tokens that are the same in all sentences
size_t required_tokens; // needed number of tokens to evaluate all 4 endings
std::vector<llama_token> seq_tokens[4];
};

fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") );

// Select and read data from prompt lines
hs_data_t *hs_data = new hs_data_t[hs_task_count];
for (size_t i=0; i < hs_task_count; i++) {
std::vector<hs_data_t> hs_data(hs_task_count);
for (size_t i = 0; i < hs_task_count; i++) {
size_t idx = i;

auto & hs_cur = hs_data[i];

// Select a random example of those left in the prompt
if (randomize_tasks) {
std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
idx = dist(rng);
}

hs_data[i].context = prompt_lines[idx*6];
hs_data[i].gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
for (size_t j=0; j < 4; j++) {
hs_data[i].ending[j] = prompt_lines[idx*6+2+j];
hs_cur.context = prompt_lines[idx*6];
hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
for (size_t j = 0; j < 4; j++) {
hs_cur.ending[j] = prompt_lines[idx*6+2+j];
hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], add_bos);
}

// determine the common prefix of the endings
hs_cur.common_prefix = 0;
hs_cur.required_tokens = 0;
for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[3][k]) {
break;
}
hs_cur.common_prefix++;
}
hs_cur.required_tokens = hs_cur.common_prefix +
hs_cur.seq_tokens[0].size() - hs_cur.common_prefix +
hs_cur.seq_tokens[1].size() - hs_cur.common_prefix +
hs_cur.seq_tokens[2].size() - hs_cur.common_prefix +
hs_cur.seq_tokens[3].size() - hs_cur.common_prefix;

//GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, add_bos).size());

// Delete the selected random example from the prompt
if (randomize_tasks) {
prompt_lines.erase( std::next(prompt_lines.begin(),idx*6) , std::next(prompt_lines.begin(),idx*6+6) );
}
}

fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__);

printf("\ntask\tacc_norm\n");

double acc = 0.0f;

const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch;

std::vector<std::vector<int>> ending_tokens(4);
const int max_tasks_per_batch = params.n_parallel;
const int max_seq = 4*max_tasks_per_batch;

std::vector<float> tok_logits(n_vocab);
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);

for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
// Tokenize the context to count tokens
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
size_t context_size = context_embd.size();

for (int i = 0; i < 4; ++i) {
ending_tokens[i] = ::llama_tokenize(ctx, hs_data[task_idx].context + " " + hs_data[task_idx].ending[i], add_bos);
for (int k = 0; k < int(context_size); ++k) {
if (ending_tokens[i][k] != context_embd[k]) {
fprintf(stderr, "Oops: ending %d of task %d differs from context at position %d\n",i,int(task_idx),k);
break;
}
std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(n_ctx*n_vocab);

auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));

llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};

const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
}
}

// Do the 1st ending
// In this case we include the context when evaluating
//auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
auto query_embd = ending_tokens[0];
auto query_size = query_embd.size();

// Stop if query wont fit the ctx window
if (query_size > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
}

// Speedup small evaluations by evaluating atleast 32 tokens
if (query_size < 32) {
query_embd.resize(32);
}
return true;
};

// clear the KV cache
llama_kv_cache_clear(ctx);
for (size_t i0 = 0; i0 < hs_task_count; i0++) {
int n_cur = 0;

auto logits = evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}
size_t i1 = i0;
size_t i_batch = 0; // this tells us where in `llama_batch` we are currently

llama_batch_clear(batch);

std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float));
const auto first_probs = softmax(tok_logits);
// batch as much tasks as possible into the available context
// each task has 4 unique seuqnce ids - one for each ending
// the common prefix is shared among the 4 sequences to save tokens
// we extract logits only from the last common token and from all ending tokens of each sequence
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
auto & hs_cur = hs_data[i1];

hs_data[task_idx].ending_logprob_count[0] = 1;
hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);
const int s0 = 4*(i1 - i0);
if (s0 + 4 > max_seq) {
break;
}

// Calculate the logprobs over the ending
for (size_t j = context_size; j < query_size - 1; j++) {
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix

std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
for (int s = 0; s < 4; ++s) {
for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, true);
}
}

const float prob = softmax(tok_logits)[query_embd[j + 1]];
hs_cur.i_batch = i_batch;
i_batch += hs_cur.required_tokens;

hs_data[task_idx].ending_logprob[0] += std::log(prob);
hs_data[task_idx].ending_logprob_count[0]++;
n_cur += hs_data[i1].required_tokens;
if (++i1 == hs_task_count) {
break;
}
}

// Calculate the mean token logprob for acc_norm
hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0];
if (i0 == i1) {
fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
return;
}

// Do the remaining endings
// For these, we use the bare ending with n_past = context_size
//
for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
llama_kv_cache_clear(ctx);

// Tokenize the query
query_embd.resize(ending_tokens[ending_idx].size() - context_size);
std::memcpy(query_embd.data(), ending_tokens[ending_idx].data() + context_size, query_embd.size()*sizeof(int));
query_size = query_embd.size();
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, n_batch)) {
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
return;
}

// Stop if query wont fit the ctx window
if (context_size + query_size > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}
// compute the logprobs for each ending of the decoded tasks
for (size_t i = i0; i < i1; ++i) {
auto & hs_cur = hs_data[i];

// Speedup small evaluations by evaluating atleast 32 tokens
// No, resizing to 32 is actually slightly slower (at least on CUDA)
//if (query_size < 32) {
// query_embd.resize(32);
//}
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float));

// Evaluate the query
logits = evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}
const auto first_probs = softmax(tok_logits);

hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
size_t li = hs_cur.common_prefix; // logits index in the batch

// Calculate the logprobs over the ending
for (size_t j = 0; j < query_size - 1; j++) {
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
for (int s = 0; s < 4; ++s) {
hs_cur.ending_logprob_count[s] = 1;
hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);

const float prob = softmax(tok_logits)[query_embd[j + 1]];
// Calculate the logprobs over the ending
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float));

hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
hs_data[task_idx].ending_logprob_count[ending_idx]++;
}
const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]];

// Calculate the mean token logprob for acc_norm
hs_data[task_idx].ending_logprob[ending_idx] /= hs_data[task_idx].ending_logprob_count[ending_idx];
hs_cur.ending_logprob[s] += std::log(prob);
hs_cur.ending_logprob_count[s]++;
}

// account that we skip the last token in the ending
++li;

// printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n",
// task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] );
}
// Calculate the mean token logprob for acc_norm
hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
}

// Find the ending with maximum logprob
size_t ending_logprob_max_idx = 0;
double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0];
for (size_t j = 1; j < 4; j++) {
if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
ending_logprob_max_idx = j;
ending_logprob_max_val = hs_data[task_idx].ending_logprob[j];
// Find the ending with maximum logprob
size_t ending_logprob_max_idx = 0;
double ending_logprob_max_val = hs_cur.ending_logprob[0];
for (size_t s = 1; s < 4; s++) {
if (hs_cur.ending_logprob[s] > ending_logprob_max_val) {
ending_logprob_max_idx = s;
ending_logprob_max_val = hs_cur.ending_logprob[s];
}
}
}

// printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx);
//printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx);

// If the gold ending got the maximum logprobe add one accuracy point
if (ending_logprob_max_idx == hs_cur.gold_ending_idx) {
acc += 1.0;
}

// If the gold ending got the maximum logprobe add one accuracy point
if (ending_logprob_max_idx == hs_data[task_idx].gold_ending_idx) {
acc += 1.0;
// Print the accumulated accuracy mean x 100
printf("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
fflush(stdout);
}

// Print the accumulated accuracy mean x 100
printf("%zu\t%.8lf\n",task_idx+1, acc/double(task_idx+1)*100.0);
fflush(stdout);
i0 = i1 - 1;
}

delete [] hs_data;
llama_batch_free(batch);

printf("\n");
}
Expand Down