Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 63 additions & 27 deletions src/transformers/models/gemma3/convert_gemma3_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
--tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \
--checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \
--output_path="$HOME/gemma3/gemma3_4b_pt_safetensors/"
--include_vision_encoder
"""

from collections.abc import Iterator, Sequence
Expand Down Expand Up @@ -182,7 +183,7 @@
),
_VARIANT_GEMMA_3_4B: Gemma3Config(
text_config=Gemma3TextConfig(
vocab_size=262_208,
vocab_size=262_144,
hidden_size=2560,
intermediate_size=2560 * 8 // 2,
num_attention_heads=8,
Expand All @@ -200,7 +201,7 @@
),
_VARIANT_GEMMA_3_12B: Gemma3Config(
text_config=Gemma3TextConfig(
vocab_size=262_208,
vocab_size=262_144,
hidden_size=30 * 128,
intermediate_size=30 * 128 * 8 // 2,
num_attention_heads=16,
Expand All @@ -218,7 +219,7 @@
),
_VARIANT_GEMMA_3_27B: Gemma3Config(
text_config=Gemma3TextConfig(
vocab_size=262_208,
vocab_size=262_144,
hidden_size=42 * 128,
intermediate_size=42 * 128 * 8 // 2,
num_attention_heads=32,
Expand All @@ -236,10 +237,15 @@
),
}

_TEXT_ONLY_VARIANTS = (_VARIANT_EMBEDDINGGEMMA, _VARIANT_GEMMA_3_270M, _VARIANT_GEMMA_3_1B)

# ==== Flags ====

_CHAT_TEMPLATE_PATH = flags.DEFINE_string(
name="chat_template_path",
default=None,
help="Path to the chat template.",
required=False,
)

_CHECKPOINT_PATH = flags.DEFINE_string(
name="checkpoint_path",
default=None,
Expand All @@ -251,6 +257,15 @@
name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer"
)

_INCLUDE_VISION_ENCODER = flags.DEFINE_bool(
name="include_vision_encoder",
default=True,
help=(
"If true, the model will expect vision weights in the checkpoint at `checkpoint_path` an if not found loading"
" the weights will throw a `RuntimeError`."
),
)

_OUTPUT_PATH = flags.DEFINE_string(
name="output_path",
default=None,
Expand Down Expand Up @@ -299,6 +314,17 @@
)


def get_chat_template() -> Optional[str]:
if not _INCLUDE_CHAT_TEMPLATE.value:
return None

if _CHAT_TEMPLATE_PATH.value:
with open(_CHAT_TEMPLATE_PATH.value, "r") as f:
return f.read()

return _CHAT_TEMPLATE


def convert_siglip_weight(
config: SiglipVisionConfig,
paths: Sequence[str],
Expand Down Expand Up @@ -405,35 +431,36 @@ def convert_transformer_weights(
if path.endswith(_TRANSFORMER_EMBEDDER):
if prop == "input_embedding":
# Tied to language_model.lm_head.weight, assigned at the end.
converted_paths = ["language_model.model.embed_tokens.weight"]
converted_paths = ["model.language_model.embed_tokens.weight"]

if _VARIANT.value not in _TEXT_ONLY_VARIANTS:
if _INCLUDE_VISION_ENCODER.value:
# Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama
pre_expansion_embeddings = weights
mu = np.mean(pre_expansion_embeddings, axis=0)
sigma = np.cov(pre_expansion_embeddings, rowvar=False, bias=True)
new_embeddings = np.random.multivariate_normal(mu, 1e-5 * sigma, size=64)
weights = np.vstack([pre_expansion_embeddings, new_embeddings])
config.vocab_size += 64

converted_weights = [weights]
elif _VARIANT.value in _TEXT_ONLY_VARIANTS or prop in ("mm_output_embedding", "mm_input_embedding_extra"):
elif not _INCLUDE_VISION_ENCODER.value or prop in ("mm_output_embedding", "mm_input_embedding_extra"):
return zip([], [])
else:
raise ValueError(f"Unexpected member, {prop}, in Embedder.")
elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"):
if _VARIANT.value in _TEXT_ONLY_VARIANTS:
if not _INCLUDE_VISION_ENCODER.value:
return zip([], [])

if path.endswith("/mm_input_projection"):
converted_paths = ["multi_modal_projector.mm_input_projection_weight"]
converted_paths = ["model.multi_modal_projector.mm_input_projection_weight"]
converted_weights = [weights]
elif path.endswith("/mm_soft_embedding_norm"):
converted_paths = ["multi_modal_projector.mm_soft_emb_norm.weight"]
converted_paths = ["model.multi_modal_projector.mm_soft_emb_norm.weight"]
converted_weights = [weights]
else:
raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.")
elif path.endswith(_TRANSFORMER_FINAL_NORM):
converted_paths = ["language_model.model.norm.weight"]
converted_paths = ["model.language_model.norm.weight"]
converted_weights = [weights]
elif _TRANSFORMER_DECODER_BLOCK in path:
decoder_block_start = path.find(_TRANSFORMER_DECODER_BLOCK)
Expand All @@ -443,7 +470,7 @@ def convert_transformer_weights(
layer_idx = decoder_block_path[:next_path_separator_idx]
decoder_block_path = decoder_block_path[next_path_separator_idx:]

base_path = f"language_model.model.layers.{layer_idx}"
base_path = f"model.language_model.layers.{layer_idx}"

if path.endswith("attn/attn_vec_einsum"):
converted_paths = [f"{base_path}.self_attn.o_proj.weight"]
Expand Down Expand Up @@ -522,26 +549,35 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No

for paths, value in orbax_tree_flat:
if paths[0].startswith("SigLiPFromPatches_"):
if config.vision_config is None:
if not _INCLUDE_VISION_ENCODER.value:
continue

path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value)
update_tree(path, weights, config.vision_config.dtype)
else:
for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value):
if variant in _TEXT_ONLY_VARIANTS:
path = path[len("language_model.") :]
if not _INCLUDE_VISION_ENCODER.value:
# Paths generated during weights conversion assume it is targeting a Gemma3ForConditionalGeneration
# model, which has a Gemma3TextModel at "model.language_model". If _INCLUDE_VISION_ENCODER.value is
# False, then this is targeting a Gemma3ForCausalLM, which has its Gemma3TextModel at "model", so
# the "language_model." portion of the path needs to be removed prior to calling load_state_dict().
path = path.replace("language_model.", "")

if variant == _VARIANT_EMBEDDINGGEMMA:
# EmbeddingGemma only the Gemma3TextModel instead of an LLM of VLM class for loading weights, and
# defers final model construction to SentenceTransformers, so the "model." portion of the path
# needs to be removed prior to calling load_state_dict().
path = path[len("model.") :]

update_tree(path, weights, config.text_config.dtype)

if variant == _VARIANT_EMBEDDINGGEMMA:
return hf_tree, [weight[1].T for weight in orbax_tree_flat[: _NUM_LINEAR_LAYERS.value]]
elif config.vision_config is None:
hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"]

if _INCLUDE_VISION_ENCODER.value:
hf_tree["lm_head.weight"] = hf_tree["model.language_model.embed_tokens.weight"]
else:
hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"]
hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"]

return hf_tree, None

Expand All @@ -555,10 +591,10 @@ def main(*args):
config = _VARIANTS[variant]
config.text_config.dtype = getattr(torch, _TRANSFORMER_DTYPE.value)

if variant in _TEXT_ONLY_VARIANTS:
config.vision_config = None
else:
if _INCLUDE_VISION_ENCODER.value:
config.vision_config.dtype = getattr(torch, _VISION_DTYPE.value)
else:
config.vision_config = None

if _INCLUDE_CHAT_TEMPLATE.value:
# Chat template is included for instruction tuned models, which treat
Expand All @@ -577,10 +613,10 @@ def main(*args):
with accelerate.init_empty_weights():
if variant == _VARIANT_EMBEDDINGGEMMA:
model = Gemma3TextModel(config=config.text_config)
elif variant in _TEXT_ONLY_VARIANTS:
model = Gemma3ForCausalLM(config=config.text_config)
else:
elif _INCLUDE_VISION_ENCODER.value:
model = Gemma3ForConditionalGeneration(config)
else:
model = Gemma3ForCausalLM(config=config.text_config)

model.load_state_dict(state_tree, assign=True, strict=True)
logging.info(
Expand Down Expand Up @@ -608,12 +644,12 @@ def main(*args):
"boi_token": "<start_of_image>", # Should be ID=255_999
"eoi_token": "<end_of_image>", # Should be ID=256_000
},
chat_template=_CHAT_TEMPLATE if _INCLUDE_CHAT_TEMPLATE.value else None,
chat_template=get_chat_template(),
)
tokenizer.save_pretrained(output_path)
logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path)

if variant not in _TEXT_ONLY_VARIANTS:
if _INCLUDE_VISION_ENCODER.value:
image_processor = Gemma3ImageProcessor(
image_seq_length=256,
image_mean=(0.5,) * 3,
Expand Down