Skip to content

Commit

Permalink
starcoder : add repeat penalty (#311)
Browse files Browse the repository at this point in the history
* implement repeat penalty processing for starcoder

* show effective parameters at starcoder startup

---------

Co-authored-by: Mike Ravkine <kryptk@gmail.com>
  • Loading branch information
the-crypt-keeper and KryptK420 committed Jul 2, 2023
1 parent cad56f5 commit dfef9c6
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
6 changes: 6 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.top_p = std::stof(argv[++i]);
} else if (arg == "--temp") {
params.temp = std::stof(argv[++i]);
} else if (arg == "--repeat-last-n") {
params.repeat_last_n = std::stof(argv[++i]);
} else if (arg == "--repeat-penalty") {
params.repeat_penalty = std::stof(argv[++i]);
} else if (arg == "-b" || arg == "--batch_size") {
params.n_batch = std::stoi(argv[++i]);
} else if (arg == "-m" || arg == "--model") {
Expand Down Expand Up @@ -90,6 +94,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
Expand Down
2 changes: 2 additions & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ struct gpt_params {
int32_t top_k = 40;
float top_p = 0.9f;
float temp = 0.9f;
int32_t repeat_last_n = 64;
float repeat_penalty = 1.00f;

int32_t n_batch = 8; // batch size for prompt processing

Expand Down
23 changes: 21 additions & 2 deletions examples/starcoder/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -782,13 +782,26 @@ int main(int argc, char ** argv) {
test_gpt_tokenizer(vocab, params.token_test);
}

if (params.repeat_last_n == -1) {
params.repeat_last_n = model.hparams.n_ctx;
}
printf("\n");
printf("%s: temp = %.3f\n", __func__, params.temp);
printf("%s: top_k = %d\n", __func__, params.top_k);
printf("%s: top_p = %.3f\n", __func__, params.top_p);
printf("%s: repeat_last_n = %d\n", __func__, params.repeat_last_n);
printf("%s: repeat_penalty = %.3f\n", __func__, params.repeat_penalty);

int n_past = 0;

int64_t t_sample_us = 0;
int64_t t_predict_us = 0;

std::vector<float> logits;

std::vector<int32_t> last_n_tokens(model.hparams.n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);

// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);

Expand Down Expand Up @@ -847,17 +860,23 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_sample_us = ggml_time_us();

id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);

id = gpt_sample_top_k_top_p_repeat(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, params.repeat_last_n, params.repeat_penalty, rng);
t_sample_us += ggml_time_us() - t_start_sample_us;
}

// add it to the context
embd.push_back(id);

last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
} else {
// if here, it means we are still processing the input prompt
for (int k = i; k < embd_inp.size(); k++) {
embd.push_back(embd_inp[k]);

last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[k]);

if (embd.size() >= params.n_batch) {
break;
}
Expand Down

0 comments on commit dfef9c6

Please sign in to comment.