-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Return effective attention mask in Wav2Vec2BaseModelOutput #25471
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @gau-nernst - it's looking great already! Just two minor suggestions below regarding the dtype of the outputs, and forcing return_dict=True
Could we also add a new test for Wav2Vec2 to make sure we get a downsampled attention mask on the outputs?
@@ -1750,12 +1752,6 @@ def forward( | |||
# 2. quantize all (unmasked) extracted features and project to final vq dim | |||
extract_features = self.dropout_features(outputs[1]) | |||
|
|||
if attention_mask is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good catch! This wasn't used here
@@ -1854,16 +1850,15 @@ def forward( | |||
input_values, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
return_dict=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually we try to avoid forcing return_dict
in the call to any modules, and leave it to the setting input by the user -> can we respect the setting of return_dict
used here? i.e. set return_dict=return_dict
as it was previously?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't change this at first, but faced a lot of problems when I try to access the returned attention_mask
in a tuple-indexing style. Following existing code for hidden_states
, where a global variable _HIDDEN_STATES_START_POSITION = 2
is defined, I tried to define _ATTENTION_MASK_START_POSITION = 4
and use that for indexing. This becomes a big problem because the tuple output might not include hidden_states
, thus the index 4 is wrong. The cleaner solution for me was to force return_dict=True
for the sub-model, thus I can always access outputs.attention_mask
reliably and cleanly.
The alternative solution is to check if output_hidden_states
is True (which can be overridden by self.config.use_weighted_layer_sum
), then use index=4 for attention_mask
, otherwise use index=3. Let me know if you prefer which solution.
(Also I just noticed output_hidden_states
argument in forward()
for derived models (ForMaskedLM, XVector...) does not do anything since it won't affect the final output)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright thanks for the explanation! In this case it does indeed seem easier to force return_dict=True
and use the base model output classes. Let's leave this one open for the next reviewer!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought more about this and unfortunately I don't think we can force return_dict
like this - setting return_dict=True
means that we can't torch.compile
the model, which is going to disrupt the ongoing PR #24668
For torch.compile
to work, we need to propagate return_dict=False
throughout the whole model. Otherwise, we get a graph break here and the compilation fails
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What test do you run to show that return_dict=True
breaks the graph? From my local testing it seems like return_dict=True
is fine, but graph breaking occurs due to getattr(self.config, "apply_spec_augment", True)
(I have replaced hidden_states[mask/indices] = something
with hidden_states.masked_filled()
to avoid in-place ops)
import torch
from transformers.models.wav2vec2 import Wav2Vec2Model, Wav2Vec2ForSequenceClassification
# model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base")
model = torch.compile(model, fullgraph=True)
x = torch.randn(2, 16000 * 2)
out = model(x)
Stack trace
File "/Users/thien/miniforge3/envs/pytorch/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 225, in call_hasattr
unimplemented(f"hasattr: {repr(self)}")
File "/Users/thien/miniforge3/envs/pytorch/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 71, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: hasattr: HFPretrainedConfigVariable()
from user code:
File "/Users/thien/github/transformers/src/transformers/models/wav2vec2/modeling_wav2vec2.py", line 1578, in forward
hidden_states = self._mask_hidden_states(
File "/Users/thien/github/transformers/src/transformers/models/wav2vec2/modeling_wav2vec2.py", line 1508, in _mask_hidden_states
if not getattr(self.config, "apply_spec_augment", True):
If I remove fullgraph=True
in torch.compile()
, the model compiles without error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sanchit-gandhi @gau-nernst Yes, we don't force return_dict=True
as it messes things up for torch.compile
, unfortunately. It's true return_dict=True
makes things more straightforward and less error prone (we can access name attributes rather than indexing) but this is true for pretty much all of our models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts What test do you run to show torch.compile()
fails when overriding return_dict=True
? If you read the previous comments, in my setup, overriding return_dict=True
does not mess up torch.compile()
. It will be good to make sure I don't miss anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @amyeroberts - just wondering if you had a chance to circle back here? As @gau-nernst mentioned in their prior comment, it looks like torch.compile
is not affected by setting return_dict=True
in this module: #25471 (comment)
Is it ok to force this value in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @gau-nernst, apologies for the delay. I thought there were tests for torch.compile
but it seems that there aren't yet (cc @ydshieh). I've had to convert models last month because of setting return_dict=True
and incompatibilities with torch.compile, and this has been a design pattern for a while. If you can successfully compile and use this model setting return_dict=True
then I don't see why we couldn't do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's also a case in diffusers
that setting return_dict=False
is necessary to enable torch compile, e.g. huggingface/diffusers#4738. But if we've comprehensively shown that torch compile works with return_dict=True
, I also don't see any issue in allowing this!
|
||
return Wav2Vec2BaseModelOutput( | ||
last_hidden_state=hidden_states, | ||
extract_features=extract_features, | ||
hidden_states=encoder_outputs.hidden_states, | ||
attentions=encoder_outputs.attentions, | ||
attention_mask=attention_mask, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We input a torch.LongTensor
of 0/1's as the attention mask input. Currently, we're returning a boolean output for the downsampled attention mask -> IMO it makes sense to return the downsampled attention mask as a torch.LongTensor
of 0/1's for consistency with the inputs. WDYT? Just requires you to convert this boolean tensor to a long tensor!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I will convert it to torch.long
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thanks @gau-nernst!
tests/test_modeling_common.py
Outdated
@@ -1729,6 +1729,10 @@ def recursive_check(tuple_object, dict_object): | |||
elif tuple_object is None: | |||
return | |||
else: | |||
# bool tensors do not support subtraction `-` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The above suggestion will fix this problem for you
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): | ||
# Effectively attention_mask.sum(-1), but not inplace to be able to run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sanchit-gandhi Do you know if .sum(-1) still cause problems? It's introduced in #14260. I try change .cumsum(-1)[:, -1]
to .sum(-1)
in wav2vec2 and all tests still pass. I think it will be best to consolidate similar functions to wav2vec2. Before my PR, Hubert's version uses .sum(-1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I believe it does - ideally let's keep it as the .cumsum
op to avoid the in-place operations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently I see two variants of this this function
- Wav2Vec2, Wav2Vec2Conformer, WavLM, Data2Vec-Audio: there is
add_adapter
argument - HuBERT, SEW, SEW-D, UniSpeech, UniSpeech-SAT, and SpeechT5: no
add_adapter
argument
Do you know the reason why there are two variants? add_adapter
is always False anyway from what I see. Ideally I think we should consolidate everything to Wav2Vec2 version (make other models copy from Wav2Vec2), since all other models are modifications from Wav2Vec2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add_adapter
is used to account for the linear adapter layer that is optionally added on top of an audio encoder model. This adapter layer further downsamples along the temporal dimension of the hidden_states
, c.f. https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition#warm-started-speech-encoder-decoder-model
This adapter is only used in a handful of models, e.g. when the model is a candidate encoder in an encoder-decoder configuration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Would it be okay if I use Wav2Vec2's _get_feature_vector_attention_mask()
for all other models, even if they don't have adapters? Having one source of truth would be easier for maintenance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have two sets: one for models with adapters (e.g. Wav2Vec2), one for models without (e.g. HuBERT). All the models with adapters should copy from Wav2Vec2. All the models without should copy from HuBERT. Otherwise, the models without adapters have redundant adapter code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had done what you suggested before this comment :)
@sanchit-gandhi How should I add a test checking for returned downsampled attention mask? Should it be under |
Hey @gau-nernst - indeed it's quite an intimidating file! It's got quite a lot of tests given Wav2Vec2 is a core model in the library and we want to ensure that any new changes don't break backwards compatibility. But once you know how it works it's quite straightforward! The model tester class defines all the functions we want to test ( What I would suggest doing here is defining a function in the model tester, e.g.
And then running the test in the model test, e.g.
=> this way you just have to focus on writing one new function in the model tester, and then execute it in the model test In this test, I think we can check that:
|
@@ -775,8 +775,7 @@ def forward( | |||
|
|||
if attention_mask is not None: | |||
# make sure padded tokens output 0 | |||
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) | |||
hidden_states[~expand_attention_mask] = 0 | |||
hidden_states[~attention_mask.bool()] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to expand the attention mask anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like it's not necessary. Let me double check this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can check that the slow tests pass with this change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my snippet below, not expanding attention mask is fine
import torch
B, L, D = 4, 100, 256
hidden_states = torch.randn(B, L, D)
mask = torch.zeros(B, L, dtype=torch.long)
for i, length in enumerate(torch.randint(L, size=(B,))):
mask[i, :length] = 1
out1 = hidden_states.clone()
out1[~mask.unsqueeze(-1).repeat(1, 1, D).bool()] = 0
out2 = hidden_states.clone()
out2[~mask.bool()] = 0
torch.testing.assert_close(out1, out2)
Regarding #24668, I think using Tensor.masked_filled()
would be better than torch.where()
to make sure it's an out-of-place op. In that case, we need to add an extra dimension to attention mask, but repeating it is not necessary thanks to broadcasting.
out3 = hidden_states.masked_fill(~mask.unsqueeze(-1).bool(), 0)
torch.testing.assert_close(out1, out3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, sounds good! Thanks for clarifying
) | ||
|
||
hidden_states = outputs[0] | ||
hidden_states = self.dropout(hidden_states) | ||
logits = self.lm_head(hidden_states) | ||
|
||
if not return_dict: | ||
output = (logits,) + outputs[2:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here we will need to revert to slicing the tuple as was done previously - IMO it's worth trying your solution where we index the tuple based on the settings of output_hidden_states
/ output_attention
, since we sadly cannot force return_dict=True
@@ -473,6 +473,28 @@ def check_labels_out_of_vocab(self, config, input_values, *args): | |||
with self.parent.assertRaises(ValueError): | |||
model(input_values, labels=labels) | |||
|
|||
def check_returned_attention_mask(self, config, input_values, *args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice test!
It's looking very close to completion @gau-nernst! Just the |
Let me know when you'd like a re-review here @gau-nernst! It's looking quite nice already! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
Fixes #25307
attention_mask
to Wav2Vec2BaseModelOutputBefore submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sanchit-gandhi