Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras_hub/src/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ def generate(
)
elif stop_token_ids == "auto":
stop_token_ids = [self.preprocessor.tokenizer.end_token_id]
# Some models like Llama3 use two end tokens: <|eot_id|> in
# "instruct" versions and <|end_of_text|> in others.
if hasattr(self.preprocessor.tokenizer, "end_token2_id"):
stop_token_ids.append(self.preprocessor.tokenizer.end_token2_id)

def preprocess(x):
return self.preprocessor.generate_preprocess(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class Llama3CausalLMPreprocessorTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"]
self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
self.vocab += ["<|eot_id|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
Expand Down
4 changes: 3 additions & 1 deletion keras_hub/src/models/llama3/llama3_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class Llama3CausalLMTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"]
self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
self.vocab += ["<|eot_id|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
Expand Down Expand Up @@ -44,7 +46,7 @@ def test_causal_lm_basics(self):
cls=Llama3CausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
expected_output_shape=(2, 7, 8),
expected_output_shape=(2, 7, 11),
)

def test_generate(self):
Expand Down
27 changes: 25 additions & 2 deletions keras_hub/src/models/llama3/llama3_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,33 @@ def __init__(
self,
vocabulary=None,
merges=None,
bos_token="<|begin_of_text|>",
eos_token="<|end_of_text|>",
misc_special_tokens={"<|start_header_id|>", "<|end_header_id|>"},
**kwargs,
):
self._add_special_token("<|begin_of_text|>", "start_token")
self._add_special_token("<|end_of_text|>", "end_token")
# Note: all special tokens must also appear in "vocabulary"

self._add_special_token(bos_token, "start_token")
misc_special_tokens -= {bos_token}
self._add_special_token(eos_token, "end_token")
misc_special_tokens -= {eos_token}
for i, token in enumerate(misc_special_tokens):
self._add_special_token(token, f"special_token_{i:03d}")

# Hack:
# Llama models use the <|end_of_text|> or the <|eot_id|> as the stop
# token. This info can be read from config when loading a Hugging Face
# checkpoint but no such config exists for Keras checkpoints.
# Setting both probable end tokens when no config is availble will
# make text generation work in all cases as it will stop
# on both end tokens. However, the packer will always use
# "<|end_of_text|>" , which will be the wrong eos_token for "instruct"
# variants of Llama3.
# TODO: load this correctly from a Keras tokenizer config.
if eos_token == "<|end_of_text|>":
self._add_special_token("<|eot_id|>", "end_token2")

self.pad_token_id = 0
super().__init__(
vocabulary=vocabulary,
Expand Down
2 changes: 2 additions & 0 deletions keras_hub/src/models/llama3/llama3_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class Llama3TokenizerTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|end_of_text|>", "<|begin_of_text|>"]
self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
self.vocab += ["<|eot_id|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
Expand Down
26 changes: 21 additions & 5 deletions keras_hub/src/utils/transformers/convert_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,26 @@ def convert_tokenizer(cls, preset, **kwargs):
vocab = tokenizer_config["model"]["vocab"]
merges = tokenizer_config["model"]["merges"]

bot = tokenizer_config["added_tokens"][0] # begin of text
eot = tokenizer_config["added_tokens"][1] # end of text

vocab[bot["content"]] = bot["id"]
vocab[eot["content"]] = eot["id"]
# Load all special tokens with the exception of "reserved" ones.
special_tokens = set()
for token in tokenizer_config["added_tokens"]:
if not token["content"].startswith("<|reserved_special_token_"):
vocab[token["content"]] = token["id"]
special_tokens.add(token["content"])

# Load text start and stop tokens from the config.
# Llama3 uses the <|end_of_text|> end token for regular models
# but uses <|eot_id|> for instruction-tuned variants.
tokenizer_config2 = load_json(preset, "tokenizer_config.json")
bos_token = tokenizer_config2["bos_token"]
eos_token = tokenizer_config2["eos_token"]

kwargs.update(
{
"bos_token": bos_token,
"eos_token": eos_token,
"misc_special_tokens": special_tokens,
}
)

return cls(vocabulary=vocab, merges=merges, **kwargs)
Loading