Skip to content

Commit

Permalink
Fix missing bos token for some models (including Llama-3) (#6050)
Browse files Browse the repository at this point in the history
  • Loading branch information
belladoreai committed May 27, 2024
1 parent 8df68b0 commit a363cdf
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,21 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
input_ids = np.array(input_ids).reshape(1, len(input_ids))
else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
if not add_bos_token:
while len(input_ids[0]) > 0 and input_ids[0][0] == shared.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]

if hasattr(shared.tokenizer, 'bos_token_id'):
if add_bos_token:
if (len(input_ids[0]) > 0 and input_ids[0][0] != shared.tokenizer.bos_token_id) or len(input_ids[0]) == 0:
# Add a missing bos token (it may not have been added due to faulty model metadata)
bos_tensor = torch.tensor([[shared.tokenizer.bos_token_id]])
input_ids = torch.cat((bos_tensor, input_ids), 1)

# Prevent double bos token due to jinja templates with <s> somewhere
while len(input_ids[0]) > 1 and input_ids[0][0] == shared.tokenizer.bos_token_id and input_ids[0][1] == shared.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]
else:
# Remove any bos token that may have been added
while len(input_ids[0]) > 0 and input_ids[0][0] == shared.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]

# Handling truncation
if truncation_length is not None:
Expand Down

0 comments on commit a363cdf

Please sign in to comment.