Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5ec3789
allow `check_model_inputs` in core VLMs
zucchini-nlp Aug 21, 2025
6db7a44
address comments
zucchini-nlp Aug 22, 2025
7e1d1a5
fix style
zucchini-nlp Aug 22, 2025
268bd49
why this didnt fail prev?
zucchini-nlp Aug 22, 2025
0be7f5e
chec for Noneness instead
zucchini-nlp Aug 22, 2025
9696c83
Merge branch 'main' into check-model-inputs
zucchini-nlp Aug 27, 2025
ae9c66a
batch update vlms
zucchini-nlp Aug 28, 2025
c6ee459
fix some tests
zucchini-nlp Aug 28, 2025
84c4178
Merge remote-tracking branch 'upstream/main' into check-model-inputs
zucchini-nlp Aug 28, 2025
9a3d9bd
fix copies
zucchini-nlp Aug 28, 2025
fb59341
oops delete
zucchini-nlp Aug 28, 2025
05104d9
fix efficientloftr
zucchini-nlp Aug 29, 2025
202bf6b
fix copies
zucchini-nlp Aug 29, 2025
59895d8
i am stupid, fix idefics
zucchini-nlp Aug 29, 2025
20ae443
fix GC
zucchini-nlp Aug 29, 2025
7e64094
return type and other comments
zucchini-nlp Sep 1, 2025
5929681
we shouldn't manually change attention anymore
zucchini-nlp Sep 1, 2025
41d8f92
Merge remote-tracking branch 'upstream/main' into check-model-inputs
zucchini-nlp Sep 1, 2025
b8648c3
fix style
zucchini-nlp Sep 1, 2025
59b74c8
fix copies
zucchini-nlp Sep 1, 2025
fe5b522
fix the test
zucchini-nlp Sep 1, 2025
208424f
Merge branch 'main' into check-model-inputs
zucchini-nlp Sep 2, 2025
e22da39
Merge branch 'main' into check-model-inputs
zucchini-nlp Sep 2, 2025
ae302ea
Merge branch 'main' into check-model-inputs
zucchini-nlp Sep 3, 2025
9175d99
Merge branch 'main' into check-model-inputs
zucchini-nlp Sep 3, 2025
db58424
Merge branch 'main' into check-model-inputs
zucchini-nlp Sep 4, 2025
a753e86
vision model shouldn't need attention, see e.g. CLIP/Siglip
zucchini-nlp Sep 4, 2025
3de818f
Merge branch 'main' into check-model-inputs
zucchini-nlp Sep 4, 2025
3613648
Merge branch 'main' into check-model-inputs
zucchini-nlp Sep 5, 2025
37d4cbd
Merge branch 'main' into check-model-inputs
zucchini-nlp Sep 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 35 additions & 95 deletions src/transformers/models/aimv2/modeling_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs
from ...utils.deprecation import deprecate_kwarg
from ...utils.generic import check_model_inputs
from .configuration_aimv2 import Aimv2Config, Aimv2TextConfig, Aimv2VisionConfig


Expand Down Expand Up @@ -300,17 +303,17 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, torch.Tensor]:
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
norm_hidden_states = self.rms_norm1(hidden_states)
attn_output, attn_weights = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask)
attn_output, _ = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask, **kwargs)

hidden_states = hidden_states + attn_output
norm_hidden_states = self.rms_norm2(hidden_states)
mlp_output = self.ffn(norm_hidden_states)

hidden_states = hidden_states + mlp_output
return (hidden_states, attn_weights) if output_attentions else (hidden_states, None)
return hidden_states


class Aimv2Encoder(nn.Module):
Expand All @@ -329,68 +332,22 @@ def __init__(self, config: Aimv2Config):
self.gradient_checkpointing = False

# Ignore copy
@can_return_tuple
@auto_docstring
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None

hidden_states = inputs_embeds
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)

layer_outputs = encoder_layer(
hidden_states = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
**kwargs,
)

hidden_states = layer_outputs[0]

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)

return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
return BaseModelOutput(last_hidden_state=hidden_states)


class Aimv2AttentionPoolingHead(nn.Module):
Expand Down Expand Up @@ -464,6 +421,10 @@ def _init_weights(self, module):
class Aimv2VisionModel(Aimv2PreTrainedModel):
config: Aimv2VisionConfig
main_input_name = "pixel_values"
_can_record_outputs = {
"hidden_states": Aimv2EncoderLayer,
"attentions": Aimv2Attention,
}

def __init__(self, config: Aimv2VisionConfig):
super().__init__(config)
Expand All @@ -482,14 +443,14 @@ def __init__(self, config: Aimv2VisionConfig):
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.patch_embed

@can_return_tuple
@deprecate_kwarg("attention_mask", version="v4.58.0")
@check_model_inputs
@auto_docstring
def forward(
self,
pixel_values,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
r"""
Examples:
Expand All @@ -511,29 +472,21 @@ def forward(
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

hidden_states = self.embeddings(pixel_values)

encoder_outputs = self.encoder(
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
)

last_hidden_state = encoder_outputs[0]
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.rms_norm(last_hidden_state)

pooler_output = self.head(last_hidden_state) if self.use_head else None

return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


Expand All @@ -545,6 +498,11 @@ def forward(
class Aimv2TextModel(Aimv2PreTrainedModel):
main_input_name = "input_ids"

_can_record_outputs = {
"hidden_states": Aimv2EncoderLayer,
"attentions": Aimv2Attention,
}

def __init__(self, config: Aimv2TextConfig):
super().__init__(config)
self.config = config
Expand All @@ -562,20 +520,14 @@ def get_input_embeddings(self) -> nn.Module:
def set_input_embeddings(self, value):
self.embeddings.token_embedding = value

@can_return_tuple
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

hidden_states = self.embeddings(input_ids)
batch_size, seq_len, _ = hidden_states.shape

Expand All @@ -594,11 +546,10 @@ def forward(
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
)

last_hidden_state = encoder_outputs[0]
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.rms_norm(last_hidden_state)

# Get pooled output
Expand All @@ -610,8 +561,6 @@ def forward(
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


Expand Down Expand Up @@ -733,8 +682,7 @@ def forward(
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Aimv2Output:
r"""
Examples:
Expand All @@ -758,23 +706,15 @@ def forward(
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

vision_outputs: BaseModelOutputWithPooling = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
)

text_outputs: BaseModelOutputWithPooling = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
)

image_embeds = vision_outputs.pooler_output
Expand Down
Loading