diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 8ee09ccfbedd1..d83777411ddab 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -999,6 +999,8 @@ struct winogrande_entry { size_t i_logits; size_t common_prefix; size_t required_tokens; + size_t n_base1; // number of tokens for context + choice 1 + size_t n_base2; // number of tokens for context + choice 2 std::vector seq_tokens[2]; }; @@ -1038,38 +1040,6 @@ static std::vector load_winogrande_from_csv(const std::string& auto choice2 = line.substr(comma_pos[2]+1, comma_pos[3] - comma_pos[2] - 1); auto answer = line.substr(comma_pos[3]+1, line.size() - comma_pos[3] - 1); auto index = line.substr(0, comma_pos[0]); - if ('a' <= sentence[0] && sentence[0] <= 'z') { - // make the first letter a capital letter - sentence[0] -= 'a' - 'A'; - } - for (int i = 0; i < (int) sentence.size() - 1; ++i) { - // trim repeated spaces and spaces before punctuation - if (sentence[i] == ' ') { - char next = sentence[i+1]; - if (next == ' ' || next == ',' || next == '.' || next == '\'') { - char r[2] = { next, 0 }; - sentence.replace(i, 2, r); - --i; // stay at the same index for repeated spaces - } - } else if (sentence[i] == ',' || sentence[i] == '.') { - if (sentence[i] == sentence[i+1]) { - // trim repeated punctuation (forward to work at the end of sentences) - char r[2] = { sentence[i], 0 }; - sentence.replace(i, 2, r); - --i; // same index to then run the other checks on that punctuation - } else if (0 < i && sentence[i-1] == sentence[i]) { - // trim repeated punctuation (looks back to work with the space trim) - char r[2] = { sentence[i], 0 }; - sentence.replace(i-1, 2, r); - i -= 2; // go back because content was shifted - } else if (sentence[i+1] != ' ') { - // add missing space after punctuation - // (since the loop stops before the end, this adds no trailing space) - char r[3] = { sentence[i], ' ', 0 }; - sentence.replace(i, 1, r); - } - } - } int where = 0; for ( ; where < int(sentence.size()); ++where) { if (sentence[where] == '_') break; @@ -1106,6 +1076,8 @@ static std::vector load_winogrande_from_csv(const std::string& */ static void winogrande_score(llama_context * ctx, const gpt_params & params) { + constexpr int k_min_trailing_ctx = 3; + auto data = load_winogrande_from_csv(params.prompt); if (data.empty()) { fprintf(stderr, "%s: no tasks\n", __func__); @@ -1150,11 +1122,13 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { task.common_prefix++; } + // TODO: the last token of each of the sequences don't need to be evaluated task.required_tokens = task.common_prefix + task.seq_tokens[0].size() - task.common_prefix + - task.seq_tokens[1].size() - task.common_prefix - // the last tokens don't need to be evaluated - - 2; + task.seq_tokens[1].size() - task.common_prefix; + + task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size(); + task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size(); } fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__); @@ -1201,8 +1175,8 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { n_logits += 1; for (int s = 0; s < 2; ++s) { - // end before the last token, no need to predict past the end of the sequences - for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size() - 1; ++i) { + // TODO: end before the last token, no need to predict past the end of the sequences + for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true); n_logits += 1; } @@ -1234,20 +1208,23 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { for (size_t i = i0; i < i1; ++i) { auto & task = data[i]; - // start from the end of the common prefix - size_t li = 0; - for (size_t j = task.common_prefix-1; j < task.seq_tokens[0].size()-1; ++j) { + const bool skip_choice = + task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx && + task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx; + + const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix; + const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0; + size_t li = n_base1 - task.common_prefix; + for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) { eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[0][j+1]); } - // first token of the second choice is predicted by the end of the common prefix - eval_pairs.emplace_back(task.i_logits, task.seq_tokens[1][task.common_prefix]); - for (size_t j = task.common_prefix; j < task.seq_tokens[1].size()-1; ++j) { + const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix; + const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0; + // FIXME: this uses the wrong first logits when not skipping the choice word + li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix; + for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) { eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[1][j+1]); } - if (i < i1 - 1) { - // make sure all logits have been processed as expected - GGML_ASSERT(task.i_logits + li == data[i+1].i_logits); - } } compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results); @@ -1255,17 +1232,25 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { for (size_t i = i0; i < i1; ++i) { auto & task = data[i]; + const bool skip_choice = + task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx && + task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx; + float score_1st = 0; - for (size_t j = task.common_prefix-1; j < task.seq_tokens[0].size()-1; ++j) { + const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix; + const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0; + for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) { score_1st += eval_results[ir++]; } - score_1st /= (task.seq_tokens[0].size() - task.common_prefix); + score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st); float score_2nd = 0; - for (size_t j = task.common_prefix-1; j < task.seq_tokens[1].size()-1; ++j) { + const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix; + const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0; + for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) { score_2nd += eval_results[ir++]; } - score_2nd /= (task.seq_tokens[1].size() - task.common_prefix); + score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd); int result = score_1st > score_2nd ? 1 : 2;