From 1ebb3bb3604fb4b75b76a3b7d9e31ff4d34362f9 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Thu, 17 Jul 2025 19:48:30 +0000 Subject: [PATCH 01/19] Gemma 3 for Embeddings --- .../models/gemma3/configuration_gemma3.py | 2 + .../convert_gemma3_weights_orbax_to_hf.py | 123 ++++++++++++++---- .../models/gemma3/modeling_gemma3.py | 21 ++- .../models/gemma3/modular_gemma3.py | 23 +++- 4 files changed, 141 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/gemma3/configuration_gemma3.py b/src/transformers/models/gemma3/configuration_gemma3.py index c0184c1993d3..bc8f62013f13 100644 --- a/src/transformers/models/gemma3/configuration_gemma3.py +++ b/src/transformers/models/gemma3/configuration_gemma3.py @@ -193,6 +193,7 @@ def __init__( attn_logit_softcapping=None, rope_scaling=None, rope_local_base_freq=10_000.0, + use_bidirectional_attention=False, **kwargs, ): super().__init__( @@ -222,6 +223,7 @@ def __init__( self.final_logit_softcapping = final_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping self.layer_types = layer_types + self.use_bidirectional_attention = use_bidirectional_attention self.rope_local_base_freq = rope_local_base_freq self.rope_scaling = rope_scaling diff --git a/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py b/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py index 6bd2b7da4cc0..257864728b3e 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py @@ -16,7 +16,7 @@ r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint. -python -m transformers.models.gemma3.convert_gemma3_weights_orbax_to_hf \ +python src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py \ --variant='gemma3_4b' \ --tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \ --checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \ @@ -24,7 +24,7 @@ """ from collections.abc import Iterator, Sequence -from typing import Any +from typing import Any, Optional import accelerate import numpy as np @@ -40,6 +40,7 @@ Gemma3ImageProcessor, Gemma3Processor, Gemma3TextConfig, + Gemma3TextModel, GemmaTokenizerFast, GenerationConfig, SiglipVisionConfig, @@ -100,10 +101,10 @@ _SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK) _SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm" -_TRANSFORMER_DECODER_BLOCK = "transformer/layer_" +_TRANSFORMER_DECODER_BLOCK = "/layer_" _TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK) -_TRANSFORMER_EMBEDDER = "transformer/embedder" -_TRANSFORMER_FINAL_NORM = "transformer/final_norm" +_TRANSFORMER_EMBEDDER = "/embedder" +_TRANSFORMER_FINAL_NORM = "/final_norm" _TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/" _TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX) @@ -121,11 +122,46 @@ "vision_use_head": False, } +_VARIANT_GEMMA_3_EMBEDDING = "gemma3_embedding" +_VARIANT_GEMMA_3_500M = "gemma3_500m" _VARIANT_GEMMA_3_1B = "gemma3_1b" _VARIANT_GEMMA_3_4B = "gemma3_4b" _VARIANT_GEMMA_3_12B = "gemma3_12b" _VARIANT_GEMMA_3_27B = "gemma3_27b" _VARIANTS = { + _VARIANT_GEMMA_3_EMBEDDING: Gemma3Config( + text_config=Gemma3TextConfig( + vocab_size=262_144, + hidden_size=768, + intermediate_size=1152, + num_hidden_layers=24, + num_attention_heads=3, + num_key_value_heads=1, + head_dim=256, + max_position_embeddings=1024, + query_pre_attn_scalar=256, + sliding_window=512, + rope_scaling=None, + use_bidirectional_attention=True, + ), + vision_config=None, + ), + _VARIANT_GEMMA_3_500M: Gemma3Config( + text_config=Gemma3TextConfig( + vocab_size=262_144, + hidden_size=768, + intermediate_size=1152, + num_hidden_layers=24, + num_attention_heads=3, + num_key_value_heads=1, + head_dim=256, + max_position_embeddings=32768, + query_pre_attn_scalar=256, + sliding_window=512, + rope_scaling=None, + ), + vision_config=None, + ), _VARIANT_GEMMA_3_1B: Gemma3Config( text_config=Gemma3TextConfig( vocab_size=262_144, @@ -200,6 +236,8 @@ ), } +_TEXT_ONLY_VARIANTS = (_VARIANT_GEMMA_3_EMBEDDING, _VARIANT_GEMMA_3_500M, _VARIANT_GEMMA_3_1B) + # ==== Flags ==== _CHECKPOINT_PATH = flags.DEFINE_string( @@ -220,6 +258,12 @@ required=True, ) +_NUM_LINEAR_LAYERS = flags.DEFINE_integer( + name="num_linear_layers", + default=1, + help="Number of linear projection layers at the end of the Sentence Transformer.", +) + _TRANSFORMER_DTYPE = flags.DEFINE_enum( name="text_dtype", default="bfloat16", @@ -358,12 +402,12 @@ def convert_transformer_weights( attn_head_dim = config.num_attention_heads * config.head_dim kv_head_dim = config.num_key_value_heads * config.head_dim - if path == _TRANSFORMER_EMBEDDER: + if path.endswith(_TRANSFORMER_EMBEDDER): if prop == "input_embedding": # Tied to language_model.lm_head.weight, assigned at the end. converted_paths = ["language_model.model.embed_tokens.weight"] - if _VARIANT.value != _VARIANT_GEMMA_3_1B: + if _VARIANT.value not in _TEXT_ONLY_VARIANTS: # Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama pre_expansion_embeddings = weights mu = np.mean(pre_expansion_embeddings, axis=0) @@ -372,12 +416,12 @@ def convert_transformer_weights( weights = np.vstack([pre_expansion_embeddings, new_embeddings]) converted_weights = [weights] - elif _VARIANT.value == _VARIANT_GEMMA_3_1B or prop in ("mm_output_embedding", "mm_input_embedding_extra"): + elif _VARIANT.value in _TEXT_ONLY_VARIANTS or prop in ("mm_output_embedding", "mm_input_embedding_extra"): return zip([], []) else: raise ValueError(f"Unexpected member, {prop}, in Embedder.") elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"): - if _VARIANT.value == _VARIANT_GEMMA_3_1B: + if _VARIANT.value in _TEXT_ONLY_VARIANTS: return zip([], []) if path.endswith("/mm_input_projection"): @@ -388,14 +432,16 @@ def convert_transformer_weights( converted_weights = [weights] else: raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.") - elif path == _TRANSFORMER_FINAL_NORM: + elif path.endswith(_TRANSFORMER_FINAL_NORM): converted_paths = ["language_model.model.norm.weight"] converted_weights = [weights] - elif path.startswith(_TRANSFORMER_DECODER_BLOCK): - decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:] - next_path_separator_idx = decoder_block_path.find("/") - layer_idx = decoder_block_path[:next_path_separator_idx] - decoder_block_path = decoder_block_path[next_path_separator_idx:] + elif _TRANSFORMER_DECODER_BLOCK in path: + decoder_block_start = path.find(_TRANSFORMER_DECODER_BLOCK) + decoder_block_offset = decoder_block_start + _TRANSFORMER_DECODER_BLOCK_LEN + decoder_block_path = path[decoder_block_offset:] + next_path_seperator_idx = decoder_block_path.find("/") + layer_idx = decoder_block_path[:next_path_seperator_idx] + decoder_block_path = decoder_block_path[next_path_seperator_idx:] base_path = f"language_model.model.layers.{layer_idx}" @@ -445,8 +491,6 @@ def convert_transformer_weights( converted_weights = [weights] else: raise ValueError(f"Unexpected path `{path}` in Decoder Block.") - else: - raise ValueError(f"Unexpected path `{path}`.") if (cpl := len(converted_paths)) != (cwl := len(converted_weights)): raise ValueError( @@ -457,11 +501,14 @@ def convert_transformer_weights( return zip(converted_paths, converted_weights) -def convert(checkpoint_path: str, config: Gemma3Config) -> dict[str, torch.Tensor]: +def convert( + checkpoint_path: str, config: Gemma3Config, variant: str +) -> tuple[dict[str, torch.Tensor], Optional[Sequence[np.ndarray]]]: """Loads Orbax checkpoint from `input_path` and converts it to HF tree.""" checkpointer = obc.PyTreeCheckpointer() ckpt = checkpointer.restore(checkpoint_path) hf_tree: dict[str, torch.Tensor] = {} + orbax_tree_flat = tree.flatten_with_path(ckpt) def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> None: hf_tree[path] = torch.from_numpy(weights.astype("float32")).type(target_dtype) @@ -473,7 +520,7 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No target_dtype, ) - for paths, value in tree.flatten_with_path(ckpt): + for paths, value in orbax_tree_flat: if paths[0].startswith("SigLiPFromPatches_"): if config.vision_config is None: continue @@ -482,17 +529,21 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No update_tree(path, weights, config.vision_config.dtype) else: for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value): - if config.vision_config is None: + if variant in _TEXT_ONLY_VARIANTS: path = path[len("language_model.") :] + if variant == _VARIANT_GEMMA_3_EMBEDDING: + path = path[len("model.") :] update_tree(path, weights, config.text_config.dtype) - if config.vision_config is None: + if variant == _VARIANT_GEMMA_3_EMBEDDING: + return hf_tree, [weight[1].T for weight in orbax_tree_flat[:_NUM_LINEAR_LAYERS.value]] + elif config.vision_config is None: hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"] else: hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"] - return hf_tree + return hf_tree, None def main(*args): @@ -504,7 +555,7 @@ def main(*args): config = _VARIANTS[variant] config.text_config.dtype = getattr(torch, _TRANSFORMER_DTYPE.value) - if variant == _VARIANT_GEMMA_3_1B: + if variant in _TEXT_ONLY_VARIANTS: config.vision_config = None else: config.vision_config.dtype = getattr(torch, _VISION_DTYPE.value) @@ -520,11 +571,13 @@ def main(*args): _TRANSFORMER_DTYPE.value, _VISION_DTYPE.value, ) - state_tree = convert(_CHECKPOINT_PATH.value, config) + state_tree, st_linears = convert(_CHECKPOINT_PATH.value, config, variant) logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant) with accelerate.init_empty_weights(): - if variant == _VARIANT_GEMMA_3_1B: + if variant == _VARIANT_GEMMA_3_EMBEDDING: + model = Gemma3TextModel(config=config.text_config) + elif variant in _TEXT_ONLY_VARIANTS: model = Gemma3ForCausalLM(config=config.text_config) else: model = Gemma3ForConditionalGeneration(config) @@ -558,7 +611,7 @@ def main(*args): tokenizer.save_pretrained(output_path) logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path) - if variant != _VARIANT_GEMMA_3_1B: + if variant not in _TEXT_ONLY_VARIANTS: image_processor = Gemma3ImageProcessor( image_seq_length=256, image_mean=(0.5,) * 3, @@ -589,6 +642,24 @@ def main(*args): ) generation_config.save_pretrained(output_path) + if _VARIANT.value == _VARIANT_GEMMA_3_EMBEDDING: + from sentence_transformers import SentenceTransformer, models + + transformer = models.Transformer(output_path) + pooling = models.Pooling(config.text_config.hidden_size, pooling_mode="mean") + linears = [] + + for linear_weight in st_linears: + in_size, out_size = linear_weight.shape[:2] + dense = models.Dense(in_size, out_size, bias=False, activation_function=None) + dense.linear.weight.data = torch.from_numpy( + linear_weight.astype("float32") + ).type(getattr(torch, _TRANSFORMER_DTYPE.value)) + linears.append(dense) + + model = SentenceTransformer(modules=[transformer, pooling, *linears]) + model.save_pretrained(output_path) + if __name__ == "__main__": app.run(main) diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 2b60466d7ff1..eec46bee5473 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -443,6 +443,19 @@ def _init_weights(self, module): module.mm_input_projection_weight.data.zero_() +def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]: + """ + Enables a bidirectional mask within the sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + """A token can attend to any other token if their absolute distance is within + half the sliding window size (distance <= sliding_window // 2).""" + return abs(q_idx - kv_idx) <= sliding_window // 2 + + return inner_mask + + @auto_docstring class Gemma3TextModel(Gemma3PreTrainedModel): config: Gemma3TextConfig @@ -531,10 +544,16 @@ def forward( "past_key_values": past_key_values, "position_ids": position_ids, } + sliding_mask_kwargs = mask_kwargs.copy() + + if self.config.use_bidirectional_attention: + mask_kwargs["or_mask_function"] = lambda *args, **kwargs: True + sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window) + # Create the masks causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), - "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), } # embed positions diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 947a22ab8eaa..7c2df31ba337 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -204,6 +204,7 @@ def __init__( attn_logit_softcapping=None, rope_scaling=None, rope_local_base_freq=10_000.0, + use_bidirectional_attention=False, **kwargs, ): PretrainedConfig.__init__( @@ -233,6 +234,7 @@ def __init__( self.final_logit_softcapping = final_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping self.layer_types = layer_types + self.use_bidirectional_attention = use_bidirectional_attention self.rope_local_base_freq = rope_local_base_freq self.rope_scaling = rope_scaling @@ -535,6 +537,19 @@ def _init_weights(self, module): module.mm_input_projection_weight.data.zero_() +def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]: + """ + Enables a bidirectional mask within the sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + """A token can attend to any other token if their absolute distance is within + half the sliding window size (distance <= sliding_window // 2).""" + return abs(q_idx - kv_idx) <= sliding_window // 2 + + return inner_mask + + class Gemma3TextModel(Gemma2Model): config: Gemma3TextConfig @@ -609,10 +624,16 @@ def forward( "past_key_values": past_key_values, "position_ids": position_ids, } + sliding_mask_kwargs = mask_kwargs.copy() + + if self.config.use_bidirectional_attention: + mask_kwargs["or_mask_function"] = lambda *args, **kwargs: True + sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window) + # Create the masks causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), - "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), } # embed positions From 6a8eb6983444c300ce3bb0e03448c9b95d57b551 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Wed, 23 Jul 2025 13:27:24 +0000 Subject: [PATCH 02/19] Style fixes --- src/transformers/models/gemma3/configuration_gemma3.py | 2 ++ .../models/gemma3/convert_gemma3_weights_orbax_to_hf.py | 8 ++++---- src/transformers/models/gemma3/modular_gemma3.py | 2 ++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gemma3/configuration_gemma3.py b/src/transformers/models/gemma3/configuration_gemma3.py index bc8f62013f13..b1ec3311ba66 100644 --- a/src/transformers/models/gemma3/configuration_gemma3.py +++ b/src/transformers/models/gemma3/configuration_gemma3.py @@ -136,6 +136,8 @@ class Gemma3TextConfig(PretrainedConfig): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE rope_local_base_freq (float, *optional*, defaults to 10000.0): The base period of the RoPE embeddings for local attention. + use_bidirectional_attention (`bool`, *optional*, defaults to `False`): If True, the model will attend to all + text tokens instead of using a causal mask. This does not change behavior for vision tokens. ```python >>> from transformers import Gemma3TextModel, Gemma3TextConfig diff --git a/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py b/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py index 257864728b3e..d6af44417a65 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py @@ -537,7 +537,7 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No update_tree(path, weights, config.text_config.dtype) if variant == _VARIANT_GEMMA_3_EMBEDDING: - return hf_tree, [weight[1].T for weight in orbax_tree_flat[:_NUM_LINEAR_LAYERS.value]] + return hf_tree, [weight[1].T for weight in orbax_tree_flat[: _NUM_LINEAR_LAYERS.value]] elif config.vision_config is None: hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"] else: @@ -652,9 +652,9 @@ def main(*args): for linear_weight in st_linears: in_size, out_size = linear_weight.shape[:2] dense = models.Dense(in_size, out_size, bias=False, activation_function=None) - dense.linear.weight.data = torch.from_numpy( - linear_weight.astype("float32") - ).type(getattr(torch, _TRANSFORMER_DTYPE.value)) + dense.linear.weight.data = torch.from_numpy(linear_weight.astype("float32")).type( + getattr(torch, _TRANSFORMER_DTYPE.value) + ) linears.append(dense) model = SentenceTransformer(modules=[transformer, pooling, *linears]) diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 7c2df31ba337..3625d1ff058e 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -162,6 +162,8 @@ class Gemma3TextConfig(Gemma2Config, PretrainedConfig): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE rope_local_base_freq (float, *optional*, defaults to 10000.0): The base period of the RoPE embeddings for local attention. + use_bidirectional_attention (`bool`, *optional*, defaults to `False`): If True, the model will attend to all + text tokens instead of using a causal mask. This does not change behavior for vision tokens. ```python >>> from transformers import Gemma3TextModel, Gemma3TextConfig From f640a94a0c4c8d5a45c0183288727b5360b3b17a Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Wed, 23 Jul 2025 19:57:25 +0000 Subject: [PATCH 03/19] Rename conversion file for consistency --- ..._gemma3_weights_orbax_to_hf.py => convert_gemma3_weights.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename src/transformers/models/gemma3/{convert_gemma3_weights_orbax_to_hf.py => convert_gemma3_weights.py} (99%) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py b/src/transformers/models/gemma3/convert_gemma3_weights.py similarity index 99% rename from src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py rename to src/transformers/models/gemma3/convert_gemma3_weights.py index d6af44417a65..60d108e3eee8 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -16,7 +16,7 @@ r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint. -python src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py \ +python src/transformers/models/gemma3/convert_gemma3_weights.py \ --variant='gemma3_4b' \ --tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \ --checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \ From 9390afeff3825e7bb3d99190a9335fedb65176bc Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Wed, 30 Jul 2025 12:44:37 +0000 Subject: [PATCH 04/19] Default padding side emb vs gen --- src/transformers/models/gemma3/convert_gemma3_weights.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 60d108e3eee8..7b3ea1617b64 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -601,6 +601,7 @@ def main(*args): tokenizer = GemmaTokenizerFast( _TOKENIZER_PATH.value, add_bos_token=True, + padding_side="right" if variant == _VARIANT_GEMMA_3_EMBEDDING else "left" extra_special_tokens={ "image_token": "", # Should be ID=262_144 "boi_token": "", # Should be ID=255_999 @@ -642,7 +643,7 @@ def main(*args): ) generation_config.save_pretrained(output_path) - if _VARIANT.value == _VARIANT_GEMMA_3_EMBEDDING: + if variant == _VARIANT_GEMMA_3_EMBEDDING: from sentence_transformers import SentenceTransformer, models transformer = models.Transformer(output_path) From e0d65e341b9f00423c5296b9940c19a22cbe8cb5 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Wed, 30 Jul 2025 18:55:09 +0000 Subject: [PATCH 05/19] Corrected 270m config --- .../models/gemma3/convert_gemma3_weights.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 7b3ea1617b64..bfb72fb0afa7 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -123,7 +123,7 @@ } _VARIANT_GEMMA_3_EMBEDDING = "gemma3_embedding" -_VARIANT_GEMMA_3_500M = "gemma3_500m" +_VARIANT_GEMMA_3_270M = "gemma3_270m" _VARIANT_GEMMA_3_1B = "gemma3_1b" _VARIANT_GEMMA_3_4B = "gemma3_4b" _VARIANT_GEMMA_3_12B = "gemma3_12b" @@ -146,13 +146,13 @@ ), vision_config=None, ), - _VARIANT_GEMMA_3_500M: Gemma3Config( + _VARIANT_GEMMA_3_270M: Gemma3Config( text_config=Gemma3TextConfig( vocab_size=262_144, - hidden_size=768, - intermediate_size=1152, - num_hidden_layers=24, - num_attention_heads=3, + hidden_size=640, + intermediate_size=2048, + num_hidden_layers=18, + num_attention_heads=4, num_key_value_heads=1, head_dim=256, max_position_embeddings=32768, @@ -236,7 +236,7 @@ ), } -_TEXT_ONLY_VARIANTS = (_VARIANT_GEMMA_3_EMBEDDING, _VARIANT_GEMMA_3_500M, _VARIANT_GEMMA_3_1B) +_TEXT_ONLY_VARIANTS = (_VARIANT_GEMMA_3_EMBEDDING, _VARIANT_GEMMA_3_270M, _VARIANT_GEMMA_3_1B) # ==== Flags ==== @@ -468,6 +468,7 @@ def convert_transformer_weights( converted_paths = [f"{base_path}.self_attn.q_norm.weight"] converted_weights = [weights] elif path.endswith("mlp/gating_einsum"): + converted_paths = [ f"{base_path}.mlp.gate_proj.weight", f"{base_path}.mlp.up_proj.weight", @@ -601,7 +602,7 @@ def main(*args): tokenizer = GemmaTokenizerFast( _TOKENIZER_PATH.value, add_bos_token=True, - padding_side="right" if variant == _VARIANT_GEMMA_3_EMBEDDING else "left" + padding_side="right" if variant == _VARIANT_GEMMA_3_EMBEDDING else "left", extra_special_tokens={ "image_token": "", # Should be ID=262_144 "boi_token": "", # Should be ID=255_999 From f338c42ce520a38fe044480fb3761b6720a90e38 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Wed, 30 Jul 2025 19:04:08 +0000 Subject: [PATCH 06/19] style fixes --- src/transformers/models/gemma3/convert_gemma3_weights.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index bfb72fb0afa7..74ef0e4137b2 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -468,7 +468,6 @@ def convert_transformer_weights( converted_paths = [f"{base_path}.self_attn.q_norm.weight"] converted_weights = [weights] elif path.endswith("mlp/gating_einsum"): - converted_paths = [ f"{base_path}.mlp.gate_proj.weight", f"{base_path}.mlp.up_proj.weight", From 6aaa60e3aa269252e1c659122dcf22076088373f Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Sat, 2 Aug 2025 17:13:25 +0000 Subject: [PATCH 07/19] EmbeddingGemma config --- .../models/gemma3/convert_gemma3_weights.py | 33 +++++-------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 74ef0e4137b2..4d59d666e2eb 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -122,14 +122,13 @@ "vision_use_head": False, } -_VARIANT_GEMMA_3_EMBEDDING = "gemma3_embedding" -_VARIANT_GEMMA_3_270M = "gemma3_270m" +_VARIANT_EMBEDDINGGEMMA = "embedding" _VARIANT_GEMMA_3_1B = "gemma3_1b" _VARIANT_GEMMA_3_4B = "gemma3_4b" _VARIANT_GEMMA_3_12B = "gemma3_12b" _VARIANT_GEMMA_3_27B = "gemma3_27b" _VARIANTS = { - _VARIANT_GEMMA_3_EMBEDDING: Gemma3Config( + _VARIANT_EMBEDDINGGEMMA: Gemma3Config( text_config=Gemma3TextConfig( vocab_size=262_144, hidden_size=768, @@ -146,22 +145,6 @@ ), vision_config=None, ), - _VARIANT_GEMMA_3_270M: Gemma3Config( - text_config=Gemma3TextConfig( - vocab_size=262_144, - hidden_size=640, - intermediate_size=2048, - num_hidden_layers=18, - num_attention_heads=4, - num_key_value_heads=1, - head_dim=256, - max_position_embeddings=32768, - query_pre_attn_scalar=256, - sliding_window=512, - rope_scaling=None, - ), - vision_config=None, - ), _VARIANT_GEMMA_3_1B: Gemma3Config( text_config=Gemma3TextConfig( vocab_size=262_144, @@ -236,7 +219,7 @@ ), } -_TEXT_ONLY_VARIANTS = (_VARIANT_GEMMA_3_EMBEDDING, _VARIANT_GEMMA_3_270M, _VARIANT_GEMMA_3_1B) +_TEXT_ONLY_VARIANTS = (_VARIANT_EMBEDDINGGEMMA, _VARIANT_GEMMA_3_1B) # ==== Flags ==== @@ -531,12 +514,12 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value): if variant in _TEXT_ONLY_VARIANTS: path = path[len("language_model.") :] - if variant == _VARIANT_GEMMA_3_EMBEDDING: + if variant == _VARIANT_EMBEDDINGGEMMA: path = path[len("model.") :] update_tree(path, weights, config.text_config.dtype) - if variant == _VARIANT_GEMMA_3_EMBEDDING: + if variant == _VARIANT_EMBEDDINGGEMMA: return hf_tree, [weight[1].T for weight in orbax_tree_flat[: _NUM_LINEAR_LAYERS.value]] elif config.vision_config is None: hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"] @@ -575,7 +558,7 @@ def main(*args): logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant) with accelerate.init_empty_weights(): - if variant == _VARIANT_GEMMA_3_EMBEDDING: + if variant == _VARIANT_EMBEDDINGGEMMA: model = Gemma3TextModel(config=config.text_config) elif variant in _TEXT_ONLY_VARIANTS: model = Gemma3ForCausalLM(config=config.text_config) @@ -601,7 +584,7 @@ def main(*args): tokenizer = GemmaTokenizerFast( _TOKENIZER_PATH.value, add_bos_token=True, - padding_side="right" if variant == _VARIANT_GEMMA_3_EMBEDDING else "left", + padding_side="right" if variant == _VARIANT_EMBEDDINGGEMMA else "left", extra_special_tokens={ "image_token": "", # Should be ID=262_144 "boi_token": "", # Should be ID=255_999 @@ -643,7 +626,7 @@ def main(*args): ) generation_config.save_pretrained(output_path) - if variant == _VARIANT_GEMMA_3_EMBEDDING: + if variant == _VARIANT_EMBEDDINGGEMMA: from sentence_transformers import SentenceTransformer, models transformer = models.Transformer(output_path) From 52c520d77a449aaa53ab188dc9888d74a645bbe2 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Sat, 2 Aug 2025 17:19:04 +0000 Subject: [PATCH 08/19] TODO for built-in prompts --- src/transformers/models/gemma3/convert_gemma3_weights.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 4d59d666e2eb..1a9450f24f7d 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -641,6 +641,10 @@ def main(*args): ) linears.append(dense) + # TODO - sindhusindhuraghuram: Add `prompts` to cover: 1) the `query` and `document` default options for + # SentenceTransformers; and 2) any MTEB tasks we want to specifically include for reproducibility purposes, + # following the docs at + # https://github.com/embeddings-benchmark/mteb/blob/main/docs/usage/usage.md#running-sentencetransformer-model-with-prompts model = SentenceTransformer(modules=[transformer, pooling, *linears]) model.save_pretrained(output_path) From 352c6862638a56d243079d09e8cc32d6ca7d3760 Mon Sep 17 00:00:00 2001 From: Sindhu Raghuram Date: Fri, 15 Aug 2025 19:19:32 +0000 Subject: [PATCH 09/19] Resolving the sentence similarity bug and updating the architecture --- .../models/gemma3/convert_gemma3_weights.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 1a9450f24f7d..9f6020bb7e27 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -584,6 +584,7 @@ def main(*args): tokenizer = GemmaTokenizerFast( _TOKENIZER_PATH.value, add_bos_token=True, + add_eos_token=True if variant == _VARIANT_EMBEDDINGGEMMA else False, padding_side="right" if variant == _VARIANT_EMBEDDINGGEMMA else "left", extra_special_tokens={ "image_token": "", # Should be ID=262_144 @@ -629,12 +630,23 @@ def main(*args): if variant == _VARIANT_EMBEDDINGGEMMA: from sentence_transformers import SentenceTransformer, models + task_prompts = { + 'clustering': 'task: clustering | query: ', + 'classification': 'task: classification | query: ', + 'question_answering': 'task: question answering | query: ', + 'search_result': 'task: search result | query: ', + 'sentence_similarity': 'task: sentence similarity | query: ', + 'fact_checking': 'task: fact checking | query: ', + 'retrieval_document': 'title: | text: ' + } + transformer = models.Transformer(output_path) pooling = models.Pooling(config.text_config.hidden_size, pooling_mode="mean") + normalize = models.Normalize() linears = [] - + for linear_weight in st_linears: - in_size, out_size = linear_weight.shape[:2] + out_size, in_size = linear_weight.shape[:2] dense = models.Dense(in_size, out_size, bias=False, activation_function=None) dense.linear.weight.data = torch.from_numpy(linear_weight.astype("float32")).type( getattr(torch, _TRANSFORMER_DTYPE.value) @@ -645,7 +657,7 @@ def main(*args): # SentenceTransformers; and 2) any MTEB tasks we want to specifically include for reproducibility purposes, # following the docs at # https://github.com/embeddings-benchmark/mteb/blob/main/docs/usage/usage.md#running-sentencetransformer-model-with-prompts - model = SentenceTransformer(modules=[transformer, pooling, *linears]) + model = SentenceTransformer(modules=[transformer, pooling, *linears, normalize], prompts=task_prompts) model.save_pretrained(output_path) From d0d87e95e28d5bb5db745be26040aeba21eef4ec Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Mon, 18 Aug 2025 15:18:57 +0000 Subject: [PATCH 10/19] code style --- .../models/gemma3/convert_gemma3_weights.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 9f6020bb7e27..13e05428066c 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -631,20 +631,20 @@ def main(*args): from sentence_transformers import SentenceTransformer, models task_prompts = { - 'clustering': 'task: clustering | query: ', - 'classification': 'task: classification | query: ', - 'question_answering': 'task: question answering | query: ', - 'search_result': 'task: search result | query: ', - 'sentence_similarity': 'task: sentence similarity | query: ', - 'fact_checking': 'task: fact checking | query: ', - 'retrieval_document': 'title: | text: ' + "clustering": "task: clustering | query: ", + "classification": "task: classification | query: ", + "question_answering": "task: question answering | query: ", + "search_result": "task: search result | query: ", + "sentence_similarity": "task: sentence similarity | query: ", + "fact_checking": "task: fact checking | query: ", + "retrieval_document": "title: | text: ", } transformer = models.Transformer(output_path) pooling = models.Pooling(config.text_config.hidden_size, pooling_mode="mean") normalize = models.Normalize() linears = [] - + for linear_weight in st_linears: out_size, in_size = linear_weight.shape[:2] dense = models.Dense(in_size, out_size, bias=False, activation_function=None) From 2e0d4d953410dc0899ce221e36f42ce6a8bb89a4 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Mon, 18 Aug 2025 15:27:52 +0000 Subject: [PATCH 11/19] Add query prompt for SentenceTransformers --- .../models/gemma3/convert_gemma3_weights.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 13e05428066c..1eeb9b527cfe 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -630,6 +630,11 @@ def main(*args): if variant == _VARIANT_EMBEDDINGGEMMA: from sentence_transformers import SentenceTransformer, models + # TODO: Support Retrieval tasks where we use `"title: {title} | text: {passage}"` interally and construct this + # from split-records cached data, but externally these come through as a single string with components + # separated by a newline. This should be used for `passage` for SentenceTransformers and the relevant MTEB + # Retrieval tasks. + # https://github.com/embeddings-benchmark/mteb/blob/main/docs/usage/usage.md#running-sentencetransformer-model-with-prompts task_prompts = { "clustering": "task: clustering | query: ", "classification": "task: classification | query: ", @@ -638,6 +643,7 @@ def main(*args): "sentence_similarity": "task: sentence similarity | query: ", "fact_checking": "task: fact checking | query: ", "retrieval_document": "title: | text: ", + "query": "task: search result | query: ", } transformer = models.Transformer(output_path) @@ -653,10 +659,6 @@ def main(*args): ) linears.append(dense) - # TODO - sindhusindhuraghuram: Add `prompts` to cover: 1) the `query` and `document` default options for - # SentenceTransformers; and 2) any MTEB tasks we want to specifically include for reproducibility purposes, - # following the docs at - # https://github.com/embeddings-benchmark/mteb/blob/main/docs/usage/usage.md#running-sentencetransformer-model-with-prompts model = SentenceTransformer(modules=[transformer, pooling, *linears, normalize], prompts=task_prompts) model.save_pretrained(output_path) From 3eaae78dfaa99407d9ba8343ee8ecf1da5860210 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Mon, 18 Aug 2025 15:56:33 +0000 Subject: [PATCH 12/19] Code quality --- src/transformers/models/gemma3/convert_gemma3_weights.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 1eeb9b527cfe..9f39819ed1ef 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -584,7 +584,7 @@ def main(*args): tokenizer = GemmaTokenizerFast( _TOKENIZER_PATH.value, add_bos_token=True, - add_eos_token=True if variant == _VARIANT_EMBEDDINGGEMMA else False, + add_eos_token=variant == _VARIANT_EMBEDDINGGEMMA, padding_side="right" if variant == _VARIANT_EMBEDDINGGEMMA else "left", extra_special_tokens={ "image_token": "", # Should be ID=262_144 From 56721a35801728cf9316136f98bd76218f770905 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Tue, 19 Aug 2025 14:59:09 +0000 Subject: [PATCH 13/19] Fixing or_mask_function return types --- src/transformers/models/gemma3/modeling_gemma3.py | 2 +- src/transformers/models/gemma3/modular_gemma3.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index eec46bee5473..d2ba04298dec 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -547,7 +547,7 @@ def forward( sliding_mask_kwargs = mask_kwargs.copy() if self.config.use_bidirectional_attention: - mask_kwargs["or_mask_function"] = lambda *args, **kwargs: True + mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool) sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window) # Create the masks diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 3625d1ff058e..fc70fa6e9d8e 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -629,7 +629,7 @@ def forward( sliding_mask_kwargs = mask_kwargs.copy() if self.config.use_bidirectional_attention: - mask_kwargs["or_mask_function"] = lambda *args, **kwargs: True + mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool) sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window) # Create the masks From 86a7572fde18bb492cde4d6d9e81be5cd15ca2e4 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Tue, 19 Aug 2025 15:17:55 +0000 Subject: [PATCH 14/19] Adding placeholder prompts for document and passage --- src/transformers/models/gemma3/convert_gemma3_weights.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 9f39819ed1ef..7dc488b490e3 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -644,6 +644,8 @@ def main(*args): "fact_checking": "task: fact checking | query: ", "retrieval_document": "title: | text: ", "query": "task: search result | query: ", + "document": "title: | text: ", + "passage": "title: | text: ", } transformer = models.Transformer(output_path) From 70aed03c26f2aebf98e871f3c87d9d559ce1a018 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Thu, 21 Aug 2025 14:01:20 +0000 Subject: [PATCH 15/19] Finalizing prompt templates --- .../models/gemma3/convert_gemma3_weights.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 7dc488b490e3..328c4df322ca 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -636,16 +636,19 @@ def main(*args): # Retrieval tasks. # https://github.com/embeddings-benchmark/mteb/blob/main/docs/usage/usage.md#running-sentencetransformer-model-with-prompts task_prompts = { - "clustering": "task: clustering | query: ", - "classification": "task: classification | query: ", - "question_answering": "task: question answering | query: ", - "search_result": "task: search result | query: ", - "sentence_similarity": "task: sentence similarity | query: ", - "fact_checking": "task: fact checking | query: ", - "retrieval_document": "title: | text: ", "query": "task: search result | query: ", - "document": "title: | text: ", - "passage": "title: | text: ", + "document": "title: none | text: ", + "BitextMining": "task: search result | query: ", + "Clustering": "task: clustering | query: ", + "Classification": "task: classification | query: ", + "InstructionRetrieval": "task: code retrieval | query: ", + "MultilabelClassification": "task: classification | query: ", + "PairClassification": "task: sentence similarity | query: ", + "Reranking": "task: search result | query: ", + "Retrieval-query": "task: search result | query: ", + "Retrieval-document": "title: none | text: ", + "STS": "task: sentence similarity | query: ", + "Summarization": "task: summarization | query: ", } transformer = models.Transformer(output_path) From 6e3c8c76b084cfd075645ff207918bb0ed03c5e6 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Fri, 22 Aug 2025 13:43:45 +0000 Subject: [PATCH 16/19] Adding Retrieval ro preconfigured prompts --- src/transformers/models/gemma3/convert_gemma3_weights.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 328c4df322ca..6363b5552f90 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -645,6 +645,7 @@ def main(*args): "MultilabelClassification": "task: classification | query: ", "PairClassification": "task: sentence similarity | query: ", "Reranking": "task: search result | query: ", + "Retrieval": "task: search result | query: ", "Retrieval-query": "task: search result | query: ", "Retrieval-document": "title: none | text: ", "STS": "task: sentence similarity | query: ", From f92857a8fe091db72204c2fb00788b2e386c2aae Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Fri, 22 Aug 2025 09:51:10 -0400 Subject: [PATCH 17/19] Add Gemma 3 270M Config --- .../models/gemma3/convert_gemma3_weights.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index 6363b5552f90..ec7f63449f31 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -123,6 +123,7 @@ } _VARIANT_EMBEDDINGGEMMA = "embedding" +_VARIANT_GEMMA_3_270M = "gemma3_270m" _VARIANT_GEMMA_3_1B = "gemma3_1b" _VARIANT_GEMMA_3_4B = "gemma3_4b" _VARIANT_GEMMA_3_12B = "gemma3_12b" @@ -145,6 +146,22 @@ ), vision_config=None, ), + _VARIANT_GEMMA_3_270M: Gemma3Config( + text_config=Gemma3TextConfig( + vocab_size=262_144, + hidden_size=640, + intermediate_size=2048, + num_hidden_layers=18, + num_attention_heads=4, + num_key_value_heads=1, + head_dim=256, + max_position_embeddings=32768, + query_pre_attn_scalar=256, + sliding_window=512, + rope_scaling=None, + ), + vision_config=None, + ), _VARIANT_GEMMA_3_1B: Gemma3Config( text_config=Gemma3TextConfig( vocab_size=262_144, @@ -219,7 +236,7 @@ ), } -_TEXT_ONLY_VARIANTS = (_VARIANT_EMBEDDINGGEMMA, _VARIANT_GEMMA_3_1B) +_TEXT_ONLY_VARIANTS = (_VARIANT_EMBEDDINGGEMMA, _VARIANT_GEMMA_3_270M, _VARIANT_GEMMA_3_1B) # ==== Flags ==== From a4631dddab696fe832e1d3b23c30be0174140d3a Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Wed, 27 Aug 2025 00:26:36 +0000 Subject: [PATCH 18/19] Correcting num_linear_layers flag default --- src/transformers/models/gemma3/convert_gemma3_weights.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index ec7f63449f31..bc5b94c4e8dd 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -260,7 +260,7 @@ _NUM_LINEAR_LAYERS = flags.DEFINE_integer( name="num_linear_layers", - default=1, + default=2, help="Number of linear projection layers at the end of the Sentence Transformer.", ) From 6b4fd39a222817b71281ffb11c51bb1536534be8 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Wed, 3 Sep 2025 19:40:35 +0000 Subject: [PATCH 19/19] Export Sentence Transformer in correct dtype --- src/transformers/models/gemma3/convert_gemma3_weights.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gemma3/convert_gemma3_weights.py b/src/transformers/models/gemma3/convert_gemma3_weights.py index bc5b94c4e8dd..8d7a21219197 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -677,12 +677,11 @@ def main(*args): for linear_weight in st_linears: out_size, in_size = linear_weight.shape[:2] dense = models.Dense(in_size, out_size, bias=False, activation_function=None) - dense.linear.weight.data = torch.from_numpy(linear_weight.astype("float32")).type( - getattr(torch, _TRANSFORMER_DTYPE.value) - ) + dense.linear.weight.data = torch.from_numpy(linear_weight.astype("float32")) linears.append(dense) model = SentenceTransformer(modules=[transformer, pooling, *linears, normalize], prompts=task_prompts) + model = model.to(getattr(torch, _TRANSFORMER_DTYPE.value)) model.save_pretrained(output_path)