Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return effective attention mask in Wav2Vec2BaseModelOutput #25471

Closed
wants to merge 27 commits into from
Closed

Return effective attention mask in Wav2Vec2BaseModelOutput #25471

wants to merge 27 commits into from

Conversation

gau-nernst
Copy link
Contributor

@gau-nernst gau-nernst commented Aug 12, 2023

What does this PR do?

Fixes #25307

  • Add field attention_mask to Wav2Vec2BaseModelOutput
  • Return updated attention mask for Wav2VecModel, Data2VecAudioModel, HubertModel, SEWModel, SEWDModel, WavLMModel, Wav2Vec2ConformerModel, UniSpeechModel, UniSpeechSatModel
  • Change model output from BaseModelOutput to Wav2Vec2BaseModelOutput for HubertModel, SEWModel, SEWDModel
  • Fix tensor comparison functions to accept bool

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sanchit-gandhi

@gau-nernst gau-nernst marked this pull request as ready for review August 13, 2023 04:04
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @gau-nernst - it's looking great already! Just two minor suggestions below regarding the dtype of the outputs, and forcing return_dict=True

Could we also add a new test for Wav2Vec2 to make sure we get a downsampled attention mask on the outputs?

@@ -1750,12 +1752,6 @@ def forward(
# 2. quantize all (unmasked) extracted features and project to final vq dim
extract_features = self.dropout_features(outputs[1])

if attention_mask is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch! This wasn't used here

@@ -1854,16 +1850,15 @@ def forward(
input_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually we try to avoid forcing return_dict in the call to any modules, and leave it to the setting input by the user -> can we respect the setting of return_dict used here? i.e. set return_dict=return_dict as it was previously?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't change this at first, but faced a lot of problems when I try to access the returned attention_mask in a tuple-indexing style. Following existing code for hidden_states, where a global variable _HIDDEN_STATES_START_POSITION = 2 is defined, I tried to define _ATTENTION_MASK_START_POSITION = 4 and use that for indexing. This becomes a big problem because the tuple output might not include hidden_states, thus the index 4 is wrong. The cleaner solution for me was to force return_dict=True for the sub-model, thus I can always access outputs.attention_mask reliably and cleanly.

The alternative solution is to check if output_hidden_states is True (which can be overridden by self.config.use_weighted_layer_sum), then use index=4 for attention_mask, otherwise use index=3. Let me know if you prefer which solution.

(Also I just noticed output_hidden_states argument in forward() for derived models (ForMaskedLM, XVector...) does not do anything since it won't affect the final output)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright thanks for the explanation! In this case it does indeed seem easier to force return_dict=True and use the base model output classes. Let's leave this one open for the next reviewer!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought more about this and unfortunately I don't think we can force return_dict like this - setting return_dict=True means that we can't torch.compile the model, which is going to disrupt the ongoing PR #24668

For torch.compile to work, we need to propagate return_dict=False throughout the whole model. Otherwise, we get a graph break here and the compilation fails

Copy link
Contributor Author

@gau-nernst gau-nernst Aug 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What test do you run to show that return_dict=True breaks the graph? From my local testing it seems like return_dict=True is fine, but graph breaking occurs due to getattr(self.config, "apply_spec_augment", True) (I have replaced hidden_states[mask/indices] = something with hidden_states.masked_filled() to avoid in-place ops)

import torch
from transformers.models.wav2vec2 import Wav2Vec2Model, Wav2Vec2ForSequenceClassification

# model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base")
model = torch.compile(model, fullgraph=True)

x = torch.randn(2, 16000 * 2)

out = model(x)

Stack trace

  File "/Users/thien/miniforge3/envs/pytorch/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 225, in call_hasattr
    unimplemented(f"hasattr: {repr(self)}")
  File "/Users/thien/miniforge3/envs/pytorch/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 71, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: hasattr: HFPretrainedConfigVariable()

from user code:
   File "/Users/thien/github/transformers/src/transformers/models/wav2vec2/modeling_wav2vec2.py", line 1578, in forward
    hidden_states = self._mask_hidden_states(
  File "/Users/thien/github/transformers/src/transformers/models/wav2vec2/modeling_wav2vec2.py", line 1508, in _mask_hidden_states
    if not getattr(self.config, "apply_spec_augment", True):

If I remove fullgraph=True in torch.compile(), the model compiles without error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchit-gandhi @gau-nernst Yes, we don't force return_dict=True as it messes things up for torch.compile, unfortunately. It's true return_dict=True makes things more straightforward and less error prone (we can access name attributes rather than indexing) but this is true for pretty much all of our models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts What test do you run to show torch.compile() fails when overriding return_dict=True? If you read the previous comments, in my setup, overriding return_dict=True does not mess up torch.compile(). It will be good to make sure I don't miss anything.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @amyeroberts - just wondering if you had a chance to circle back here? As @gau-nernst mentioned in their prior comment, it looks like torch.compile is not affected by setting return_dict=True in this module: #25471 (comment)

Is it ok to force this value in this case?

Copy link
Collaborator

@amyeroberts amyeroberts Sep 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gau-nernst, apologies for the delay. I thought there were tests for torch.compile but it seems that there aren't yet (cc @ydshieh). I've had to convert models last month because of setting return_dict=True and incompatibilities with torch.compile, and this has been a design pattern for a while. If you can successfully compile and use this model setting return_dict=True then I don't see why we couldn't do this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also a case in diffusers that setting return_dict=False is necessary to enable torch compile, e.g. huggingface/diffusers#4738. But if we've comprehensively shown that torch compile works with return_dict=True, I also don't see any issue in allowing this!

src/transformers/modeling_outputs.py Outdated Show resolved Hide resolved
src/transformers/modeling_outputs.py Outdated Show resolved Hide resolved

return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
attention_mask=attention_mask,
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We input a torch.LongTensor of 0/1's as the attention mask input. Currently, we're returning a boolean output for the downsampled attention mask -> IMO it makes sense to return the downsampled attention mask as a torch.LongTensor of 0/1's for consistency with the inputs. WDYT? Just requires you to convert this boolean tensor to a long tensor!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will convert it to torch.long.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thanks @gau-nernst!

@@ -1729,6 +1729,10 @@ def recursive_check(tuple_object, dict_object):
elif tuple_object is None:
return
else:
# bool tensors do not support subtraction `-`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above suggestion will fix this problem for you

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchit-gandhi Do you know if .sum(-1) still cause problems? It's introduced in #14260. I try change .cumsum(-1)[:, -1] to .sum(-1) in wav2vec2 and all tests still pass. I think it will be best to consolidate similar functions to wav2vec2. Before my PR, Hubert's version uses .sum(-1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I believe it does - ideally let's keep it as the .cumsum op to avoid the in-place operations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently I see two variants of this this function

  1. Wav2Vec2, Wav2Vec2Conformer, WavLM, Data2Vec-Audio: there is add_adapter argument
  2. HuBERT, SEW, SEW-D, UniSpeech, UniSpeech-SAT, and SpeechT5: no add_adapter argument

Do you know the reason why there are two variants? add_adapter is always False anyway from what I see. Ideally I think we should consolidate everything to Wav2Vec2 version (make other models copy from Wav2Vec2), since all other models are modifications from Wav2Vec2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add_adapter is used to account for the linear adapter layer that is optionally added on top of an audio encoder model. This adapter layer further downsamples along the temporal dimension of the hidden_states, c.f. https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition#warm-started-speech-encoder-decoder-model

This adapter is only used in a handful of models, e.g. when the model is a candidate encoder in an encoder-decoder configuration

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Would it be okay if I use Wav2Vec2's _get_feature_vector_attention_mask() for all other models, even if they don't have adapters? Having one source of truth would be easier for maintenance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have two sets: one for models with adapters (e.g. Wav2Vec2), one for models without (e.g. HuBERT). All the models with adapters should copy from Wav2Vec2. All the models without should copy from HuBERT. Otherwise, the models without adapters have redundant adapter code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had done what you suggested before this comment :)

@gau-nernst
Copy link
Contributor Author

@sanchit-gandhi How should I add a test checking for returned downsampled attention mask? Should it be under Wav2Vec2ModelTest? I'm not familiar with HF tests, and it looks kinda overwhelming.

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Aug 21, 2023

Hey @gau-nernst - indeed it's quite an intimidating file! It's got quite a lot of tests given Wav2Vec2 is a core model in the library and we want to ensure that any new changes don't break backwards compatibility. But once you know how it works it's quite straightforward!

The model tester class defines all the functions we want to test (check_xxx). The model test class then executes them (test_xxx).

What I would suggest doing here is defining a function in the model tester, e.g.

def check_labels_out_of_vocab(self, config, input_values, *args):

And then running the test in the model test, e.g.

def test_labels_out_of_vocab(self):

=> this way you just have to focus on writing one new function in the model tester, and then execute it in the model test

In this test, I think we can check that:

  1. We return an attention mask from the model output
  2. This attention mask has the correct downsampled length (which we can get from the private method _get_feat_extract_output_lengths if required)

@@ -775,8 +775,7 @@ def forward(

if attention_mask is not None:
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
hidden_states[~attention_mask.bool()] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to expand the attention mask anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it's not necessary. Let me double check this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can check that the slow tests pass with this change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my snippet below, not expanding attention mask is fine

import torch

B, L, D = 4, 100, 256

hidden_states = torch.randn(B, L, D)
mask = torch.zeros(B, L, dtype=torch.long)
for i, length in enumerate(torch.randint(L, size=(B,))):
    mask[i, :length] = 1

out1 = hidden_states.clone()
out1[~mask.unsqueeze(-1).repeat(1, 1, D).bool()] = 0

out2 = hidden_states.clone()
out2[~mask.bool()] = 0

torch.testing.assert_close(out1, out2)

Regarding #24668, I think using Tensor.masked_filled() would be better than torch.where() to make sure it's an out-of-place op. In that case, we need to add an extra dimension to attention mask, but repeating it is not necessary thanks to broadcasting.

out3 = hidden_states.masked_fill(~mask.unsqueeze(-1).bool(), 0)
torch.testing.assert_close(out1, out3)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, sounds good! Thanks for clarifying

)

hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.lm_head(hidden_states)

if not return_dict:
output = (logits,) + outputs[2:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here we will need to revert to slicing the tuple as was done previously - IMO it's worth trying your solution where we index the tuple based on the settings of output_hidden_states / output_attention, since we sadly cannot force return_dict=True

@@ -473,6 +473,28 @@ def check_labels_out_of_vocab(self, config, input_values, *args):
with self.parent.assertRaises(ValueError):
model(input_values, labels=labels)

def check_returned_attention_mask(self, config, input_values, *args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test!

@sanchit-gandhi
Copy link
Contributor

It's looking very close to completion @gau-nernst! Just the return_dict situation that needs addressing, otherwise in good shape 🤗

@sanchit-gandhi
Copy link
Contributor

Let me know when you'd like a re-review here @gau-nernst! It's looking quite nice already!

@sanchit-gandhi sanchit-gandhi mentioned this pull request Aug 31, 2023
7 tasks
@huggingface huggingface deleted a comment from github-actions bot Oct 11, 2023
@huggingface huggingface deleted a comment from github-actions bot Nov 6, 2023
@huggingface huggingface deleted a comment from github-actions bot Dec 1, 2023
@huggingface huggingface deleted a comment from github-actions bot Dec 26, 2023
@huggingface huggingface deleted a comment from github-actions bot Jan 22, 2024
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Feb 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Return updated attention mask from Wav2Vec 2.0
4 participants