diff --git a/examples/common.cpp b/examples/common.cpp index fe00278c2..7b01089b0 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -39,6 +39,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.top_p = std::stof(argv[++i]); } else if (arg == "--temp") { params.temp = std::stof(argv[++i]); + } else if (arg == "--repeat-last-n") { + params.repeat_last_n = std::stof(argv[++i]); + } else if (arg == "--repeat-penalty") { + params.repeat_penalty = std::stof(argv[++i]); } else if (arg == "-b" || arg == "--batch_size") { params.n_batch = std::stoi(argv[++i]); } else if (arg == "-m" || arg == "--model") { @@ -90,6 +94,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k); fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p); fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); + fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n); + fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); diff --git a/examples/common.h b/examples/common.h index 7e9b867d3..12b2b339d 100644 --- a/examples/common.h +++ b/examples/common.h @@ -23,6 +23,8 @@ struct gpt_params { int32_t top_k = 40; float top_p = 0.9f; float temp = 0.9f; + int32_t repeat_last_n = 64; + float repeat_penalty = 1.00f; int32_t n_batch = 8; // batch size for prompt processing diff --git a/examples/starcoder/main.cpp b/examples/starcoder/main.cpp index 2016f8974..5c6065980 100644 --- a/examples/starcoder/main.cpp +++ b/examples/starcoder/main.cpp @@ -782,6 +782,16 @@ int main(int argc, char ** argv) { test_gpt_tokenizer(vocab, params.token_test); } + if (params.repeat_last_n == -1) { + params.repeat_last_n = model.hparams.n_ctx; + } + printf("\n"); + printf("%s: temp = %.3f\n", __func__, params.temp); + printf("%s: top_k = %d\n", __func__, params.top_k); + printf("%s: top_p = %.3f\n", __func__, params.top_p); + printf("%s: repeat_last_n = %d\n", __func__, params.repeat_last_n); + printf("%s: repeat_penalty = %.3f\n", __func__, params.repeat_penalty); + int n_past = 0; int64_t t_sample_us = 0; @@ -789,6 +799,9 @@ int main(int argc, char ** argv) { std::vector logits; + std::vector last_n_tokens(model.hparams.n_ctx); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + // tokenize the prompt std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); @@ -847,17 +860,23 @@ int main(int argc, char ** argv) { { const int64_t t_start_sample_us = ggml_time_us(); - id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); - + id = gpt_sample_top_k_top_p_repeat(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, params.repeat_last_n, params.repeat_penalty, rng); t_sample_us += ggml_time_us() - t_start_sample_us; } // add it to the context embd.push_back(id); + + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(id); } else { // if here, it means we are still processing the input prompt for (int k = i; k < embd_inp.size(); k++) { embd.push_back(embd_inp[k]); + + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(embd_inp[k]); + if (embd.size() >= params.n_batch) { break; }