Skip to content

Commit

Permalink
infill. : fix tokenization (ggerganov#3508)
Browse files Browse the repository at this point in the history
* infill tokens correction

* serverinfill tokens correction

* removing any leading whitespace from infill suffix and removing leeading space token from suffix when params.escape

* removing any leading whitespace from infill suffix and removing leeading space token from suffix when params.escape

* only rm when params.escape, rm space if possible which is added back or rm added space token

* only rm when params.escape, rm space if possible which is added back or rm added space token

* Revert "only rm when params.escape, rm space if possible which is added back or rm added space token"

This reverts commit 63ba0b6.

* fix interactive prompt escaping and fix server infill leading space handling

* rm unnecessary bool check
  • Loading branch information
vvhg1 committed Oct 10, 2023
1 parent 95bd60a commit 11ea5c7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
37 changes: 33 additions & 4 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,22 @@ int main(int argc, char ** argv) {
const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM;
LOG("add_bos: %d\n", add_bos);

bool suff_rm_leading_spc = params.escape;
if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
std::vector<llama_token> embd_inp;
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos);
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos);
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
const int space_token = 29871;
if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
inp_sfx.erase(inp_sfx.begin());
}
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
if (add_bos) {
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
}
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
Expand Down Expand Up @@ -627,10 +639,27 @@ int main(int argc, char ** argv) {
buffer.clear();
// done taking input, reset color
console::set_display(console::reset);

if (params.escape) {
//process escape sequences, for the initial prompt this is done in common.cpp when we load the params, but for the interactive mode we need to do it here
process_escapes(params.input_prefix);
process_escapes(params.input_suffix);
}
suff_rm_leading_spc = params.escape;
if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
// tokenize new prefix and suffix
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos);
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos);
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
inp_sfx.erase(inp_sfx.begin());
}
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
if (add_bos) {
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
}
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
Expand Down
15 changes: 13 additions & 2 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,20 @@ struct llama_server_context

void loadInfill()
{
auto prefix_tokens = tokenize(params.input_prefix, true); // always add BOS
auto suffix_tokens = tokenize(params.input_suffix, true); // always add BOS
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}

auto prefix_tokens = tokenize(params.input_prefix, false);
auto suffix_tokens = tokenize(params.input_suffix, false);
const int space_token = 29871;
if (suff_rm_leading_spc && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(ctx));
Expand Down

0 comments on commit 11ea5c7

Please sign in to comment.