diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index b4b00dc22ec8..c11176d38637 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -21,6 +21,7 @@ --tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \ --checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \ --output_path="$HOME/gemma3/gemma3_4b_pt_safetensors/" + --include_vision_encoder """ from collections.abc import Iterator, Sequence @@ -182,7 +183,7 @@ ), _VARIANT_GEMMA_3_4B: Gemma3Config( text_config=Gemma3TextConfig( - vocab_size=262_208, + vocab_size=262_144, hidden_size=2560, intermediate_size=2560 * 8 // 2, num_attention_heads=8, @@ -200,7 +201,7 @@ ), _VARIANT_GEMMA_3_12B: Gemma3Config( text_config=Gemma3TextConfig( - vocab_size=262_208, + vocab_size=262_144, hidden_size=30 * 128, intermediate_size=30 * 128 * 8 // 2, num_attention_heads=16, @@ -218,7 +219,7 @@ ), _VARIANT_GEMMA_3_27B: Gemma3Config( text_config=Gemma3TextConfig( - vocab_size=262_208, + vocab_size=262_144, hidden_size=42 * 128, intermediate_size=42 * 128 * 8 // 2, num_attention_heads=32, @@ -236,10 +237,15 @@ ), } -_TEXT_ONLY_VARIANTS = (_VARIANT_EMBEDDINGGEMMA, _VARIANT_GEMMA_3_270M, _VARIANT_GEMMA_3_1B) - # ==== Flags ==== +_CHAT_TEMPLATE_PATH = flags.DEFINE_string( + name="chat_template_path", + default=None, + help="Path to the chat template.", + required=False, +) + _CHECKPOINT_PATH = flags.DEFINE_string( name="checkpoint_path", default=None, @@ -251,6 +257,15 @@ name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer" ) +_INCLUDE_VISION_ENCODER = flags.DEFINE_bool( + name="include_vision_encoder", + default=True, + help=( + "If true, the model will expect vision weights in the checkpoint at `checkpoint_path` an if not found loading" + " the weights will throw a `RuntimeError`." + ), +) + _OUTPUT_PATH = flags.DEFINE_string( name="output_path", default=None, @@ -299,6 +314,17 @@ ) +def get_chat_template() -> Optional[str]: + if not _INCLUDE_CHAT_TEMPLATE.value: + return None + + if _CHAT_TEMPLATE_PATH.value: + with open(_CHAT_TEMPLATE_PATH.value, "r") as f: + return f.read() + + return _CHAT_TEMPLATE + + def convert_siglip_weight( config: SiglipVisionConfig, paths: Sequence[str], @@ -405,35 +431,36 @@ def convert_transformer_weights( if path.endswith(_TRANSFORMER_EMBEDDER): if prop == "input_embedding": # Tied to language_model.lm_head.weight, assigned at the end. - converted_paths = ["language_model.model.embed_tokens.weight"] + converted_paths = ["model.language_model.embed_tokens.weight"] - if _VARIANT.value not in _TEXT_ONLY_VARIANTS: + if _INCLUDE_VISION_ENCODER.value: # Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama pre_expansion_embeddings = weights mu = np.mean(pre_expansion_embeddings, axis=0) sigma = np.cov(pre_expansion_embeddings, rowvar=False, bias=True) new_embeddings = np.random.multivariate_normal(mu, 1e-5 * sigma, size=64) weights = np.vstack([pre_expansion_embeddings, new_embeddings]) + config.vocab_size += 64 converted_weights = [weights] - elif _VARIANT.value in _TEXT_ONLY_VARIANTS or prop in ("mm_output_embedding", "mm_input_embedding_extra"): + elif not _INCLUDE_VISION_ENCODER.value or prop in ("mm_output_embedding", "mm_input_embedding_extra"): return zip([], []) else: raise ValueError(f"Unexpected member, {prop}, in Embedder.") elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"): - if _VARIANT.value in _TEXT_ONLY_VARIANTS: + if not _INCLUDE_VISION_ENCODER.value: return zip([], []) if path.endswith("/mm_input_projection"): - converted_paths = ["multi_modal_projector.mm_input_projection_weight"] + converted_paths = ["model.multi_modal_projector.mm_input_projection_weight"] converted_weights = [weights] elif path.endswith("/mm_soft_embedding_norm"): - converted_paths = ["multi_modal_projector.mm_soft_emb_norm.weight"] + converted_paths = ["model.multi_modal_projector.mm_soft_emb_norm.weight"] converted_weights = [weights] else: raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.") elif path.endswith(_TRANSFORMER_FINAL_NORM): - converted_paths = ["language_model.model.norm.weight"] + converted_paths = ["model.language_model.norm.weight"] converted_weights = [weights] elif _TRANSFORMER_DECODER_BLOCK in path: decoder_block_start = path.find(_TRANSFORMER_DECODER_BLOCK) @@ -443,7 +470,7 @@ def convert_transformer_weights( layer_idx = decoder_block_path[:next_path_separator_idx] decoder_block_path = decoder_block_path[next_path_separator_idx:] - base_path = f"language_model.model.layers.{layer_idx}" + base_path = f"model.language_model.layers.{layer_idx}" if path.endswith("attn/attn_vec_einsum"): converted_paths = [f"{base_path}.self_attn.o_proj.weight"] @@ -522,26 +549,35 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No for paths, value in orbax_tree_flat: if paths[0].startswith("SigLiPFromPatches_"): - if config.vision_config is None: + if not _INCLUDE_VISION_ENCODER.value: continue path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value) update_tree(path, weights, config.vision_config.dtype) else: for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value): - if variant in _TEXT_ONLY_VARIANTS: - path = path[len("language_model.") :] + if not _INCLUDE_VISION_ENCODER.value: + # Paths generated during weights conversion assume it is targeting a Gemma3ForConditionalGeneration + # model, which has a Gemma3TextModel at "model.language_model". If _INCLUDE_VISION_ENCODER.value is + # False, then this is targeting a Gemma3ForCausalLM, which has its Gemma3TextModel at "model", so + # the "language_model." portion of the path needs to be removed prior to calling load_state_dict(). + path = path.replace("language_model.", "") + if variant == _VARIANT_EMBEDDINGGEMMA: + # EmbeddingGemma only the Gemma3TextModel instead of an LLM of VLM class for loading weights, and + # defers final model construction to SentenceTransformers, so the "model." portion of the path + # needs to be removed prior to calling load_state_dict(). path = path[len("model.") :] update_tree(path, weights, config.text_config.dtype) if variant == _VARIANT_EMBEDDINGGEMMA: return hf_tree, [weight[1].T for weight in orbax_tree_flat[: _NUM_LINEAR_LAYERS.value]] - elif config.vision_config is None: - hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"] + + if _INCLUDE_VISION_ENCODER.value: + hf_tree["lm_head.weight"] = hf_tree["model.language_model.embed_tokens.weight"] else: - hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"] + hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"] return hf_tree, None @@ -555,10 +591,10 @@ def main(*args): config = _VARIANTS[variant] config.text_config.dtype = getattr(torch, _TRANSFORMER_DTYPE.value) - if variant in _TEXT_ONLY_VARIANTS: - config.vision_config = None - else: + if _INCLUDE_VISION_ENCODER.value: config.vision_config.dtype = getattr(torch, _VISION_DTYPE.value) + else: + config.vision_config = None if _INCLUDE_CHAT_TEMPLATE.value: # Chat template is included for instruction tuned models, which treat @@ -577,10 +613,10 @@ def main(*args): with accelerate.init_empty_weights(): if variant == _VARIANT_EMBEDDINGGEMMA: model = Gemma3TextModel(config=config.text_config) - elif variant in _TEXT_ONLY_VARIANTS: - model = Gemma3ForCausalLM(config=config.text_config) - else: + elif _INCLUDE_VISION_ENCODER.value: model = Gemma3ForConditionalGeneration(config) + else: + model = Gemma3ForCausalLM(config=config.text_config) model.load_state_dict(state_tree, assign=True, strict=True) logging.info( @@ -608,12 +644,12 @@ def main(*args): "boi_token": "", # Should be ID=255_999 "eoi_token": "", # Should be ID=256_000 }, - chat_template=_CHAT_TEMPLATE if _INCLUDE_CHAT_TEMPLATE.value else None, + chat_template=get_chat_template(), ) tokenizer.save_pretrained(output_path) logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path) - if variant not in _TEXT_ONLY_VARIANTS: + if _INCLUDE_VISION_ENCODER.value: image_processor = Gemma3ImageProcessor( image_seq_length=256, image_mean=(0.5,) * 3,