Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wave2vec2 with feat_extract_norm == "group" normalizes over channels not token which causes issue when padding #21534

Closed
1 of 2 tasks
itayhubara opened this issue Feb 9, 2023 · 8 comments

Comments

@itayhubara
Copy link

itayhubara commented Feb 9, 2023

System Info

Any GPU/CPU system

Who can help?

No response

Information

While there is support for padding, I believe it affects the accuracy - (even when running with\without padding on just one sample). My central claim is that since the first normalization is done over the channels and not the tokens and thus the mean and var values change when the sequence is padded.

I used similar code to the given example where in the processor I used padding="max_length", truncation=True,max_length=1182340.
Note that wish small enough padding for instance max_length=100000 you want see any issue.
Am I missing something? Can anyone help?

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Use the example code with and without padding

With padding:

from transformers import AutoProcessor, Wav2Vec2ForCTC
from datasets import load_dataset
import torch

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate

processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")


inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt", padding="max_length", truncation=True,max_length=1182340)
with torch.no_grad():
    logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)

transcription = processor.batch_decode(predicted_ids)
print(transcription[0])

output: ISTE COIS THE COL OF T I CLASES AND WE RLITO O HIS GOSPLE

Without padding:

from transformers import AutoProcessor, Wav2Vec2ForCTC
from datasets import load_dataset
import torch

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate

processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
with torch.no_grad():
    logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)

transcription = processor.batch_decode(predicted_ids)
print(transcription[0]) 

output: 'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'

Expected behavior

I expect that adding won't affect the output of the model

@sgugger
Copy link
Collaborator

sgugger commented Feb 9, 2023

cc @sanchit-gandhi

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Feb 10, 2023

Hey @itayhubara!

Thanks for the nice write-up 🤗 Edited your OP to format the reproducible code as a markdown codesnippet!

The phenomenon you're describing is due to a bug in the original Wav2Vec2 implementation, where layer-norm was applied after the attention layer (instead of before as it should have been). This was a bug that was in the original fairseq codebase for the wav2vec2-base model. The model was trained and release with this bug, and thus it was copied over to transformers when the model was added.

What's interesting is that the wav2vec2-large model applies layer-norm in the correct way! Layer-norm is applied before the attention layer.

We differentiate between applying layer-norm before/after the attention layer with the config parameter do_stable_layer_norm.

If False (like the base model), we apply layer-norm after the attention layer: wav2vec2-base-960h/config.json#L43

If True (like the large model), we apply layer-norm correctly before the attention layer: wav2vec2-large-960h-lv60-self/config.json#L43

Running the code snippet with the large model (which has correct layer-norm) and placing the model on the GPU (if available) gives the correct output, even with extreme padding:

from transformers import AutoProcessor, Wav2Vec2ForCTC
from datasets import load_dataset
import torch

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate

processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)

inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt",padding="max_length", truncation=True,max_length=1182340)
inputs = {key: value.to(device) for key, value in inputs.items()}

with torch.no_grad():
    logits = model(**inputs).logits

predicted_ids = torch.argmax(logits, dim=-1)

transcription = processor.batch_decode(predicted_ids)
print(transcription[0])

Print Output:

MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL

@schoi-habana
Copy link

schoi-habana commented Feb 11, 2023

Hi @sanchit-gandhi, thanks for the response!

Unfortunately, we are still seeing the accuracy issue with padding with the large model. Please find the code snippet below. We modified your code snippet for the dataset we are using, librispeech test clean.

WER=0.051 when a input sequence is padded with 60000 tokens, whereas WER=0.076 when no padding.
Can you please take a look at our example?

from transformers import AutoProcessor, Wav2Vec2ForCTC
from datasets import load_dataset
import torch
from jiwer import wer

dataset = load_dataset("librispeech_asr", "clean", split="test")
sampling_rate = dataset.features["audio"].sampling_rate

processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
inputs=[571]
for i in inputs:
    inputs_long_pad = processor(dataset[i]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt",padding="max_length", truncation=True,max_length=len(dataset[i]["audio"]["array"])+60000)
    inputs_long_pad = {key: value.to(device) for key, value in inputs_long_pad.items()}

    actual=dataset[i]['text']
    with torch.no_grad():
        logits_long_pad = model(**inputs_long_pad).logits

    predicted_ids_long_pad = torch.argmax(logits_long_pad, dim=-1)
    transcription_long_pad = processor.batch_decode(predicted_ids_long_pad)
    wer_long_pad = wer(actual,transcription_long_pad[0])

    inputs_short_pad = processor(dataset[i]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt",padding="max_length", truncation=True,max_length=len(dataset[i]["audio"]["array"]))
    inputs_short_pad = {key: value.to(device) for key, value in inputs_short_pad.items()}

    with torch.no_grad():
        logits_short_pad = model(**inputs_short_pad).logits

    predicted_ids_short_pad = torch.argmax(logits_short_pad, dim=-1)
    transcription_short_pad = processor.batch_decode(predicted_ids_short_pad)
    wer_short_pad = wer(actual,transcription_short_pad[0])
    print("long pad: ", transcription_long_pad[0], ", long pad WER: ", wer_long_pad)
    print("no pad: ", transcription_short_pad[0], ", no pad WER: ", wer_short_pad)
    print("ground truth: ", actual)

Outputs from the run:

long pad:  FOR MANY THEN THIS BOOK HAS BEEN A SOURCE OF FASCINATION SURELY ONE OF THE MOST INFLUENTIAL NOVELS EVER WRITTEN AN INSPIRATION FOR SUCH SCIENTISTS AND DISCOVERERS AS ENGINEER SIMON LAKE OCEANOGRAPHER WILLIAM BB POLAR TRAVELLER SIR ERNEST SHACKLETON , long pad WER:  0.05128205128205128
no pad:  FOR MANY THEN THIS BOOK HAS BEEN A SOURCE OF FASCINATION SURELY ONE OF THE MOST INFLUENTIAL NOVELS EVER WRITTEN AN INSPIRATION FOR SUCH SCIENTISTS AND DISCOVERERS AS ENGINEER SIMON LAKE OCEANOGRAPHER WILLIAM B B POLAR TRAVELLER SIR ERNEST SHACKLETON , no pad WER:  0.07692307692307693
ground truth:  FOR MANY THEN THIS BOOK HAS BEEN A SOURCE OF FASCINATION SURELY ONE OF THE MOST INFLUENTIAL NOVELS EVER WRITTEN AN INSPIRATION FOR SUCH SCIENTISTS AND DISCOVERERS AS ENGINEER SIMON LAKE OCEANOGRAPHER WILLIAM BEEBE POLAR TRAVELER SIR ERNEST SHACKLETON

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Feb 17, 2023

Hey @schoi-habana!

I think what we're seeing in this particular example is the effect of numerical precision on our model outputs. The padding mask is not perfect: we set the attention values to a very large negative number for the padded inputs, however they are not entirely masked to minus infinity (due to numerical precision constraints).

Essentially, we set the attention mask for padded inputs to the most negative value permitted by our dtype:

attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min

So in our attention computation, the padded attention values are masked as much as possible. This is as good as we can get for padded inputs. Here, we are bounded by numerical precision, and can't go down further. The workaround would be to perform un-batched inference, but here you pay the penalty of slower inference time.

We can see that the transcription is by-and-large correct, we've just got an extra space in a name:

  • No pad: BB
  • Pad: B B
  • Ground truth: BEEBE

The extra space is giving an additional insertion error for the padding case. But we can see that the transcription is pretty much identical in all other aspects.

If we evaluate the model on the full LibriSpeech test-clean corpus, we find that the results are the same to within numerical precision:

from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
from jiwer import wer

librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")


def map_to_pred(batch):
    inputs = processor(batch["audio"]["array"], return_tensors="pt", padding="longest", sampling_rate=16000)
    input_values = inputs.input_values.to("cuda")
    attention_mask = inputs.attention_mask.to("cuda")
    
    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)[0]
    batch["transcription"] = transcription
    return batch

result = librispeech_eval.map(map_to_pred, remove_columns=["audio"])

print("WER no pad:", wer(result["text"], result["transcription"]))


def map_to_pred_with_pad(batch):
    audio = batch["audio"]["array"]
    max_length = len(audio) + 60000

    inputs = processor(audio, return_tensors="pt", padding="max_length", max_length=max_length, sampling_rate=16000)
    input_values = inputs.input_values.to("cuda")
    attention_mask = inputs.attention_mask.to("cuda")
    
    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)[0]
    batch["transcription"] = transcription
    return batch

pad_result = librispeech_eval.map(map_to_pred_with_pad, remove_columns=["audio"])

print("WER with pad:", wer(pad_result["text"], pad_result["transcription"]))

Print Output:

100%|███████████████████████████████████████████████████████████████████████████████| 2620/2620 [01:53<00:00, 23.12ex/s]
WER no pad: 0.018620663420572125
100%|███████████████████████████████████████████████████████████████████████████████| 2620/2620 [02:31<00:00, 17.29ex/s]
WER with pad: 0.018620663420572125

@schoi-habana
Copy link

@sanchit-gandhi thanks we were able to reproduce the results with the wav2vec2-large model. Any plan to fix the bug in the wav2vec2-base model and publish the trained model?

@huggingface huggingface deleted a comment from github-actions bot Mar 17, 2023
@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Mar 17, 2023

Hey @schoi-habana! Very glad to hear that! Unfortunately, the bug in the wav2vec2-base model in the HuggingFace code was somewhat deliberate and is not something we can really change. What we're trying to provide with Transformers is an easy-to-use codebase that's built on-top of the official weights/code. Transformers seeks to match the official Facebook implementation of Wav2Vec2 as closely as possible. Since the official weights have the bug, our easy-to-use codebase also has to have the bug, such that our implementation gets the same results as the official one. So we can't really change this for the base model! If Facebook release a variant of the base model that has the bug fixed, you can be sure we'll also adapt the code to accommodate and host the weights on the HF Hub! But this seems quite unlikely.

The padding situation you've described seems very extreme! Is it really affecting your results quite drastically? From the codesnippet you shared it seems to be quite localised to one/two samples in the dataset no?

I think at this point we need to view it as a flaw with the model rather than a flaw with the codebase! I guess the options are either:

  1. Group your audio samples such that they have similar lengths and minimise padding (see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.group_by_length)
  2. Use the large checkpoint since it doesn't have the padding bug
  3. Acknowledge that the base checkpoint has the flaw with padding and accept the small degradation to WER!

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@sanchit-gandhi
Copy link
Contributor

Leaving this closed since the 'bug' is baked into the official implementation and thus propagated onto the 🤗 Transformers implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants