diff --git a/tools/gemma/export_gemma_to_hf.py b/tools/gemma/export_gemma_to_hf.py index b90e8f3c6a..e9b2dea5d7 100644 --- a/tools/gemma/export_gemma_to_hf.py +++ b/tools/gemma/export_gemma_to_hf.py @@ -1,15 +1,18 @@ import contextlib import os +from typing import Optional import torch import transformers from absl import app from absl import flags -import keras_hub - os.environ["KERAS_BACKEND"] = "torch" +import keras # noqa: F401,E402 + +import keras_hub # noqa: E402 + """ Sample usage: @@ -42,18 +45,29 @@ PRESET_MAP = { + # Gemma 1 "gemma_2b_en": "gg-hf/gemma-2b", "gemma_instruct_2b_en": "gg-hf/gemma-2b", "gemma_7b_en": "gg-hf/gemma-7b", "gemma_instruct_7b_en": "gg-hf/gemma-7b", + # Gemma 2 + "gemma2_2b_en": "gg-hf/gemma-2-2b", + "gemma2_instruct_2b_en": "gg-hf/gemma-2-2b-it", + "gemma2_9b_en": "gg-hf/gemma-2-9b", + "gemma2_instruct_9b_en": "gg-hf/gemma-2-9b-it", + "gemma2_27b_en": "gg-hf/gemma-2-27b", + "gemma2_instruct_27b_en": "gg-hf/gemma-2-27b-it", } SIZE_MAP = { - "2b": ("gg-hf/gemma-2b", "gemma_2b_en"), - "7b": ("gg-hf/gemma-7b", "gemma_7b_en"), + "v1_2b": ("gg-hf/gemma-2b", "gemma_2b_en"), + "v1_7b": ("gg-hf/gemma-7b", "gemma_7b_en"), + "v2_2b": ("gg-hf/gemma-2-2b", "gemma2_2b_en"), + "v2_9b": ("gg-hf/gemma-2-9b", "gemma2_9b_en"), + "v2_27b": ("gg-hf/gemma-2-27b", "gemma2_27b_en"), } -gemma_2b_config = transformers.GemmaConfig( +gemma1_2b_config = transformers.GemmaConfig( num_hidden_layers=18, num_attention_heads=8, num_key_value_heads=1, @@ -61,9 +75,49 @@ intermediate_size=16384, ) -gemma_7b_config = transformers.GemmaConfig() +gemma1_7b_config = transformers.GemmaConfig() -CONFIG_MAPPING = {"2b": gemma_2b_config, "7b": gemma_7b_config} +gemma2_2b_config = transformers.Gemma2Config( + num_hidden_layers=26, + num_attention_heads=8, + num_key_value_heads=4, + hidden_size=2304, + intermediate_size=9216, +) + +gemma2_9b_config = transformers.Gemma2Config( + num_hidden_layers=42, + num_attention_heads=16, + num_key_value_heads=8, + hidden_size=3584, + intermediate_size=14336, + final_logit_softcapping=30.0, + attn_logit_softcapping=50.0, + head_dim=256, + sliding_window=4096, + query_pre_attn_scalar=224, +) + +gemma2_27b_config = transformers.Gemma2Config( + num_hidden_layers=46, + num_attention_heads=32, + num_key_value_heads=16, + hidden_size=4608, + intermediate_size=36864, + final_logit_softcapping=30.0, + attn_logit_softcapping=50.0, + head_dim=128, + sliding_window=4096, + query_pre_attn_scalar=144, +) + +CONFIG_MAPPING = { + "v1_2b": gemma1_2b_config, + "v1_7b": gemma1_7b_config, + "v2_2b": gemma2_2b_config, + "v2_9b": gemma2_9b_config, + "v2_27b": gemma2_27b_config, +} FLAGS = flags.FLAGS flags.DEFINE_string( @@ -107,6 +161,11 @@ "float32", "Set the precision of the converted checkpoint. Must be a valid PyTorch dtype.", ) +flags.DEFINE_integer( + "gemma_version", + None, + "Integer denoting the Gemma version (e.g. 1, 2).", +) @contextlib.contextmanager @@ -117,13 +176,22 @@ def _set_default_tensor_type(dtype: torch.dtype): torch.set_default_dtype(torch.float) -def convert_checkpoints(preset, weights_file, size, output_dir, vocab_path): +def convert_checkpoints( + preset: str, + weights_file: str, + gemma_version: int, + size: str, + output_dir: str, + vocab_path: Optional[str] = None, +): if preset is not None: hf_id = PRESET_MAP[preset] print(f"\n-> Loading KerasHub Gemma model with preset `{preset}`...") keras_hub_model = keras_hub.models.GemmaCausalLM.from_preset(preset) else: - hf_id, keras_preset = SIZE_MAP[size.lower()] + hf_id, keras_preset = SIZE_MAP[ + f"v{gemma_version.lower()}_{size.lower()}" + ] print(f"\n-> Loading Keras weights from file `{weights_file}`...") keras_hub_model = keras_hub.models.GemmaCausalLM.from_preset( keras_preset @@ -131,7 +199,11 @@ def convert_checkpoints(preset, weights_file, size, output_dir, vocab_path): keras_hub_model.load_weights(weights_file) print(f"\n-> Loading HuggingFace Gemma `{size.upper()}` model...") - hf_model = transformers.GemmaForCausalLM(CONFIG_MAPPING[size.lower()]) + config = CONFIG_MAPPING[f"v{gemma_version}_{size.lower()}"] + if isinstance(config, transformers.GemmaConfig): + hf_model = transformers.GemmaForCausalLM(config) + elif isinstance(config, transformers.Gemma2Config): + hf_model = transformers.Gemma2ForCausalLM(config) print("\n✅ Model loading complete.") print("\n-> Converting weights from KerasHub Gemma to HuggingFace Gemma...") @@ -322,11 +394,12 @@ def main(_): flag_error_handler() with _set_default_tensor_type(getattr(torch, FLAGS.dtype)): convert_checkpoints( - FLAGS.preset, - FLAGS.weights_file, - FLAGS.size, - FLAGS.output_dir, - FLAGS.vocab_path, + preset=FLAGS.preset, + weights_file=FLAGS.weights_file, + gemma_version=FLAGS.gemma_version, + size=FLAGS.size, + output_dir=FLAGS.output_dir, + vocab_path=FLAGS.vocab_path, )