diff --git a/keras_nlp/src/utils/transformers/convert.py b/keras_nlp/src/utils/transformers/convert.py index 21c54bd1fb..568bd8e79d 100644 --- a/keras_nlp/src/utils/transformers/convert.py +++ b/keras_nlp/src/utils/transformers/convert.py @@ -18,6 +18,8 @@ from keras_nlp.src.utils.transformers.convert_bert import load_bert_tokenizer from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_backbone from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_tokenizer +from keras_nlp.src.utils.transformers.convert_gpt2 import load_gpt2_backbone +from keras_nlp.src.utils.transformers.convert_gpt2 import load_gpt2_tokenizer from keras_nlp.src.utils.transformers.convert_llama3 import load_llama3_backbone from keras_nlp.src.utils.transformers.convert_llama3 import ( load_llama3_tokenizer, @@ -52,6 +54,8 @@ def load_transformers_backbone(cls, preset, load_weights): return load_llama3_backbone(cls, preset, load_weights) if cls.__name__ == "PaliGemmaBackbone": return load_pali_gemma_backbone(cls, preset, load_weights) + if cls.__name__ == "GPT2Backbone": + return load_gpt2_backbone(cls, preset, load_weights) raise ValueError( f"{cls} has not been ported from the Hugging Face format yet. " "Please check Hugging Face Hub for the Keras model. " @@ -79,6 +83,8 @@ def load_transformers_tokenizer(cls, preset): return load_llama3_tokenizer(cls, preset) if cls.__name__ == "PaliGemmaTokenizer": return load_pali_gemma_tokenizer(cls, preset) + if cls.__name__ == "GPT2Tokenizer": + return load_gpt2_tokenizer(cls, preset) raise ValueError( f"{cls} has not been ported from the Hugging Face format yet. " "Please check Hugging Face Hub for the Keras model. " diff --git a/keras_nlp/src/utils/transformers/convert_gpt2.py b/keras_nlp/src/utils/transformers/convert_gpt2.py new file mode 100644 index 0000000000..2ac8a9a8a2 --- /dev/null +++ b/keras_nlp/src/utils/transformers/convert_gpt2.py @@ -0,0 +1,186 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.utils.preset_utils import get_file +from keras_nlp.src.utils.preset_utils import jax_memory_cleanup +from keras_nlp.src.utils.preset_utils import load_config +from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader + + +def convert_backbone_config(transformers_config): + return { + "vocabulary_size": transformers_config["vocab_size"], + "num_layers": transformers_config["n_layer"], + "num_heads": transformers_config["n_head"], + "hidden_dim": transformers_config["n_embd"], + "intermediate_dim": transformers_config["n_embd"] * 4, + "dropout": transformers_config["resid_pdrop"], + "max_sequence_length": transformers_config["n_positions"], + } + + +def convert_weights(backbone, loader, transformers_config): + # Embeddings + loader.port_weight( + keras_variable=backbone.token_embedding.embeddings, + hf_weight_key="wte.weight", + ) + loader.port_weight( + keras_variable=backbone.position_embedding.position_embeddings, + hf_weight_key="wpe.weight", + ) + + # Attention blocks + for index in range(backbone.num_layers): + decoder_layer = backbone.transformer_layers[index] + + # Norm layers + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer_norm.gamma, + hf_weight_key=f"h.{index}.ln_1.weight", + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer_norm.beta, + hf_weight_key=f"h.{index}.ln_1.bias", + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_layer_norm.gamma, + hf_weight_key=f"h.{index}.ln_2.weight", + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_layer_norm.beta, + hf_weight_key=f"h.{index}.ln_2.bias", + ) + + # Attention layers + n_embd = transformers_config["n_embd"] + + # Query + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer.query_dense.kernel, + hf_weight_key=f"h.{index}.attn.c_attn.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + hf_tensor[:, :n_embd], keras_shape + ), + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer.query_dense.bias, + hf_weight_key=f"h.{index}.attn.c_attn.bias", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + hf_tensor[:n_embd], keras_shape + ), + ) + + # Key + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer.key_dense.kernel, + hf_weight_key=f"h.{index}.attn.c_attn.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + hf_tensor[:, n_embd : 2 * n_embd], keras_shape + ), + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer.key_dense.bias, + hf_weight_key=f"h.{index}.attn.c_attn.bias", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + hf_tensor[n_embd : 2 * n_embd], keras_shape + ), + ) + + # Value + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer.value_dense.kernel, + hf_weight_key=f"h.{index}.attn.c_attn.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + hf_tensor[:, 2 * n_embd :], keras_shape + ), + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer.value_dense.bias, + hf_weight_key=f"h.{index}.attn.c_attn.bias", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + hf_tensor[2 * n_embd :], keras_shape + ), + ) + + # Output + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer.output_dense.kernel, + hf_weight_key=f"h.{index}.attn.c_proj.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + hf_tensor, keras_shape + ), + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer.output_dense.bias, + hf_weight_key=f"h.{index}.attn.c_proj.bias", + ) + + # MLP layers + loader.port_weight( + keras_variable=decoder_layer._feedforward_intermediate_dense.kernel, + hf_weight_key=f"h.{index}.mlp.c_fc.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + hf_tensor, keras_shape + ), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_intermediate_dense.bias, + hf_weight_key=f"h.{index}.mlp.c_fc.bias", + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_output_dense.kernel, + hf_weight_key=f"h.{index}.mlp.c_proj.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + hf_tensor, keras_shape + ), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_output_dense.bias, + hf_weight_key=f"h.{index}.mlp.c_proj.bias", + ) + + # Normalization + loader.port_weight( + keras_variable=backbone.layer_norm.gamma, + hf_weight_key="ln_f.weight", + ) + loader.port_weight( + keras_variable=backbone.layer_norm.beta, + hf_weight_key="ln_f.bias", + ) + + return backbone + + +def load_gpt2_backbone(cls, preset, load_weights): + transformers_config = load_config(preset, HF_CONFIG_FILE) + keras_config = convert_backbone_config(transformers_config) + backbone = cls(**keras_config) + if load_weights: + jax_memory_cleanup(backbone) + with SafetensorLoader(preset) as loader: + convert_weights(backbone, loader, transformers_config) + return backbone + + +def load_gpt2_tokenizer(cls, preset): + vocab_file = get_file(preset, "vocab.json") + merges_file = get_file(preset, "merges.txt") + return cls( + vocabulary=vocab_file, + merges=merges_file, + ) diff --git a/keras_nlp/src/utils/transformers/convert_gpt2_test.py b/keras_nlp/src/utils/transformers/convert_gpt2_test.py new file mode 100644 index 0000000000..c7b65eb87f --- /dev/null +++ b/keras_nlp/src/utils/transformers/convert_gpt2_test.py @@ -0,0 +1,27 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM +from keras_nlp.src.tests.test_case import TestCase + + +class TestTask(TestCase): + @pytest.mark.large + def test_convert_tiny_preset(self): + model = GPT2CausalLM.from_preset("hf://openai-community/gpt2") + prompt = "What is your favorite condiment?" + model.generate([prompt], max_length=15) + + # TODO: compare numerics with huggingface model diff --git a/keras_nlp/src/utils/transformers/safetensor_utils.py b/keras_nlp/src/utils/transformers/safetensor_utils.py index 2376b3ba68..40ef473ff3 100644 --- a/keras_nlp/src/utils/transformers/safetensor_utils.py +++ b/keras_nlp/src/utils/transformers/safetensor_utils.py @@ -42,12 +42,41 @@ def __init__(self, preset): else: self.safetensor_config = None self.safetensor_files = {} + self.prefix = None + + def get_prefixed_key(self, hf_weight_key, dict_like): + """ + Determine and return a prefixed key for a given hf weight key. + + This method checks if there's a common prefix for the weight keys and caches it + for future use. + + Args: + hf_weight_key (str): The hf weight key to check for a prefix. + dict_like (object): An object to get keys of safetensor file using keys() method. + + Returns: + str: The full key including the prefix (if any). + """ + if self.prefix is not None: + return self.prefix + hf_weight_key + + for full_key in dict_like.keys(): + if full_key.endswith(hf_weight_key) and full_key != hf_weight_key: + self.prefix = full_key[: -len(hf_weight_key)] + return full_key + + self.prefix = "" + return hf_weight_key def get_tensor(self, hf_weight_key): if self.safetensor_config is None: fname = SAFETENSOR_FILE else: - fname = self.safetensor_config["weight_map"][hf_weight_key] + full_key = self.get_prefixed_key( + hf_weight_key, self.safetensor_config["weight_map"] + ) + fname = self.safetensor_config["weight_map"][full_key] if fname in self.safetensor_files: file = self.safetensor_files[fname] @@ -58,7 +87,8 @@ def get_tensor(self, hf_weight_key): ) self.safetensor_files[fname] = file - return file.get_tensor(hf_weight_key) + full_key = self.get_prefixed_key(hf_weight_key, file) + return file.get_tensor(full_key) def port_weight(self, keras_variable, hf_weight_key, hook_fn=None): hf_tensor = self.get_tensor(hf_weight_key)