Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/models/gemma3/configuration_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ class Gemma3TextConfig(PretrainedConfig):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
rope_local_base_freq (float, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings for local attention.
use_bidirectional_attention (`bool`, *optional*, defaults to `False`): If True, the model will attend to all
text tokens instead of using a causal mask. This does not change behavior for vision tokens.

```python
>>> from transformers import Gemma3TextModel, Gemma3TextConfig
Expand Down Expand Up @@ -193,6 +195,7 @@ def __init__(
attn_logit_softcapping=None,
rope_scaling=None,
rope_local_base_freq=10_000.0,
use_bidirectional_attention=False,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -222,6 +225,7 @@ def __init__(
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.layer_types = layer_types
self.use_bidirectional_attention = use_bidirectional_attention

self.rope_local_base_freq = rope_local_base_freq
self.rope_scaling = rope_scaling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint.

python -m transformers.models.gemma3.convert_gemma3_weights_orbax_to_hf \
python src/transformers/models/gemma3/convert_gemma3_weights.py \
--variant='gemma3_4b' \
--tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \
--checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \
--output_path="$HOME/gemma3/gemma3_4b_pt_safetensors/"
"""

from collections.abc import Iterator, Sequence
from typing import Any
from typing import Any, Optional

import accelerate
import numpy as np
Expand All @@ -40,6 +40,7 @@
Gemma3ImageProcessor,
Gemma3Processor,
Gemma3TextConfig,
Gemma3TextModel,
GemmaTokenizerFast,
GenerationConfig,
SiglipVisionConfig,
Expand Down Expand Up @@ -100,10 +101,10 @@
_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK)
_SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm"

_TRANSFORMER_DECODER_BLOCK = "transformer/layer_"
_TRANSFORMER_DECODER_BLOCK = "/layer_"
_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK)
_TRANSFORMER_EMBEDDER = "transformer/embedder"
_TRANSFORMER_FINAL_NORM = "transformer/final_norm"
_TRANSFORMER_EMBEDDER = "/embedder"
_TRANSFORMER_FINAL_NORM = "/final_norm"
_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/"
_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX)

Expand All @@ -121,11 +122,46 @@
"vision_use_head": False,
}

_VARIANT_EMBEDDINGGEMMA = "embedding"
_VARIANT_GEMMA_3_270M = "gemma3_270m"
_VARIANT_GEMMA_3_1B = "gemma3_1b"
_VARIANT_GEMMA_3_4B = "gemma3_4b"
_VARIANT_GEMMA_3_12B = "gemma3_12b"
_VARIANT_GEMMA_3_27B = "gemma3_27b"
_VARIANTS = {
_VARIANT_EMBEDDINGGEMMA: Gemma3Config(
text_config=Gemma3TextConfig(
vocab_size=262_144,
hidden_size=768,
intermediate_size=1152,
num_hidden_layers=24,
num_attention_heads=3,
num_key_value_heads=1,
head_dim=256,
max_position_embeddings=1024,
query_pre_attn_scalar=256,
sliding_window=512,
rope_scaling=None,
use_bidirectional_attention=True,
),
vision_config=None,
),
_VARIANT_GEMMA_3_270M: Gemma3Config(
text_config=Gemma3TextConfig(
vocab_size=262_144,
hidden_size=640,
intermediate_size=2048,
num_hidden_layers=18,
num_attention_heads=4,
num_key_value_heads=1,
head_dim=256,
max_position_embeddings=32768,
query_pre_attn_scalar=256,
sliding_window=512,
rope_scaling=None,
),
vision_config=None,
),
_VARIANT_GEMMA_3_1B: Gemma3Config(
text_config=Gemma3TextConfig(
vocab_size=262_144,
Expand Down Expand Up @@ -200,6 +236,8 @@
),
}

_TEXT_ONLY_VARIANTS = (_VARIANT_EMBEDDINGGEMMA, _VARIANT_GEMMA_3_270M, _VARIANT_GEMMA_3_1B)

# ==== Flags ====

_CHECKPOINT_PATH = flags.DEFINE_string(
Expand All @@ -220,6 +258,12 @@
required=True,
)

_NUM_LINEAR_LAYERS = flags.DEFINE_integer(
name="num_linear_layers",
default=2,
help="Number of linear projection layers at the end of the Sentence Transformer.",
)

_TRANSFORMER_DTYPE = flags.DEFINE_enum(
name="text_dtype",
default="bfloat16",
Expand Down Expand Up @@ -358,12 +402,12 @@ def convert_transformer_weights(
attn_head_dim = config.num_attention_heads * config.head_dim
kv_head_dim = config.num_key_value_heads * config.head_dim

if path == _TRANSFORMER_EMBEDDER:
if path.endswith(_TRANSFORMER_EMBEDDER):
if prop == "input_embedding":
# Tied to language_model.lm_head.weight, assigned at the end.
converted_paths = ["language_model.model.embed_tokens.weight"]

if _VARIANT.value != _VARIANT_GEMMA_3_1B:
if _VARIANT.value not in _TEXT_ONLY_VARIANTS:
# Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama
pre_expansion_embeddings = weights
mu = np.mean(pre_expansion_embeddings, axis=0)
Expand All @@ -372,12 +416,12 @@ def convert_transformer_weights(
weights = np.vstack([pre_expansion_embeddings, new_embeddings])

converted_weights = [weights]
elif _VARIANT.value == _VARIANT_GEMMA_3_1B or prop in ("mm_output_embedding", "mm_input_embedding_extra"):
elif _VARIANT.value in _TEXT_ONLY_VARIANTS or prop in ("mm_output_embedding", "mm_input_embedding_extra"):
return zip([], [])
else:
raise ValueError(f"Unexpected member, {prop}, in Embedder.")
elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"):
if _VARIANT.value == _VARIANT_GEMMA_3_1B:
if _VARIANT.value in _TEXT_ONLY_VARIANTS:
return zip([], [])

if path.endswith("/mm_input_projection"):
Expand All @@ -388,14 +432,16 @@ def convert_transformer_weights(
converted_weights = [weights]
else:
raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.")
elif path == _TRANSFORMER_FINAL_NORM:
elif path.endswith(_TRANSFORMER_FINAL_NORM):
converted_paths = ["language_model.model.norm.weight"]
converted_weights = [weights]
elif path.startswith(_TRANSFORMER_DECODER_BLOCK):
decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:]
next_path_separator_idx = decoder_block_path.find("/")
layer_idx = decoder_block_path[:next_path_separator_idx]
decoder_block_path = decoder_block_path[next_path_separator_idx:]
elif _TRANSFORMER_DECODER_BLOCK in path:
decoder_block_start = path.find(_TRANSFORMER_DECODER_BLOCK)
decoder_block_offset = decoder_block_start + _TRANSFORMER_DECODER_BLOCK_LEN
decoder_block_path = path[decoder_block_offset:]
next_path_seperator_idx = decoder_block_path.find("/")
layer_idx = decoder_block_path[:next_path_seperator_idx]
decoder_block_path = decoder_block_path[next_path_seperator_idx:]

base_path = f"language_model.model.layers.{layer_idx}"

Expand Down Expand Up @@ -445,8 +491,6 @@ def convert_transformer_weights(
converted_weights = [weights]
else:
raise ValueError(f"Unexpected path `{path}` in Decoder Block.")
else:
raise ValueError(f"Unexpected path `{path}`.")

if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
raise ValueError(
Expand All @@ -457,11 +501,14 @@ def convert_transformer_weights(
return zip(converted_paths, converted_weights)


def convert(checkpoint_path: str, config: Gemma3Config) -> dict[str, torch.Tensor]:
def convert(
checkpoint_path: str, config: Gemma3Config, variant: str
) -> tuple[dict[str, torch.Tensor], Optional[Sequence[np.ndarray]]]:
"""Loads Orbax checkpoint from `input_path` and converts it to HF tree."""
checkpointer = obc.PyTreeCheckpointer()
ckpt = checkpointer.restore(checkpoint_path)
hf_tree: dict[str, torch.Tensor] = {}
orbax_tree_flat = tree.flatten_with_path(ckpt)

def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> None:
hf_tree[path] = torch.from_numpy(weights.astype("float32")).type(target_dtype)
Expand All @@ -473,7 +520,7 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No
target_dtype,
)

for paths, value in tree.flatten_with_path(ckpt):
for paths, value in orbax_tree_flat:
if paths[0].startswith("SigLiPFromPatches_"):
if config.vision_config is None:
continue
Expand All @@ -482,17 +529,21 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No
update_tree(path, weights, config.vision_config.dtype)
else:
for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value):
if config.vision_config is None:
if variant in _TEXT_ONLY_VARIANTS:
path = path[len("language_model.") :]
if variant == _VARIANT_EMBEDDINGGEMMA:
path = path[len("model.") :]

update_tree(path, weights, config.text_config.dtype)

if config.vision_config is None:
if variant == _VARIANT_EMBEDDINGGEMMA:
return hf_tree, [weight[1].T for weight in orbax_tree_flat[: _NUM_LINEAR_LAYERS.value]]
elif config.vision_config is None:
hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"]
else:
hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"]

return hf_tree
return hf_tree, None


def main(*args):
Expand All @@ -504,7 +555,7 @@ def main(*args):
config = _VARIANTS[variant]
config.text_config.dtype = getattr(torch, _TRANSFORMER_DTYPE.value)

if variant == _VARIANT_GEMMA_3_1B:
if variant in _TEXT_ONLY_VARIANTS:
config.vision_config = None
else:
config.vision_config.dtype = getattr(torch, _VISION_DTYPE.value)
Expand All @@ -520,11 +571,13 @@ def main(*args):
_TRANSFORMER_DTYPE.value,
_VISION_DTYPE.value,
)
state_tree = convert(_CHECKPOINT_PATH.value, config)
state_tree, st_linears = convert(_CHECKPOINT_PATH.value, config, variant)
logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant)

with accelerate.init_empty_weights():
if variant == _VARIANT_GEMMA_3_1B:
if variant == _VARIANT_EMBEDDINGGEMMA:
model = Gemma3TextModel(config=config.text_config)
elif variant in _TEXT_ONLY_VARIANTS:
model = Gemma3ForCausalLM(config=config.text_config)
else:
model = Gemma3ForConditionalGeneration(config)
Expand All @@ -548,6 +601,8 @@ def main(*args):
tokenizer = GemmaTokenizerFast(
_TOKENIZER_PATH.value,
add_bos_token=True,
add_eos_token=variant == _VARIANT_EMBEDDINGGEMMA,
padding_side="right" if variant == _VARIANT_EMBEDDINGGEMMA else "left",
extra_special_tokens={
"image_token": "<image_soft_token>", # Should be ID=262_144
"boi_token": "<start_of_image>", # Should be ID=255_999
Expand All @@ -558,7 +613,7 @@ def main(*args):
tokenizer.save_pretrained(output_path)
logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path)

if variant != _VARIANT_GEMMA_3_1B:
if variant not in _TEXT_ONLY_VARIANTS:
image_processor = Gemma3ImageProcessor(
image_seq_length=256,
image_mean=(0.5,) * 3,
Expand Down Expand Up @@ -589,6 +644,46 @@ def main(*args):
)
generation_config.save_pretrained(output_path)

if variant == _VARIANT_EMBEDDINGGEMMA:
from sentence_transformers import SentenceTransformer, models

# TODO: Support Retrieval tasks where we use `"title: {title} | text: {passage}"` interally and construct this
# from split-records cached data, but externally these come through as a single string with components
# separated by a newline. This should be used for `passage` for SentenceTransformers and the relevant MTEB
# Retrieval tasks.
# https://github.com/embeddings-benchmark/mteb/blob/main/docs/usage/usage.md#running-sentencetransformer-model-with-prompts
task_prompts = {
"query": "task: search result | query: ",
"document": "title: none | text: ",
"BitextMining": "task: search result | query: ",
"Clustering": "task: clustering | query: ",
"Classification": "task: classification | query: ",
"InstructionRetrieval": "task: code retrieval | query: ",
"MultilabelClassification": "task: classification | query: ",
"PairClassification": "task: sentence similarity | query: ",
"Reranking": "task: search result | query: ",
"Retrieval": "task: search result | query: ",
"Retrieval-query": "task: search result | query: ",
"Retrieval-document": "title: none | text: ",
"STS": "task: sentence similarity | query: ",
"Summarization": "task: summarization | query: ",
}

transformer = models.Transformer(output_path)
pooling = models.Pooling(config.text_config.hidden_size, pooling_mode="mean")
normalize = models.Normalize()
linears = []

for linear_weight in st_linears:
out_size, in_size = linear_weight.shape[:2]
dense = models.Dense(in_size, out_size, bias=False, activation_function=None)
dense.linear.weight.data = torch.from_numpy(linear_weight.astype("float32"))
linears.append(dense)

model = SentenceTransformer(modules=[transformer, pooling, *linears, normalize], prompts=task_prompts)
model = model.to(getattr(torch, _TRANSFORMER_DTYPE.value))
model.save_pretrained(output_path)


if __name__ == "__main__":
app.run(main)
21 changes: 20 additions & 1 deletion src/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,19 @@ def _init_weights(self, module):
module.mm_input_projection_weight.data.zero_()


def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
"""
Enables a bidirectional mask within the sliding window.
"""

def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
"""A token can attend to any other token if their absolute distance is within
half the sliding window size (distance <= sliding_window // 2)."""
return abs(q_idx - kv_idx) <= sliding_window // 2

return inner_mask


@auto_docstring
class Gemma3TextModel(Gemma3PreTrainedModel):
config: Gemma3TextConfig
Expand Down Expand Up @@ -531,10 +544,16 @@ def forward(
"past_key_values": past_key_values,
"position_ids": position_ids,
}
sliding_mask_kwargs = mask_kwargs.copy()

if self.config.use_bidirectional_attention:
mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window)

# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
}

# embed positions
Expand Down
Loading