-
Notifications
You must be signed in to change notification settings - Fork 30.7k
[tests] re-enable aria fast tests #40846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -836,12 +836,11 @@ def wrapper(*args, **kwargs): | |
|
||
class TransformersKwargs(TypedDict, total=False): | ||
""" | ||
Keyword arguments to be passed to the loss function | ||
Keyword arguments to be passed to the forward pass of a `PreTrainedModel`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: this class is not complete, e.g. it doesn't contain |
||
|
||
Attributes: | ||
num_items_in_batch (`Optional[torch.Tensor]`, *optional*): | ||
Number of items in the batch. It is recommended to pass it when | ||
you are doing gradient accumulation. | ||
Number of items in the batch. It is recommended to pass it when you are doing gradient accumulation. | ||
output_hidden_states (`Optional[bool]`, *optional*): | ||
Most of the models support outputting all hidden states computed during the forward pass. | ||
output_attentions (`Optional[bool]`, *optional*): | ||
|
@@ -1082,7 +1081,22 @@ def wrapped_forward(*args, **kwargs): | |
module.forward = make_capture_wrapper(module, original_forward, key, specs.index) | ||
monkey_patched_layers.append((module, original_forward)) | ||
|
||
outputs = func(self, *args, **kwargs) | ||
try: | ||
outputs = func(self, *args, **kwargs) | ||
except TypeError as original_exception: | ||
# If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly. | ||
# Get a TypeError even after removing the recordable kwargs -> re-raise the original exception | ||
# Otherwise -> we're probably missing `**kwargs` in the decorated function | ||
kwargs_without_recordable = {k: v for k, v in kwargs.items() if k not in recordable_keys} | ||
try: | ||
outputs = func(self, *args, **kwargs_without_recordable) | ||
except TypeError: | ||
raise original_exception | ||
raise TypeError( | ||
"Missing `**kwargs` in the signature of the `@check_model_inputs`-decorated function " | ||
f"({func.__qualname__})" | ||
) | ||
Comment on lines
+1084
to
+1098
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Took me a few minutes to detect that these lines were failing AND that the correct solution was simply to add |
||
|
||
# Restore original forward methods | ||
for module, original_forward in monkey_patched_layers: | ||
module.forward = original_forward | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wow, I wonder how none of the tests caught it. Ideally i think we have to pass them over to vision attention for FA2. But that definitely might open a can of worms, I'll take note of it for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I confess I am not familiar with the full consequences of adding this line -- I saw that
SiglipVisionTransformer
had them and it made CI green, so it should be fine :D