-
Notifications
You must be signed in to change notification settings - Fork 309
Gpt2safetensors #2459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Gpt2safetensors #2459
Changes from all commits
5add1a3
7146a5c
a8b6d07
3c59953
adcc364
d8e71af
45e2e9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| import keras.ops as ops | ||
| import transformers | ||
|
|
||
|
|
||
| def get_gpt2_config(keras_model): | ||
| """Convert Keras GPT-2 config to Hugging Face GPT2Config.""" | ||
| return transformers.GPT2Config( | ||
| vocab_size=keras_model.vocabulary_size, | ||
| n_positions=keras_model.max_sequence_length, | ||
| n_embd=keras_model.hidden_dim, | ||
| n_layer=keras_model.num_layers, | ||
| n_head=keras_model.num_heads, | ||
| n_inner=keras_model.intermediate_dim, | ||
| activation_function="gelu_new", | ||
| resid_pdrop=0.1, | ||
| embd_pdrop=0.1, | ||
| attn_pdrop=0.1, | ||
| layer_norm_epsilon=1e-5, | ||
| initializer_range=0.02, | ||
| summary_type="cls_index", | ||
| summary_use_proj=True, | ||
| summary_activation=None, | ||
| summary_proj_to_labels=True, | ||
| summary_first_dropout=0.1, | ||
| scale_attn_weights=True, | ||
| use_cache=True, | ||
| bos_token_id=50256, | ||
| eos_token_id=50256, | ||
| ) | ||
|
|
||
|
|
||
| def get_gpt2_weights_map(keras_model, include_lm_head=False): | ||
| """Create a weights map for a given GPT-2 model.""" | ||
| weights_map = {} | ||
|
|
||
| # Token and position embeddings | ||
| weights_map["transformer.wte.weight"] = keras_model.get_layer( | ||
| "token_embedding" | ||
| ).embeddings | ||
| weights_map["transformer.wpe.weight"] = keras_model.get_layer( | ||
| "position_embedding" | ||
| ).position_embeddings | ||
|
|
||
| for i in range(keras_model.num_layers): | ||
| # Attention weights | ||
| q_w = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._query_dense.kernel | ||
| k_w = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._key_dense.kernel | ||
| v_w = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._value_dense.kernel | ||
|
Comment on lines
+46
to
+54
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accessing private layer attributes like |
||
| q_b = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._query_dense.bias | ||
| k_b = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._key_dense.bias | ||
| v_b = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._value_dense.bias | ||
|
|
||
| q_w = ops.reshape(q_w, (keras_model.hidden_dim, keras_model.hidden_dim)) | ||
| k_w = ops.reshape(k_w, (keras_model.hidden_dim, keras_model.hidden_dim)) | ||
| v_w = ops.reshape(v_w, (keras_model.hidden_dim, keras_model.hidden_dim)) | ||
|
|
||
| c_attn_w = tf.concat([q_w, k_w, v_w], axis=-1) | ||
| weights_map[f"transformer.h.{i}.attn.c_attn.weight"] = c_attn_w | ||
|
|
||
| q_b = tf.reshape(q_b, [-1]) | ||
| k_b = tf.reshape(k_b, [-1]) | ||
| v_b = tf.reshape(v_b, [-1]) | ||
|
|
||
| c_attn_b = tf.concat([q_b, k_b, v_b], axis=-1) | ||
| weights_map[f"transformer.h.{i}.attn.c_attn.bias"] = c_attn_b | ||
|
|
||
| # Attention projection | ||
| c_proj_w = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._output_dense.kernel | ||
| c_proj_w = tf.reshape( | ||
| c_proj_w, (keras_model.hidden_dim, keras_model.hidden_dim) | ||
| ) | ||
| weights_map[f"transformer.h.{i}.attn.c_proj.weight"] = c_proj_w | ||
| weights_map[f"transformer.h.{i}.attn.c_proj.bias"] = ( | ||
| keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._output_dense.bias | ||
| ) | ||
|
|
||
| # Layer norms | ||
| weights_map[f"transformer.h.{i}.ln_1.weight"] = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer_norm.gamma | ||
| weights_map[f"transformer.h.{i}.ln_1.bias"] = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer_norm.beta | ||
| weights_map[f"transformer.h.{i}.ln_2.weight"] = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._feedforward_layer_norm.gamma | ||
| weights_map[f"transformer.h.{i}.ln_2.bias"] = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._feedforward_layer_norm.beta | ||
|
|
||
| # MLP | ||
| c_fc_w = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._feedforward_intermediate_dense.kernel | ||
| weights_map[f"transformer.h.{i}.mlp.c_fc.weight"] = c_fc_w | ||
| weights_map[f"transformer.h.{i}.mlp.c_fc.bias"] = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._feedforward_intermediate_dense.bias | ||
| c_proj_w_mlp = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._feedforward_output_dense.kernel | ||
| weights_map[f"transformer.h.{i}.mlp.c_proj.weight"] = c_proj_w_mlp | ||
| weights_map[f"transformer.h.{i}.mlp.c_proj.bias"] = ( | ||
| keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._feedforward_output_dense.bias | ||
| ) | ||
|
|
||
| # Final layer norm | ||
| weights_map["transformer.ln_f.weight"] = keras_model.get_layer( | ||
| "layer_norm" | ||
| ).gamma | ||
| weights_map["transformer.ln_f.bias"] = keras_model.get_layer( | ||
| "layer_norm" | ||
| ).beta | ||
|
|
||
| if include_lm_head: | ||
| # lm_head is tied to token embeddings | ||
| weights_map["lm_head.weight"] = weights_map["transformer.wte.weight"] | ||
|
|
||
| return weights_map | ||
|
|
||
|
|
||
| def get_gpt2_tokenizer_config(tokenizer): | ||
| return { | ||
| "model_type": "gpt2", | ||
| "bos_token": "<|endoftext|>", | ||
| "eos_token": "<|endoftext|>", | ||
| "unk_token": "<|endoftext|>", | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| import os | ||
| import shutil | ||
| import sys | ||
| import tempfile | ||
| from os.path import abspath | ||
| from os.path import dirname | ||
|
|
||
| # import keras | ||
| import numpy as np | ||
| import keras.ops as ops | ||
|
|
||
| # import torch | ||
| from absl.testing import parameterized | ||
| from transformers import AutoModelForCausalLM | ||
| from transformers import AutoTokenizer | ||
|
|
||
| # Add the project root to the Python path. | ||
| sys.path.insert( | ||
| 0, dirname(dirname(dirname(dirname(dirname(abspath(__file__)))))) | ||
| ) | ||
|
|
||
| from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM | ||
| from keras_hub.src.utils.transformers.export.hf_exporter import ( | ||
| export_to_safetensors, | ||
| ) | ||
|
|
||
|
|
||
| def to_numpy(x): | ||
| # Torch tensor | ||
| if hasattr(x, "detach") and hasattr(x, "cpu"): | ||
| return x.detach().cpu().numpy() | ||
|
|
||
| # TF tensor | ||
| if hasattr(x, "numpy"): | ||
| return x.numpy() | ||
|
|
||
| # Numpy | ||
| if isinstance(x, np.ndarray): | ||
| return x | ||
|
|
||
| # KerasTensor or ragged wrapper → convert to TF → numpy | ||
| try: | ||
| import tensorflow as tf | ||
|
|
||
| return tf.convert_to_tensor(x).numpy() | ||
| except Exception: | ||
| pass | ||
|
|
||
| raise TypeError(f"Cannot convert value of type {type(x)} to numpy") | ||
|
|
||
|
|
||
| class GPT2ExportTest(tf.test.TestCase, parameterized.TestCase): | ||
| @parameterized.named_parameters( | ||
| ("gpt2_base_en", "gpt2_base_en"), | ||
| ) | ||
| def test_gpt2_export(self, preset): | ||
| # Create a temporary directory to save the converted model. | ||
| temp_dir = tempfile.mkdtemp() | ||
| output_path = os.path.join(temp_dir, preset) | ||
|
|
||
| # Load Keras model. | ||
| keras_model = GPT2CausalLM.from_preset(preset) | ||
|
|
||
| # Export to Hugging Face format. | ||
| export_to_safetensors(keras_model, output_path) | ||
|
|
||
| # Load the converted model with Hugging Face Transformers. | ||
| hf_model = AutoModelForCausalLM.from_pretrained(output_path) | ||
| hf_tokenizer = AutoTokenizer.from_pretrained(output_path) | ||
|
|
||
| # Assertions for config parameters. | ||
| self.assertEqual( | ||
| keras_model.backbone.hidden_dim, hf_model.config.hidden_size | ||
| ) | ||
| self.assertEqual( | ||
| keras_model.backbone.num_layers, hf_model.config.n_layer | ||
| ) | ||
| self.assertEqual(keras_model.backbone.num_heads, hf_model.config.n_head) | ||
| self.assertEqual( | ||
| keras_model.backbone.intermediate_dim, hf_model.config.n_inner | ||
| ) | ||
| self.assertEqual( | ||
| keras_model.backbone.vocabulary_size, hf_model.config.vocab_size | ||
| ) | ||
| self.assertEqual( | ||
| keras_model.backbone.max_sequence_length, | ||
| hf_model.config.n_positions, | ||
| ) | ||
|
|
||
| # Test logits. | ||
| prompt = "Hello, my name is" | ||
| token_ids = ops.array( | ||
| keras_model.preprocessor.tokenizer(ops.array([prompt])) | ||
| ) | ||
| padding_mask = tf.ones_like(token_ids, dtype=tf.int32) | ||
| keras_inputs = {"token_ids": token_ids, "padding_mask": padding_mask} | ||
| keras_logits = keras_model(keras_inputs) | ||
|
|
||
| hf_inputs = hf_tokenizer(prompt, return_tensors="pt") | ||
| hf_logits = hf_model(**hf_inputs).logits | ||
| print(hf_logits) | ||
|
|
||
| # Compare logits. | ||
| # Keras logits are (batch_size, sequence_length, vocab_size) | ||
| # HF logits are (batch_size, sequence_length, vocab_size) | ||
| # We need to convert Keras logits to numpy and then to torch tensor | ||
| # for comparison. | ||
|
|
||
| # Convert Keras logits (TF) -> numpy | ||
| keras_logits_np = to_numpy(keras_logits) | ||
|
|
||
| # Convert HF logits (Torch, possibly MPS) -> numpy | ||
| hf_logits_np = to_numpy(hf_logits) | ||
|
|
||
| self.assertAllClose(keras_logits_np, hf_logits_np, atol=1e-3, rtol=1e-3) | ||
|
|
||
| # Clean up the temporary directory. | ||
| shutil.rmtree(temp_dir) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tf.test.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For better readability and consistency, imports should be at the top of the file. Please move
import json,import os, andimport shutilto the top-level of the module to follow standard Python conventions.