Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

data2vec-audio returns different results with padded input #25621

Closed
4 tasks
gau-nernst opened this issue Aug 21, 2023 · 15 comments · Fixed by #27116
Closed
4 tasks

data2vec-audio returns different results with padded input #25621

gau-nernst opened this issue Aug 21, 2023 · 15 comments · Fixed by #27116

Comments

@gau-nernst
Copy link
Contributor

gau-nernst commented Aug 21, 2023

System Info

  • transformers version: 4.31.0
  • Platform: Linux-5.15.0-79-generic-x86_64-with-glibc2.31
  • Python version: 3.10.9
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.1
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: NA
  • Using distributed or parallel set-up in script?: NA

Who can help?

@sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoModel
import torch
import torch.nn.functional as F

name = "facebook/data2vec-audio-base"
model = AutoModel.from_pretrained(name).eval()

x = torch.randn(1, 16_000)
x = F.layer_norm(x, (16_000,))

out1 = model(x)
print(out1)

x_padded = torch.zeros(1, 20_000)
mask = torch.zeros(1, 20_000, dtype=torch.long)

x_padded[:, :16_000] = x
mask[:, :16_000] = 1

out2 = model(x_padded, mask)
print(out2)

length = out1.last_hidden_state.shape[1]
torch.testing.assert_close(out1.last_hidden_state, out2.last_hidden_state[:, :length])

extract_features are the same, but last_hidden_state is not.

Expected behavior

The two outputs should be the same.

Note that when I change the model to facebook/wav2vec2-xls-r-300m, the outputs are identical. I would expect data2vec and wav2vec 2.0 have similar behavior, since they have very similar architecture. A quick glance at the source code also indicates that there should be no reason why data2vec cannot use attention mask correctly.

The preprocessor config here also indicates the model should be able to use attention mask

https://huggingface.co/facebook/data2vec-audio-base/blob/main/preprocessor_config.json

@gau-nernst
Copy link
Contributor Author

I just noticed this in the documentation: https://huggingface.co/docs/transformers/model_doc/data2vec#transformers.Data2VecAudioModel.forward.attention_mask

For all models whose processor has config.return_attention_mask == False, such as data2vec-audio-base, attention_mask should not be passed to avoid degraded performance when doing batched inference.

Does it mean the preprocessor config is wrong?

@ArthurZucker
Copy link
Collaborator

cc @sanchit-gandhi and @ylacombe

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Sep 7, 2023

Apologies @gau-nernst, this slipped the net previously! @ylacombe are you able to take a look? Would be worth running some side-by-side debugging with padded and un-padded to see if there's a divergence

@ylacombe
Copy link
Contributor

Hey @gau-nernst,
First of all, thanks for opening the issue!

I've looked into the matter, and you rightfully highlighted two shortcomings:

  1. As you suggested, the preprocessor config seems indeed wrong, since attention_mask should be passed through the data2vec encoder to ensure correctness. Only padding with zeros will definitely won't work.
  2. Outputs should definitely be the same. Hidden states start to be different at the beginning of the encoder Data2VecAudioEncoder. I will elaborate below.

I've studied a bit more where the computation starts to differ, and it happens right here, when computing positional embeddings.
This seems to be definitely the only difference, since outputs are the same when commenting those two lines.

To address this issue, we should thus:

  1. Correct the model documentation regarding Data2VecAudioModel.forward.attention_mask.
  2. Correct the behavior of Data2VecAudioPositionalConvEmbedding and more probably of its inner Data2VecAudioPositionalConvLayer layers, so that padded inputs are correctly computed.

This could be a great PR for you @gau-nernst, WDYT on working on this ? Of course, I'll support you in the matter if you have any questions!

@gau-nernst
Copy link
Contributor Author

Thank you for the detailed investigation and explanation. Do you know why Data2VecAudioPositionalConvLayer computes differently for padded input? From what I understand, by default PyTorch's convolution uses zero padding, so zero-padded inputs should have the same outputs. And do you know if FairSeq's implementation has this problem?

@ylacombe
Copy link
Contributor

ylacombe commented Sep 11, 2023

As I understand it, when passing padding zeros through pytorch's conv1D, the padding zeros will not influence the output up to the length of the output sequence without the padding. Values after this length will not necessarily be zero.

This poses problems because: values after this length are then non zeros for the other Data2VecAudioPositionalConvLayer layers' conv1D so errors accumulate.

Note that it wouldn't be a problem if there were only one Data2VecAudioPositionalConvLayer with no layernorm, since the rest of the encoder works with an attention mask.

@gau-nernst
Copy link
Contributor Author

I see, that makes sense. That's why Wav2Vec2 doesn't have this issue, since it uses only 1 convolution layer for positional encoding.

I think the way to fix this is to fill the values after attention mask with zeros. This has to be done after every conv layers in positional encoding. Not sure if there is a more elegant way.

Another note. It seems like the original fairseq implementation also has the same problem (padded input will have different results), since it seems like they don't do any special processing (I haven't actually run the code to check). Not sure if we should deviate from the official implementation if that's the case.

@ylacombe
Copy link
Contributor

I think the way to fix this is to fill the values after attention mask with zeros. This has to be done after every conv layers in positional encoding. Not sure if there is a more elegant way.

That's exactly what is done here.
And I would agree that's the way to go.

I thing that we need to discuss it further with @sanchit-gandhi, since batching (and thus padding) seems to still give correct results in the integration tests.

I think it could be interesting to experiment a bit with your solution and check if it gives correct and coherent solutions. Would you be ok to experiement with this ? you could pass the attention mask through those layers, or you could do something with the sequence lengths. And then you can compare results with what's in the integration tests.

@ylacombe
Copy link
Contributor

Note that the integration tests use "facebook/data2vec-audio-base-960h" and not "facebook/data2vec-audio-base".

@gau-nernst
Copy link
Contributor Author

The model is probably robust enough so that the final predictions are not affected.

Do you have any thoughts about not replicating the exact fairseq implementation? This is the fairseq code and the config file

@ylacombe
Copy link
Contributor

Hey @gau-nernst, I had the occasion to discuss the matter internally with @sanchit-gandhi and @patrickvonplaten, and here are our thoughts!

At the moment, we have strict equivalence with the fairseq implementation, which leads us to believe that the current behavior might be intended or that it is simply an oversight from their part. In any case, we'd like to keep the default behavior since it doesn't seem to impact so much the outputs according to the the integration tests!

However, if you are really interested in the matter, you can still drive a PR to correct this behavior, provided that we keep the default behavior by default, and provided that it is really useful in terms of quality!
To do so and if you want, you can test a bit more the different options, and even do a benchmark measuring WER degradations in the different setups (w/o batching and padding, w/ batching and padding with default behavior, and w/ batching and padding with the fix). See this comment for an example of how to measure WER.
Would you like to do that ?

BTW, could you also tell me the intended use of the model and how you encountered this problem? Many thanks! If you encountered this issue while fine-tuning the model, you might want to group samples by length, since it appears that your issue was over-amplified by a large padding-to-length ratio !

@gau-nernst
Copy link
Contributor Author

Sadly I don't have the bandwidth to do that experiment. I'm mainly using audio models for audio classification, so I'm interested in encoder-only models. I was evaluating which model to use, and found out the strange behaviour of different results for padded inputs for Wav2Vec2 Base and HuBERT Base models, which is due to the use of group norm. Then I tried to see if other models had this problem, thus found it for data2vec-audio.

Currently I don't use data2vec-audio models, since I think Wav2Vec2 XLS-R is much better thanks to its large pre-trained data.

I believe the solution for now is to update the documentation

  1. Remove the warning about attention mask

For all models whose processor has config.return_attention_mask == False, such as data2vec-audio-base, attention_mask should not be passed

  1. Add a warning that padded inputs will have different outputs, even with attention mask, due to the convolution layers in positional encodings.

@gau-nernst
Copy link
Contributor Author

@ylacombe What do you think of the solution I proposed above? I can submit a PR if you are ok with it. It's mainly documentation fix, since I won't have the bandwidth to do experiments with the model code.

@ylacombe
Copy link
Contributor

ylacombe commented Oct 2, 2023

Hey @gau-nernst, thanks for the remainder! it will be nice to have your contribution on a PR here! I agree with the solution you proposed for now, feel free to ping me on the PR!

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants