From b98d0eb15bb68ca735030114ecbfff2a4f626e8f Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Fri, 17 Oct 2025 17:22:23 +0000 Subject: [PATCH 1/4] conversion: add include_vision_encoder flag (default true) --- .../models/gemma3/convert_gemma3_weights.py | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index b4b00dc22ec8..0d53a02b5a12 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -236,8 +236,6 @@ ), } -_TEXT_ONLY_VARIANTS = (_VARIANT_EMBEDDINGGEMMA, _VARIANT_GEMMA_3_270M, _VARIANT_GEMMA_3_1B) - # ==== Flags ==== _CHECKPOINT_PATH = flags.DEFINE_string( @@ -251,6 +249,15 @@ name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer" ) +_INCLUDE_VISION_ENCODER = flags.DEFINE_bool( + name="include_vision_encoder", + default=True, + help=( + "If true, the model will expect vision weights in the checkpoint at `checkpoint_path` an if not found loading" + " the weights will throw a `RuntimeError`." + ), +) + _OUTPUT_PATH = flags.DEFINE_string( name="output_path", default=None, @@ -407,7 +414,7 @@ def convert_transformer_weights( # Tied to language_model.lm_head.weight, assigned at the end. converted_paths = ["language_model.model.embed_tokens.weight"] - if _VARIANT.value not in _TEXT_ONLY_VARIANTS: + if _INCLUDE_VISION_ENCODER.value: # Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama pre_expansion_embeddings = weights mu = np.mean(pre_expansion_embeddings, axis=0) @@ -416,12 +423,12 @@ def convert_transformer_weights( weights = np.vstack([pre_expansion_embeddings, new_embeddings]) converted_weights = [weights] - elif _VARIANT.value in _TEXT_ONLY_VARIANTS or prop in ("mm_output_embedding", "mm_input_embedding_extra"): + elif not _INCLUDE_VISION_ENCODER.value or prop in ("mm_output_embedding", "mm_input_embedding_extra"): return zip([], []) else: raise ValueError(f"Unexpected member, {prop}, in Embedder.") elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"): - if _VARIANT.value in _TEXT_ONLY_VARIANTS: + if not _INCLUDE_VISION_ENCODER.value: return zip([], []) if path.endswith("/mm_input_projection"): @@ -522,15 +529,16 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No for paths, value in orbax_tree_flat: if paths[0].startswith("SigLiPFromPatches_"): - if config.vision_config is None: + if not _INCLUDE_VISION_ENCODER.value: continue path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value) update_tree(path, weights, config.vision_config.dtype) else: for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value): - if variant in _TEXT_ONLY_VARIANTS: + if not _INCLUDE_VISION_ENCODER.value: path = path[len("language_model.") :] + if variant == _VARIANT_EMBEDDINGGEMMA: path = path[len("model.") :] @@ -538,7 +546,8 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No if variant == _VARIANT_EMBEDDINGGEMMA: return hf_tree, [weight[1].T for weight in orbax_tree_flat[: _NUM_LINEAR_LAYERS.value]] - elif config.vision_config is None: + + if not _INCLUDE_VISION_ENCODER.value: hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"] else: hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"] @@ -555,10 +564,10 @@ def main(*args): config = _VARIANTS[variant] config.text_config.dtype = getattr(torch, _TRANSFORMER_DTYPE.value) - if variant in _TEXT_ONLY_VARIANTS: - config.vision_config = None - else: + if _INCLUDE_VISION_ENCODER.value: config.vision_config.dtype = getattr(torch, _VISION_DTYPE.value) + else: + config.vision_config = None if _INCLUDE_CHAT_TEMPLATE.value: # Chat template is included for instruction tuned models, which treat @@ -577,10 +586,10 @@ def main(*args): with accelerate.init_empty_weights(): if variant == _VARIANT_EMBEDDINGGEMMA: model = Gemma3TextModel(config=config.text_config) - elif variant in _TEXT_ONLY_VARIANTS: - model = Gemma3ForCausalLM(config=config.text_config) - else: + elif _INCLUDE_VISION_ENCODER.value: model = Gemma3ForConditionalGeneration(config) + else: + model = Gemma3ForCausalLM(config=config.text_config) model.load_state_dict(state_tree, assign=True, strict=True) logging.info( @@ -613,7 +622,7 @@ def main(*args): tokenizer.save_pretrained(output_path) logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path) - if variant not in _TEXT_ONLY_VARIANTS: + if _INCLUDE_VISION_ENCODER.value: image_processor = Gemma3ImageProcessor( image_seq_length=256, image_mean=(0.5,) * 3, From b733c33a5164b6a2b95af8523bfd1617ee49d169 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Fri, 17 Oct 2025 18:50:25 +0000 Subject: [PATCH 2/4] conversion: update for inverted model.language_model weight path --- .../models/gemma3/convert_gemma3_weights.py | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 0d53a02b5a12..a6ecfb70aad1 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -21,6 +21,7 @@ --tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \ --checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \ --output_path="$HOME/gemma3/gemma3_4b_pt_safetensors/" + --include_vision_encoder """ from collections.abc import Iterator, Sequence @@ -182,7 +183,7 @@ ), _VARIANT_GEMMA_3_4B: Gemma3Config( text_config=Gemma3TextConfig( - vocab_size=262_208, + vocab_size=262_144, hidden_size=2560, intermediate_size=2560 * 8 // 2, num_attention_heads=8, @@ -200,7 +201,7 @@ ), _VARIANT_GEMMA_3_12B: Gemma3Config( text_config=Gemma3TextConfig( - vocab_size=262_208, + vocab_size=262_144, hidden_size=30 * 128, intermediate_size=30 * 128 * 8 // 2, num_attention_heads=16, @@ -218,7 +219,7 @@ ), _VARIANT_GEMMA_3_27B: Gemma3Config( text_config=Gemma3TextConfig( - vocab_size=262_208, + vocab_size=262_144, hidden_size=42 * 128, intermediate_size=42 * 128 * 8 // 2, num_attention_heads=32, @@ -251,7 +252,7 @@ _INCLUDE_VISION_ENCODER = flags.DEFINE_bool( name="include_vision_encoder", - default=True, + default=False, help=( "If true, the model will expect vision weights in the checkpoint at `checkpoint_path` an if not found loading" " the weights will throw a `RuntimeError`." @@ -412,7 +413,7 @@ def convert_transformer_weights( if path.endswith(_TRANSFORMER_EMBEDDER): if prop == "input_embedding": # Tied to language_model.lm_head.weight, assigned at the end. - converted_paths = ["language_model.model.embed_tokens.weight"] + converted_paths = ["model.language_model.embed_tokens.weight"] if _INCLUDE_VISION_ENCODER.value: # Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama @@ -421,6 +422,7 @@ def convert_transformer_weights( sigma = np.cov(pre_expansion_embeddings, rowvar=False, bias=True) new_embeddings = np.random.multivariate_normal(mu, 1e-5 * sigma, size=64) weights = np.vstack([pre_expansion_embeddings, new_embeddings]) + config.vocab_size += 64 converted_weights = [weights] elif not _INCLUDE_VISION_ENCODER.value or prop in ("mm_output_embedding", "mm_input_embedding_extra"): @@ -432,15 +434,15 @@ def convert_transformer_weights( return zip([], []) if path.endswith("/mm_input_projection"): - converted_paths = ["multi_modal_projector.mm_input_projection_weight"] + converted_paths = ["model.multi_modal_projector.mm_input_projection_weight"] converted_weights = [weights] elif path.endswith("/mm_soft_embedding_norm"): - converted_paths = ["multi_modal_projector.mm_soft_emb_norm.weight"] + converted_paths = ["model.multi_modal_projector.mm_soft_emb_norm.weight"] converted_weights = [weights] else: raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.") elif path.endswith(_TRANSFORMER_FINAL_NORM): - converted_paths = ["language_model.model.norm.weight"] + converted_paths = ["model.language_model.norm.weight"] converted_weights = [weights] elif _TRANSFORMER_DECODER_BLOCK in path: decoder_block_start = path.find(_TRANSFORMER_DECODER_BLOCK) @@ -450,7 +452,7 @@ def convert_transformer_weights( layer_idx = decoder_block_path[:next_path_separator_idx] decoder_block_path = decoder_block_path[next_path_separator_idx:] - base_path = f"language_model.model.layers.{layer_idx}" + base_path = f"model.language_model.layers.{layer_idx}" if path.endswith("attn/attn_vec_einsum"): converted_paths = [f"{base_path}.self_attn.o_proj.weight"] @@ -537,9 +539,16 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No else: for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value): if not _INCLUDE_VISION_ENCODER.value: - path = path[len("language_model.") :] + # Paths generated during weights conversion assume it is targeting a Gemma3ForConditionalGeneration + # model, which has a Gemma3TextModel at "model.language_model". If _INCLUDE_VISION_ENCODER.value is + # False, then this is targeting a Gemma3ForCausalLM, which has its Gemma3TextModel at "model", so + # the "language_model." portion of the path needs to be removed prior to calling load_state_dict(). + path = path.replace("language_model.", "") if variant == _VARIANT_EMBEDDINGGEMMA: + # EmbeddingGemma only the Gemma3TextModel instead of an LLM of VLM class for loading weights, and + # defers final model construction to SentenceTransformers, so the "model." portion of the path + # needs to be removed prior to calling load_state_dict(). path = path[len("model.") :] update_tree(path, weights, config.text_config.dtype) @@ -547,10 +556,10 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No if variant == _VARIANT_EMBEDDINGGEMMA: return hf_tree, [weight[1].T for weight in orbax_tree_flat[: _NUM_LINEAR_LAYERS.value]] - if not _INCLUDE_VISION_ENCODER.value: - hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"] + if _INCLUDE_VISION_ENCODER.value: + hf_tree["lm_head.weight"] = hf_tree["model.language_model.embed_tokens.weight"] else: - hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"] + hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"] return hf_tree, None From 691db887b3560b6266dfa7c66291cd63c14f91c0 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Fri, 17 Oct 2025 18:56:00 +0000 Subject: [PATCH 3/4] conversion: revert include_vision_encoder to True by default --- src/transformers/models/gemma3/convert_gemma3_weights.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index a6ecfb70aad1..0f3ea1bb7ebd 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -252,7 +252,7 @@ _INCLUDE_VISION_ENCODER = flags.DEFINE_bool( name="include_vision_encoder", - default=False, + default=True, help=( "If true, the model will expect vision weights in the checkpoint at `checkpoint_path` an if not found loading" " the weights will throw a `RuntimeError`." From 3f8fee352fa5b0a1fd8bad62753b8fb54c2ac282 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Sat, 18 Oct 2025 17:19:07 +0000 Subject: [PATCH 4/4] conversion: add chat template path flag --- .../models/gemma3/convert_gemma3_weights.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 0f3ea1bb7ebd..c11176d38637 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -239,6 +239,13 @@ # ==== Flags ==== +_CHAT_TEMPLATE_PATH = flags.DEFINE_string( + name="chat_template_path", + default=None, + help="Path to the chat template.", + required=False, +) + _CHECKPOINT_PATH = flags.DEFINE_string( name="checkpoint_path", default=None, @@ -307,6 +314,17 @@ ) +def get_chat_template() -> Optional[str]: + if not _INCLUDE_CHAT_TEMPLATE.value: + return None + + if _CHAT_TEMPLATE_PATH.value: + with open(_CHAT_TEMPLATE_PATH.value, "r") as f: + return f.read() + + return _CHAT_TEMPLATE + + def convert_siglip_weight( config: SiglipVisionConfig, paths: Sequence[str], @@ -626,7 +644,7 @@ def main(*args): "boi_token": "", # Should be ID=255_999 "eoi_token": "", # Should be ID=256_000 }, - chat_template=_CHAT_TEMPLATE if _INCLUDE_CHAT_TEMPLATE.value else None, + chat_template=get_chat_template(), ) tokenizer.save_pretrained(output_path) logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path)