Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/models/idefics3/modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
**kwargs: Unpack[TransformersKwargs],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wow, I wonder how none of the tests caught it. Ideally i think we have to pass them over to vision attention for FA2. But that definitely might open a can of worms, I'll take note of it for now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confess I am not familiar with the full consequences of adding this line -- I saw that SiglipVisionTransformer had them and it made CI green, so it should be fine :D

) -> Union[tuple, BaseModelOutput]:
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/smolvlm/modeling_smolvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutput]:
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
Expand Down
22 changes: 18 additions & 4 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,12 +836,11 @@ def wrapper(*args, **kwargs):

class TransformersKwargs(TypedDict, total=False):
"""
Keyword arguments to be passed to the loss function
Keyword arguments to be passed to the forward pass of a `PreTrainedModel`.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this class is not complete, e.g. it doesn't contain output_vision_hidden_states. we may want to create variants of this class for documentation purposes 🤔


Attributes:
num_items_in_batch (`Optional[torch.Tensor]`, *optional*):
Number of items in the batch. It is recommended to pass it when
you are doing gradient accumulation.
Number of items in the batch. It is recommended to pass it when you are doing gradient accumulation.
output_hidden_states (`Optional[bool]`, *optional*):
Most of the models support outputting all hidden states computed during the forward pass.
output_attentions (`Optional[bool]`, *optional*):
Expand Down Expand Up @@ -1082,7 +1081,22 @@ def wrapped_forward(*args, **kwargs):
module.forward = make_capture_wrapper(module, original_forward, key, specs.index)
monkey_patched_layers.append((module, original_forward))

outputs = func(self, *args, **kwargs)
try:
outputs = func(self, *args, **kwargs)
except TypeError as original_exception:
# If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly.
# Get a TypeError even after removing the recordable kwargs -> re-raise the original exception
# Otherwise -> we're probably missing `**kwargs` in the decorated function
kwargs_without_recordable = {k: v for k, v in kwargs.items() if k not in recordable_keys}
try:
outputs = func(self, *args, **kwargs_without_recordable)
except TypeError:
raise original_exception
raise TypeError(
"Missing `**kwargs` in the signature of the `@check_model_inputs`-decorated function "
f"({func.__qualname__})"
)
Comment on lines +1084 to +1098
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took me a few minutes to detect that these lines were failing AND that the correct solution was simply to add **kwargs -> added an informative exception so that external contributors can quickly fix related problems 🤗


# Restore original forward methods
for module, original_forward in monkey_patched_layers:
module.forward = original_forward
Expand Down
6 changes: 3 additions & 3 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,12 +2131,12 @@ def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, c
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))

# (batch, head, seq_length, head_features)
# (batch, # kv heads, seq_length, head_features)
expected_shape = (
batch_size,
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
getattr(config, "num_key_value_heads", None) or config.num_attention_heads,
cache_length,
config.hidden_size // config.num_attention_heads,
getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads,
)

if isinstance(decoder_past_key_values, Cache):
Expand Down
87 changes: 21 additions & 66 deletions tests/models/aria/test_modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import unittest

import pytest
import requests

from transformers import (
Expand Down Expand Up @@ -61,6 +60,10 @@ class AriaVisionText2TextModelTester:
def __init__(
self,
parent,
batch_size=13,
num_channels=3,
image_size=16,
num_image_tokens=4,
ignore_index=-100,
image_token_index=9,
projector_hidden_act="gelu",
Expand All @@ -83,32 +86,32 @@ def __init__(
num_choices=4,
pad_token_id=1,
hidden_size=32,
intermediate_size=64,
intermediate_size=16,
max_position_embeddings=60,
model_type="aria_moe_lm",
moe_intermediate_size=4,
moe_num_experts=4,
moe_num_experts=3,
moe_topk=2,
num_attention_heads=8,
num_attention_heads=2,
num_experts_per_tok=3,
num_hidden_layers=2,
num_key_value_heads=8,
num_key_value_heads=2,
rope_theta=5000000,
vocab_size=99,
eos_token_id=2,
head_dim=4,
),
is_training=True,
vision_config=Idefics3VisionConfig(
image_size=358,
patch_size=10,
image_size=16,
patch_size=8,
num_channels=3,
is_training=True,
hidden_size=32,
projection_dim=20,
projection_dim=4,
num_hidden_layers=2,
num_attention_heads=16,
intermediate_size=10,
num_attention_heads=2,
intermediate_size=4,
dropout=0.1,
attention_dropout=0.1,
initializer_range=0.02,
Expand All @@ -130,11 +133,14 @@ def __init__(
self.num_attention_heads = text_config.num_attention_heads
self.is_training = is_training

self.batch_size = 10
self.num_channels = 3
self.image_size = 358
self.num_image_tokens = 128
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.num_image_tokens = num_image_tokens
self.seq_length = seq_length + self.num_image_tokens
self.projector_patch_to_query_dict = {
vision_config.image_size**2 // vision_config.patch_size**2: vision_config.projection_dim
}

def get_config(self):
return AriaConfig(
Expand All @@ -146,6 +152,7 @@ def get_config(self):
vision_feature_select_strategy=self.vision_feature_select_strategy,
vision_feature_layer=self.vision_feature_layer,
eos_token_id=self.eos_token_id,
projector_patch_to_query_dict=self.projector_patch_to_query_dict,
)

def prepare_config_and_inputs(self):
Expand Down Expand Up @@ -176,7 +183,6 @@ def prepare_config_and_inputs_for_common(self):
return config, inputs_dict


@slow
@require_torch
class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
"""
Expand All @@ -193,61 +199,10 @@ def setUp(self):
self.model_tester = AriaVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=AriaConfig, has_text_modality=False)

@unittest.skip(
reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass

@unittest.skip(
reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass

@unittest.skip(
reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(reason="Compile not yet supported because in LLava models")
@pytest.mark.torch_compile_test
def test_sdpa_can_compile_dynamic(self):
pass

@unittest.skip(reason="Compile not yet supported because in LLava models")
def test_sdpa_can_dispatch_on_flash(self):
pass

@unittest.skip(reason="Feedforward chunking is not yet supported")
def test_feed_forward_chunking(self):
pass

@unittest.skip(reason="Unstable test")
def test_initialization(self):
pass

@unittest.skip(reason="Dynamic control flow due to MoE")
def test_generate_with_static_cache(self):
pass

@unittest.skip(reason="Dynamic control flow due to MoE")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass

@unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading")
def test_cpu_offload(self):
pass

@unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading")
def test_disk_offload_bin(self):
pass

@unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading")
def test_disk_offload_safetensors(self):
pass


SKIP = False
torch_accelerator_module = getattr(torch, torch_device)
Expand Down