Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0484787
tmp
gante Sep 16, 2025
4990c46
fix modular inheritance
gante Sep 16, 2025
5424b99
nit
gante Sep 16, 2025
e2f6550
paligemma 1 doesn't have swa
gante Sep 16, 2025
d8d02ff
use same pattern as in models with hybrid layers
gante Sep 17, 2025
93ff456
PR comments
gante Sep 17, 2025
132e35f
helium also needs layer_typed (bc it relies on gemma)
gante Sep 17, 2025
e28a3ed
paligemma/gemma3: same mask creation fn in fwd and generate
gante Sep 17, 2025
76613f0
propagate changes to helium (gemma-based)
gante Sep 17, 2025
2c188f9
tmp commit
gante Sep 17, 2025
5beb976
slow paligemma tests passing, let's see what breaks
gante Sep 17, 2025
6f2d326
fix test_left_padding_compatibility
gante Sep 17, 2025
bef1beb
tmp commit
gante Sep 18, 2025
8b0f34d
tmp commit
gante Sep 18, 2025
90f165a
rebase error
gante Sep 22, 2025
8979531
docs
gante Sep 22, 2025
aac2956
reduce diff
gante Sep 22, 2025
15d99a2
like this?
gante Sep 22, 2025
46362b4
t5gemma
gante Sep 22, 2025
a9d71e1
better comment
gante Sep 22, 2025
8686d39
shorter diff
gante Sep 22, 2025
a3ac80c
exception
gante Sep 22, 2025
40eed3d
ffs type
gante Sep 22, 2025
c725120
optional
gante Sep 22, 2025
a74fb93
shorter modular_gemma.py
gante Sep 23, 2025
b79e312
helium model actually needs no changes -- the tester is the issue
gante Sep 23, 2025
f916d7c
t5gemma modular config
gante Sep 23, 2025
5e518c8
a few more modular; paligemma BC
gante Sep 23, 2025
b0a9d50
fix processor issues?
gante Sep 23, 2025
c0c89b2
rm config exception
gante Sep 23, 2025
3fcf7a7
lift warning in gemma
gante Sep 23, 2025
9bf860d
Merge branch 'main' into flaky_assisted_gen_tests
gante Sep 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,8 +1073,8 @@ def create_masks_for_generate(
**kwargs,
):
"""
This function mimics how we create the masks in the `modeling_xxx.py` files, and is used in `generate` in order
to easily create the masks in advance, when we compile the forwards with Static caches.
This function mimics how we create the masks in the `modeling_xxx.py` files, and is used in places like `generate`
in order to easily create the masks in advance, when we compile the forwards with Static caches.

Args:
config (`PretrainedConfig`):
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/colpali/modular_colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __call__(
)
suffix = output_kwargs["text_kwargs"].pop("suffix", None)

return_token_type_ids = suffix is not None
return_token_type_ids = True

if text is None and images is None:
raise ValueError("Either text or images must be provided")
Expand Down Expand Up @@ -167,7 +167,7 @@ def __call__(

inputs = self.tokenizer(
input_strings,
return_token_type_ids=False,
return_token_type_ids=return_token_type_ids,
**output_kwargs["text_kwargs"],
)

Expand Down Expand Up @@ -197,7 +197,7 @@ def __call__(

batch_query = self.tokenizer(
texts_query,
return_token_type_ids=False,
return_token_type_ids=return_token_type_ids,
**output_kwargs["text_kwargs"],
)

Expand Down
12 changes: 9 additions & 3 deletions src/transformers/models/colpali/processing_colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __call__(
)
suffix = output_kwargs["text_kwargs"].pop("suffix", None)

return_token_type_ids = suffix is not None
return_token_type_ids = True

if text is None and images is None:
raise ValueError("Either text or images must be provided")
Expand Down Expand Up @@ -208,7 +208,7 @@ def __call__(

inputs = self.tokenizer(
input_strings,
return_token_type_ids=False,
return_token_type_ids=return_token_type_ids,
**output_kwargs["text_kwargs"],
)

Expand Down Expand Up @@ -238,7 +238,7 @@ def __call__(

batch_query = self.tokenizer(
texts_query,
return_token_type_ids=False,
return_token_type_ids=return_token_type_ids,
**output_kwargs["text_kwargs"],
)

Expand All @@ -262,6 +262,12 @@ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)

@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids", "labels"]
image_processor_input_names = self.image_processor.model_input_names
return list(tokenizer_input_names + image_processor_input_names)

@property
def query_augmentation_token(self) -> str:
"""
Expand Down
24 changes: 12 additions & 12 deletions src/transformers/models/colqwen2/processing_colqwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,18 @@ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):

return MultiModalData(**vision_data)

@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names

# ColQwen doesn't process videos. Make a copy of list when removing
# otherwise `self.feature_extractor.model_input_names` is also modified
image_processor_input_names = [
name for name in image_processor_input_names if name not in ["pixel_values_videos", "video_grid_thw"]
]
return tokenizer_input_names + image_processor_input_names

@property
def query_augmentation_token(self) -> str:
"""
Expand Down Expand Up @@ -385,17 +397,5 @@ def score_retrieval(

return torch.cat(scores, dim=0)

@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names

# ColQwen doesn't process videos. Make a copy of list when removing
# otherwise `self.feature_extractor.model_input_names` is also modified
image_processor_input_names = [
name for name in image_processor_input_names if name not in ["pixel_values_videos", "video_grid_thw"]
]
return tokenizer_input_names + image_processor_input_names


__all__ = ["ColQwen2Processor"]
16 changes: 15 additions & 1 deletion src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
from ...configuration_utils import PretrainedConfig, layer_type_validation


class GemmaConfig(PretrainedConfig):
Expand All @@ -30,6 +30,7 @@ class GemmaConfig(PretrainedConfig):
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
Expand Down Expand Up @@ -77,6 +78,11 @@ class GemmaConfig(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
layer_types (`list`, *optional*):
Attention pattern for each layer.
use_bidirectional_attention (`bool`, *optional*):
If True, the model will attend to all text tokens instead of using a causal mask.

Comment on lines +83 to +85
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, this makes so much sense. I wonder how gemma3 worked prev, afair we didn't have a flag for defining bidirectional attention at release time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually took it from gemma3 🤗 Most of the changes here are gemma3-inspired

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like it was added recently. Prev it used is_causal = True 🙈

```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
Expand Down Expand Up @@ -125,6 +131,8 @@ def __init__(
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
layer_types=None,
use_bidirectional_attention=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -142,6 +150,12 @@ def __init__(
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_bidirectional_attention = use_bidirectional_attention

self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = ["full_attention" for _ in range(self.num_hidden_layers)]
layer_type_validation(self.layer_types, self.num_hidden_layers)

super().__init__(
pad_token_id=pad_token_id,
Expand Down
25 changes: 15 additions & 10 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(self, config: GemmaConfig, layer_idx: int):
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.is_causal = not getattr(config, "use_bidirectional_attention", False)

self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
Expand Down Expand Up @@ -268,6 +268,7 @@ def __init__(self, config: GemmaConfig, layer_idx: int):
self.mlp = GemmaMLP(config)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]

@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
Expand Down Expand Up @@ -379,14 +380,18 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
causal_mask_mapping = {
"full_attention": create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
}

# embed positions
hidden_states = inputs_embeds
Expand All @@ -403,7 +408,7 @@ def forward(
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
Expand Down
54 changes: 44 additions & 10 deletions src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
from torch import nn

from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PretrainedConfig
from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...masking_utils import create_causal_mask
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import TransformersKwargs, logging
from ..llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaForTokenClassification,
Expand Down Expand Up @@ -58,6 +60,7 @@ class GemmaConfig(PretrainedConfig):
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
Expand Down Expand Up @@ -105,6 +108,11 @@ class GemmaConfig(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
layer_types (`list`, *optional*):
Attention pattern for each layer.
use_bidirectional_attention (`bool`, *optional*):
If True, the model will attend to all text tokens instead of using a causal mask.

```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
Expand Down Expand Up @@ -153,6 +161,8 @@ def __init__(
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
layer_types=None,
use_bidirectional_attention=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -170,6 +180,12 @@ def __init__(
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_bidirectional_attention = use_bidirectional_attention

self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = ["full_attention" for _ in range(self.num_hidden_layers)]
layer_type_validation(self.layer_types, self.num_hidden_layers)

super().__init__(
pad_token_id=pad_token_id,
Expand Down Expand Up @@ -368,6 +384,20 @@ class GemmaRotaryEmbedding(LlamaRotaryEmbedding):
pass


class GemmaAttention(LlamaAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.is_causal = not getattr(config, "use_bidirectional_attention", False)


class GemmaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.attention_type = config.layer_types[layer_idx]


class GemmaPreTrainedModel(LlamaPreTrainedModel):
def _init_weights(self, module):
PreTrainedModel._init_weights(self, module)
Expand Down Expand Up @@ -407,14 +437,18 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
causal_mask_mapping = {
"full_attention": create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
}

# embed positions
hidden_states = inputs_embeds
Expand All @@ -431,7 +465,7 @@ def forward(
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/gemma2/configuration_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Gemma2Config(PretrainedConfig):
e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
Expand Down Expand Up @@ -88,6 +89,8 @@ class Gemma2Config(PretrainedConfig):
scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
scaling factor when applying tanh softcapping on the attention scores.
use_bidirectional_attention (`bool`, *optional*):
If True, the model will attend to all text tokens instead of using a causal mask.

```python
>>> from transformers import Gemma2Model, Gemma2Config
Expand Down Expand Up @@ -142,6 +145,7 @@ def __init__(
layer_types=None,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
use_bidirectional_attention=None,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -171,6 +175,7 @@ def __init__(
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.layer_types = layer_types
self.use_bidirectional_attention = use_bidirectional_attention

if self.layer_types is None:
self.layer_types = [
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = config.query_pre_attn_scalar**-0.5
self.attention_dropout = self.config.attention_dropout
self.is_causal = True
self.is_causal = not getattr(config, "use_bidirectional_attention", False)

self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
Expand Down
Loading