diff --git a/keras_hub/src/models/gpt2/gpt2_tokenizer.py b/keras_hub/src/models/gpt2/gpt2_tokenizer.py index 8bb431ee8b..216f9405e7 100644 --- a/keras_hub/src/models/gpt2/gpt2_tokenizer.py +++ b/keras_hub/src/models/gpt2/gpt2_tokenizer.py @@ -71,3 +71,29 @@ def __init__( merges=merges, **kwargs, ) + + def save_assets(self, dir_path): + import json + import os + import shutil + + # Save vocabulary. + if isinstance(self.vocabulary, str): + # If `vocabulary` is a file path, copy it. + shutil.copy( + self.vocabulary, os.path.join(dir_path, "vocabulary.json") + ) + else: + # Otherwise, `vocabulary` is a dict. Save it to a JSON file. + with open(os.path.join(dir_path, "vocabulary.json"), "w") as f: + json.dump(self.vocabulary, f) + + # Save merges. + if isinstance(self.merges, str): + # If `merges` is a file path, copy it. + shutil.copy(self.merges, os.path.join(dir_path, "merges.txt")) + else: + # Otherwise, `merges` is a list. Save it to a text file. + with open(os.path.join(dir_path, "merges.txt"), "w") as f: + for merge in self.merges: + f.write(f"{merge}\n") diff --git a/keras_hub/src/utils/transformers/export/gemma.py b/keras_hub/src/utils/transformers/export/gemma.py index 846e391937..cc4f31a937 100644 --- a/keras_hub/src/utils/transformers/export/gemma.py +++ b/keras_hub/src/utils/transformers/export/gemma.py @@ -1,9 +1,10 @@ import keras.ops as ops +import transformers def get_gemma_config(backbone): token_embedding_layer = backbone.get_layer("token_embedding") - hf_config = { + hf_config_dict = { "vocab_size": backbone.vocabulary_size, "num_hidden_layers": backbone.num_layers, "num_attention_heads": backbone.num_query_heads, @@ -18,13 +19,14 @@ def get_gemma_config(backbone): "eos_token_id": 1, "model_type": "gemma", } + hf_config = transformers.GemmaConfig(**hf_config_dict) return hf_config def get_gemma_weights_map(backbone, include_lm_head=False): weights_dict = {} - # Map token embedding + # Map token embeddings token_embedding_layer = backbone.get_layer("token_embedding") weights_dict["model.embed_tokens.weight"] = token_embedding_layer.weights[0] diff --git a/keras_hub/src/utils/transformers/export/gpt2.py b/keras_hub/src/utils/transformers/export/gpt2.py new file mode 100644 index 0000000000..daae5df3e6 --- /dev/null +++ b/keras_hub/src/utils/transformers/export/gpt2.py @@ -0,0 +1,146 @@ +import keras.ops as ops +import transformers + + +def get_gpt2_config(keras_model): + """Convert Keras GPT-2 config to Hugging Face GPT2Config.""" + return transformers.GPT2Config( + vocab_size=keras_model.vocabulary_size, + n_positions=keras_model.max_sequence_length, + n_embd=keras_model.hidden_dim, + n_layer=keras_model.num_layers, + n_head=keras_model.num_heads, + n_inner=keras_model.intermediate_dim, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + ) + + +def get_gpt2_weights_map(keras_model, include_lm_head=False): + """Create a weights map for a given GPT-2 model.""" + weights_map = {} + + # Token and position embeddings + weights_map["transformer.wte.weight"] = keras_model.get_layer( + "token_embedding" + ).embeddings + weights_map["transformer.wpe.weight"] = keras_model.get_layer( + "position_embedding" + ).position_embeddings + + for i in range(keras_model.num_layers): + # Attention weights + q_w = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._query_dense.kernel + k_w = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._key_dense.kernel + v_w = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._value_dense.kernel + q_b = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._query_dense.bias + k_b = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._key_dense.bias + v_b = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._value_dense.bias + + q_w = ops.reshape(q_w, (keras_model.hidden_dim, keras_model.hidden_dim)) + k_w = ops.reshape(k_w, (keras_model.hidden_dim, keras_model.hidden_dim)) + v_w = ops.reshape(v_w, (keras_model.hidden_dim, keras_model.hidden_dim)) + + c_attn_w = tf.concat([q_w, k_w, v_w], axis=-1) + weights_map[f"transformer.h.{i}.attn.c_attn.weight"] = c_attn_w + + q_b = tf.reshape(q_b, [-1]) + k_b = tf.reshape(k_b, [-1]) + v_b = tf.reshape(v_b, [-1]) + + c_attn_b = tf.concat([q_b, k_b, v_b], axis=-1) + weights_map[f"transformer.h.{i}.attn.c_attn.bias"] = c_attn_b + + # Attention projection + c_proj_w = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._output_dense.kernel + c_proj_w = tf.reshape( + c_proj_w, (keras_model.hidden_dim, keras_model.hidden_dim) + ) + weights_map[f"transformer.h.{i}.attn.c_proj.weight"] = c_proj_w + weights_map[f"transformer.h.{i}.attn.c_proj.bias"] = ( + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._output_dense.bias + ) + + # Layer norms + weights_map[f"transformer.h.{i}.ln_1.weight"] = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer_norm.gamma + weights_map[f"transformer.h.{i}.ln_1.bias"] = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer_norm.beta + weights_map[f"transformer.h.{i}.ln_2.weight"] = keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_layer_norm.gamma + weights_map[f"transformer.h.{i}.ln_2.bias"] = keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_layer_norm.beta + + # MLP + c_fc_w = keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_intermediate_dense.kernel + weights_map[f"transformer.h.{i}.mlp.c_fc.weight"] = c_fc_w + weights_map[f"transformer.h.{i}.mlp.c_fc.bias"] = keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_intermediate_dense.bias + c_proj_w_mlp = keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_output_dense.kernel + weights_map[f"transformer.h.{i}.mlp.c_proj.weight"] = c_proj_w_mlp + weights_map[f"transformer.h.{i}.mlp.c_proj.bias"] = ( + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_output_dense.bias + ) + + # Final layer norm + weights_map["transformer.ln_f.weight"] = keras_model.get_layer( + "layer_norm" + ).gamma + weights_map["transformer.ln_f.bias"] = keras_model.get_layer( + "layer_norm" + ).beta + + if include_lm_head: + # lm_head is tied to token embeddings + weights_map["lm_head.weight"] = weights_map["transformer.wte.weight"] + + return weights_map + + +def get_gpt2_tokenizer_config(tokenizer): + return { + "model_type": "gpt2", + "bos_token": "<|endoftext|>", + "eos_token": "<|endoftext|>", + "unk_token": "<|endoftext|>", + } diff --git a/keras_hub/src/utils/transformers/export/gpt2_test.py b/keras_hub/src/utils/transformers/export/gpt2_test.py new file mode 100644 index 0000000000..2bc5c29e44 --- /dev/null +++ b/keras_hub/src/utils/transformers/export/gpt2_test.py @@ -0,0 +1,122 @@ +import os +import shutil +import sys +import tempfile +from os.path import abspath +from os.path import dirname + +# import keras +import numpy as np +import keras.ops as ops + +# import torch +from absl.testing import parameterized +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer + +# Add the project root to the Python path. +sys.path.insert( + 0, dirname(dirname(dirname(dirname(dirname(abspath(__file__)))))) +) + +from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM +from keras_hub.src.utils.transformers.export.hf_exporter import ( + export_to_safetensors, +) + + +def to_numpy(x): + # Torch tensor + if hasattr(x, "detach") and hasattr(x, "cpu"): + return x.detach().cpu().numpy() + + # TF tensor + if hasattr(x, "numpy"): + return x.numpy() + + # Numpy + if isinstance(x, np.ndarray): + return x + + # KerasTensor or ragged wrapper → convert to TF → numpy + try: + import tensorflow as tf + + return tf.convert_to_tensor(x).numpy() + except Exception: + pass + + raise TypeError(f"Cannot convert value of type {type(x)} to numpy") + + +class GPT2ExportTest(tf.test.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + ("gpt2_base_en", "gpt2_base_en"), + ) + def test_gpt2_export(self, preset): + # Create a temporary directory to save the converted model. + temp_dir = tempfile.mkdtemp() + output_path = os.path.join(temp_dir, preset) + + # Load Keras model. + keras_model = GPT2CausalLM.from_preset(preset) + + # Export to Hugging Face format. + export_to_safetensors(keras_model, output_path) + + # Load the converted model with Hugging Face Transformers. + hf_model = AutoModelForCausalLM.from_pretrained(output_path) + hf_tokenizer = AutoTokenizer.from_pretrained(output_path) + + # Assertions for config parameters. + self.assertEqual( + keras_model.backbone.hidden_dim, hf_model.config.hidden_size + ) + self.assertEqual( + keras_model.backbone.num_layers, hf_model.config.n_layer + ) + self.assertEqual(keras_model.backbone.num_heads, hf_model.config.n_head) + self.assertEqual( + keras_model.backbone.intermediate_dim, hf_model.config.n_inner + ) + self.assertEqual( + keras_model.backbone.vocabulary_size, hf_model.config.vocab_size + ) + self.assertEqual( + keras_model.backbone.max_sequence_length, + hf_model.config.n_positions, + ) + + # Test logits. + prompt = "Hello, my name is" + token_ids = ops.array( + keras_model.preprocessor.tokenizer(ops.array([prompt])) + ) + padding_mask = tf.ones_like(token_ids, dtype=tf.int32) + keras_inputs = {"token_ids": token_ids, "padding_mask": padding_mask} + keras_logits = keras_model(keras_inputs) + + hf_inputs = hf_tokenizer(prompt, return_tensors="pt") + hf_logits = hf_model(**hf_inputs).logits + print(hf_logits) + + # Compare logits. + # Keras logits are (batch_size, sequence_length, vocab_size) + # HF logits are (batch_size, sequence_length, vocab_size) + # We need to convert Keras logits to numpy and then to torch tensor + # for comparison. + + # Convert Keras logits (TF) -> numpy + keras_logits_np = to_numpy(keras_logits) + + # Convert HF logits (Torch, possibly MPS) -> numpy + hf_logits_np = to_numpy(hf_logits) + + self.assertAllClose(keras_logits_np, hf_logits_np, atol=1e-3, rtol=1e-3) + + # Clean up the temporary directory. + shutil.rmtree(temp_dir) + + +if __name__ == "__main__": + tf.test.main() diff --git a/keras_hub/src/utils/transformers/export/hf_exporter.py b/keras_hub/src/utils/transformers/export/hf_exporter.py index 1593987ca9..2263427905 100644 --- a/keras_hub/src/utils/transformers/export/hf_exporter.py +++ b/keras_hub/src/utils/transformers/export/hf_exporter.py @@ -10,19 +10,27 @@ get_gemma_tokenizer_config, ) from keras_hub.src.utils.transformers.export.gemma import get_gemma_weights_map +from keras_hub.src.utils.transformers.export.gpt2 import get_gpt2_config +from keras_hub.src.utils.transformers.export.gpt2 import ( + get_gpt2_tokenizer_config, +) +from keras_hub.src.utils.transformers.export.gpt2 import get_gpt2_weights_map MODEL_CONFIGS = { "GemmaBackbone": get_gemma_config, + "GPT2Backbone": get_gpt2_config, # Add for future models, e.g., "MistralBackbone": get_mistral_config } MODEL_EXPORTERS = { "GemmaBackbone": get_gemma_weights_map, + "GPT2Backbone": get_gpt2_weights_map, # Add for future models, e.g., "MistralBackbone": get_mistral_weights_map } MODEL_TOKENIZER_CONFIGS = { "GemmaTokenizer": get_gemma_tokenizer_config, + "GPT2Tokenizer": get_gpt2_tokenizer_config, # Add for future models, e.g., "MistralTokenizer": # get_mistral_tokenizer_config } @@ -58,19 +66,48 @@ def export_backbone(backbone, path, include_lm_head=False): os.makedirs(path, exist_ok=True) config_path = os.path.join(path, "config.json") with open(config_path, "w") as f: - json.dump(hf_config, f) + json.dump(hf_config.to_dict(), f) # Save weights based on backend weights_path = os.path.join(path, "model.safetensors") if backend == "torch": + import torch from safetensors.torch import save_file - weights_dict_contiguous = { - k: v.value.contiguous() if hasattr(v, "value") else v.contiguous() - for k, v in weights_dict.items() - } - save_file( - weights_dict_contiguous, weights_path, metadata={"format": "pt"} - ) + weights_dict_torch = {} + + for k, v in weights_dict.items(): + tensor = v.value if hasattr(v, "value") else v + + # Torch tensor -> move to CPU + if isinstance(tensor, torch.Tensor): + t = tensor.detach().to("cpu") + + # TensorFlow / JAX -> convert via numpy() + elif hasattr(tensor, "numpy"): + t = torch.tensor(tensor.numpy()) + + # numpy array + elif hasattr(tensor, "__array__"): + t = torch.tensor(tensor) + + else: + raise TypeError(f"Unsupported tensor type: {type(tensor)}") + + weights_dict_torch[k] = t.contiguous() + + # ---- GPT-2 tied weights ---- + if ( + "lm_head.weight" in weights_dict_torch + and "transformer.wte.weight" in weights_dict_torch + ): + wte = weights_dict_torch["transformer.wte.weight"] + lm = weights_dict_torch["lm_head.weight"] + + if wte.data_ptr() == lm.data_ptr(): + weights_dict_torch["lm_head.weight"] = lm.clone().contiguous() + + save_file(weights_dict_torch, weights_path, metadata={"format": "pt"}) + elif backend == "tensorflow": from safetensors.tensorflow import save_file @@ -104,18 +141,33 @@ def export_tokenizer(tokenizer, path): tokenizer_config_path = os.path.join(path, "tokenizer_config.json") with open(tokenizer_config_path, "w") as f: json.dump(tokenizer_config, f, indent=4) - # Rename vocabulary file - vocab_spm_path = os.path.join(path, "vocabulary.spm") - tokenizer_model_path = os.path.join(path, "tokenizer.model") - if os.path.exists(vocab_spm_path): - shutil.move(vocab_spm_path, tokenizer_model_path) - else: - warnings.warn( - f"{vocab_spm_path} not found. Tokenizer may not load " - "correctly. Ensure that the tokenizer configuration " - "is correct and that the vocabulary file is present " - "in the original model." - ) + + if tokenizer_type == "GemmaTokenizer": + # Rename vocabulary file + vocab_spm_path = os.path.join(path, "vocabulary.spm") + tokenizer_model_path = os.path.join(path, "tokenizer.model") + if os.path.exists(vocab_spm_path): + shutil.move(vocab_spm_path, tokenizer_model_path) + else: + warnings.warn( + f"{vocab_spm_path} not found. Tokenizer may not load " + "correctly. Ensure that the tokenizer configuration " + "is correct and that the vocabulary file is present " + "in the original model." + ) + elif tokenizer_type == "GPT2Tokenizer": + # Rename vocabulary file + vocab_json_path = os.path.join(path, "vocabulary.json") + renamed_vocab_json_path = os.path.join(path, "vocab.json") + if os.path.exists(vocab_json_path): + shutil.move(vocab_json_path, renamed_vocab_json_path) + else: + warnings.warn( + f"{vocab_json_path} not found. Tokenizer may not load " + "correctly. Ensure that the tokenizer configuration " + "is correct and that the vocabulary file is present " + "in the original model." + ) def export_to_safetensors(keras_model, path): diff --git a/tools/checkpoint_conversion/convert_gpt2_checkpoints.py b/tools/checkpoint_conversion/convert_gpt2_checkpoints.py index 00bdba477e..3c54fc2c9e 100644 --- a/tools/checkpoint_conversion/convert_gpt2_checkpoints.py +++ b/tools/checkpoint_conversion/convert_gpt2_checkpoints.py @@ -1,5 +1,8 @@ import json import os +import sys +from os.path import abspath +from os.path import dirname import numpy as np import requests @@ -9,9 +12,16 @@ from absl import flags from checkpoint_conversion_utils import get_md5_checksum -# Temporarily directly import gpt2 until we expose it. -from keras_hub.models.gpt2.gpt2_backbone import GPT2Backbone -from keras_hub.models.gpt2.gpt2_tokenizer import GPT2Tokenizer +# Add the project root to the Python path. +sys.path.insert(0, dirname(dirname(dirname(abspath(__file__))))) + + +import keras_hub +from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone +from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer +from keras_hub.src.utils.transformers.export.hf_exporter import ( + export_to_safetensors, +) PRESET_MAP = { "gpt2_base_en": ("124M", "gpt2"), @@ -30,6 +40,17 @@ flags.DEFINE_string( "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) +flags.DEFINE_string( + "output_path", + None, + "The path to save the converted model.", +) + +flags.DEFINE_bool( + "export_safetensors", + False, + "Whether to export the model in Safetensors format.", +) def download_model(num_params): @@ -37,7 +58,7 @@ def download_model(num_params): response = requests.get(DOWNLOAD_SCRIPT_URL) open("download_model.py", "wb").write(response.content) - os.system(f"python download_model.py {num_params}") + os.system(f"{sys.executable} download_model.py {num_params}") def convert_checkpoints(num_params): @@ -176,8 +197,8 @@ def convert_checkpoints(num_params): keras_hub_model.get_layer("layer_norm").beta.assign(weights["model/ln_f/b"]) # Save the model. - print(f"\n-> Save KerasHub model weights to `{FLAGS.preset}.h5`.") - keras_hub_model.save_weights(f"{FLAGS.preset}.h5") + print(f"\n-> Save KerasHub model weights to {FLAGS.preset}.weights.h5.") + keras_hub_model.save_weights(f"{FLAGS.preset}.weights.h5") return keras_hub_model @@ -211,7 +232,8 @@ def check_output( input_str = ["the quick brown fox ran, galloped and jumped."] # KerasHub - token_ids = keras_hub_tokenizer(input_str) + token_ids_list = keras_hub_tokenizer(ops.array(input_str)) + token_ids = ops.convert_to_tensor(token_ids_list) padding_mask = token_ids != 0 keras_hub_inputs = { @@ -229,7 +251,7 @@ def check_output( print("Difference:", np.mean(keras_hub_output - hf_output.detach().numpy())) # Show the MD5 checksum of the model weights. - print("Model md5sum: ", get_md5_checksum(f"./{FLAGS.preset}.h5")) + print("Model md5sum: ", get_md5_checksum(f"./{FLAGS.preset}.weights.h5")) return keras_hub_output @@ -242,9 +264,26 @@ def main(_): num_params = PRESET_MAP[FLAGS.preset][0] hf_model_name = PRESET_MAP[FLAGS.preset][1] + os.system("pip install requests") download_model(num_params) - keras_hub_model = convert_checkpoints(num_params) + # keras_hub_model_backbone = convert_checkpoints(num_params) + # # This saves .h5 + + # Load the KerasHub GPT2CausalLM model from the preset + # (which will use the .h5) + # This is the model we want to export to SafeTensors. + keras_hub_causal_lm_model = keras_hub.models.GPT2CausalLM.from_preset( + FLAGS.preset + ) + + if FLAGS.export_safetensors: + output_path = FLAGS.output_path or FLAGS.preset + export_to_safetensors( + keras_hub_causal_lm_model, + output_path, + ) + print(f"\n-> Exported GPT-2 model to SafeTensors at `{output_path}`.") print("\n-> Load HF model.") hf_model = transformers.AutoModel.from_pretrained(hf_model_name) @@ -255,7 +294,7 @@ def main(_): ) check_output( - keras_hub_model, + keras_hub_causal_lm_model.backbone, keras_hub_tokenizer, hf_model, hf_tokenizer,