# Finetuning Wav2Vec2
## Overview
We now import a pre-trained model from HuggingFace 🤗 and fine-tune it on a labelled dataset. Then, we test its performace on our usual benchmark dataset (LibriSpeech). In particular, the model ([Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)) was pre-trained in a complete unsupervised fashion using Masked Language Modeling (MLM) and constrastive learning. As suggested in the paper, we will not fine-tune the feature extractor module because it already has good weights. During fine-tuning, we pick different splits both from [LIBRISPEECH](https://huggingface.co/datasets/librispeech_asr) (train-clean-100) and [FLEURS](https://huggingface.co/datasets/google/fleurs) labelled speech dataset. We then assess the performance of the fine-tuned model on the Librispeech benchmark (dev-clean, dev-other, test-clean, test-other). We believe that given a well-enough pre-trained model, it is sufficient to fine-tune it on few hours of speech (less then 1 hour of labelled speech) to obtain sufficiently good performance on unseen audio data.

## Observations
Fine-tuning on FLEURS and evaluating on LibriSpeech dev-clean seems to lead to poor generalization capabilities compared to fine-tuning on splits of LibriSpeech. We believe that this behaviour is due to the distribution mismatch between train and evaluation sets. Indeed, it holds the following:
- [Librispeech](https://huggingface.co/datasets/librispeech_asr) is narrated audiospeech, text is uncased, and there is no punctuation
- [FLEURS](https://huggingface.co/datasets/google/fleurs) is European Parliament speech, it's case sensitive, and there is punctiation

Taking a deeper look into the dataset, we notice how informal speech (FLEURS) contains nuances not present in LibriSpeech, namely: false starts, corrections, misspellings, and so on. Conversely, LibriSpeech is taken from perfectly pronouced podcast of books, recorded in controlled audio settings. We hence belive such mismatch is the source of good/bad generazation of the model when evaluated on the dev set of LibriSpeech.

## Libraries
- if using a Colab session, create `utils` folder and add `preprocessing.py` file.
- create `audioset` folder.

In [1]:
%%capture
! pip install accelerate
! pip install datasets
! pip install jiwer
! pip install inflect

In [2]:
# Main libraries
import numpy as np
import torch
import pandas as pd
import accelerate

# Datasets
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
import torchaudio
from datasets import Dataset as HFDataset
from datasets import DatasetDict
from datasets import Audio

# Hugging Face
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2FeatureExtractor
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2ForCTC
from transformers import TrainingArguments
from transformers import Trainer

# Others
import re
import inflect
import json
import IPython.display as ipd
import matplotlib.pyplot as plt
import textwrap
from tqdm import tqdm
import os

## Datasets
Regarding the training phase, we train on different splits of the LIBRISPEECH dataset and FLEURS dataset.



In [None]:
# FLEURS DATASET - USED IN EARLY EXPERIMENTS
# fleurs = load_dataset("google/fleurs", "en_us", split="train", trust_remote_code=True)
# useless_columns = ["id", "path", "raw_transcription", "gender", "lang_id", "language", "lang_group_id"]
# fleurs = fleurs.remove_columns(useless_columns)

In [3]:
folder_name = "datasets"

if not os.path.exists(folder_name):
    os.makedirs(folder_name)

train_clean_100 = torchaudio.datasets.LIBRISPEECH(folder_name, url="train-clean-100", download=True)
dev_clean = torchaudio.datasets.LIBRISPEECH(folder_name, url="dev-clean", download=True)


100%|██████████| 5.95G/5.95G [08:00<00:00, 13.3MB/s]
100%|██████████| 322M/322M [00:26<00:00, 12.6MB/s]


In [4]:
def get_data(dataset, fraction):
    num_samples = int(fraction * len(dataset))
    data = [{"audio": folder_name + "/LibriSpeech/" + dataset.get_metadata(i)[0],
             "transcription": dataset[i][2]}
            for i in tqdm(range(num_samples))]
    return data

In [5]:
hf_libri_train = HFDataset.from_list(get_data(train_clean_100, fraction=0.10))
hf_libri_train = hf_libri_train.cast_column("audio", Audio(sampling_rate=16_000))

100%|██████████| 2853/2853 [00:16<00:00, 174.38it/s]


In [6]:
print(hf_libri_train)
print(hf_libri_train.features)

Dataset({
    features: ['audio', 'transcription'],
    num_rows: 2853
})
{'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None), 'transcription': Value(dtype='string', id=None)}


In [7]:
hf_libri_test = HFDataset.from_list(get_data(dev_clean, 1.0))
hf_libri_test = hf_libri_test.cast_column("audio", Audio(sampling_rate=16_000))

100%|██████████| 2703/2703 [00:09<00:00, 286.11it/s]


In [8]:
print(hf_libri_test)
print(hf_libri_test.features)

Dataset({
    features: ['audio', 'transcription'],
    num_rows: 2703
})
{'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None), 'transcription': Value(dtype='string', id=None)}


## Text Normalization
LibriSpeech contains case insensitive transcriptions (text is all uppercase) and no punctuation symbols. Indeed, as the normalization step is enough to convert everything to lower case. Later, we will check if any special caracter accidentally appears in the dataset, and eventually build our vocabulary to be fed to the tokenizer.

In [9]:
def normalize(batch):
    batch["transcription"] = batch["transcription"].lower()
    return batch

hf_libri_train = hf_libri_train.map(normalize)
hf_libri_test = hf_libri_test.map(normalize)

Map:   0%|          | 0/2853 [00:00<?, ? examples/s]

Map:   0%|          | 0/2703 [00:00<?, ? examples/s]

Pick random audio sample and play it.

In [10]:
dummy_audioset = hf_libri_train # hf_libri_test

id = np.random.randint(low=0, high=(len(dummy_audioset)))

print("TRANSCRIPTION")
print(textwrap.fill(dummy_audioset[id]['transcription'], 40))

display(ipd.Audio(data=np.asarray(dummy_audioset[id]["audio"]["array"]), rate=16_000))

TRANSCRIPTION
she had hardly spoken when the horse
appeared and mounting on his back she
started for the village where the
wedding was to be held at first she was
so delighted with the chance of a
holiday from the work which she hated
that she noticed nothing


For convenience, we merge train and test splits into an unique DatasetDict object.

In [11]:
libri_speech = DatasetDict()
libri_speech["train"] = hf_libri_train
libri_speech["test"] = hf_libri_test

libri_speech

DatasetDict({
    train: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 2853
    })
    test: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 2703
    })
})

We are now ready to crate the vocabulary. It consists of all English alphabet, apostroph, space, plus special tokens (unk and pad).

In [12]:
def get_vocabulary(dataset):
    all_text = ""
    for example in tqdm(dataset):
        all_text += example["transcription"]
    unique_chars = sorted(set(all_text))
    return unique_chars

In [13]:
vocab_list = get_vocabulary(libri_speech["train"])
vocab_dict = {v: k for k, v in enumerate(vocab_list)}

vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

vocab_name = "vocab.json"
with open(vocab_name, 'w') as file:
    json.dump(vocab_dict, file, indent=4)
    print('JSON dumped!')

100%|██████████| 2853/2853 [00:32<00:00, 87.44it/s] 

JSON dumped!





## Tokenization
We use the Hugging Face feature extractor, which takes audiowaves and preprocess them. We also need the Hugging Face tokenizer, that converts predicted log-probabilities into vocabulary tokens. Both extractors are wrapped in a useful processor object.

In [14]:
tokenizer = Wav2Vec2CTCTokenizer(f"./{vocab_name}",
                                 unk_token="[UNK]",
                                 pad_token="[PAD]",
                                 word_delimiter_token="|")

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1,
                                             sampling_rate=16_000,
                                             padding_value=0.0,
                                             do_normalize=True,
                                             return_attention_mask=False)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor,
                              tokenizer=tokenizer)


We need to process input data so they match Hugging Face model formats:
- `input_values`: audiowaves as list of floats
- `labels`: tokenized transcriptions

In [15]:
def prepare_dataset(batch):
    audio = batch["audio"]

    # call feature extractor
    batch["input_values"] = processor(audio["array"],sampling_rate=audio["sampling_rate"]).input_values[0]

    # call tokenizer
    with processor.as_target_processor():
        batch["labels"] = processor(batch["transcription"]).input_ids
    return batch

In [16]:
libri_speech = libri_speech.map(prepare_dataset,
                                remove_columns=libri_speech.column_names["train"],
                                num_proc=1)

Map:   0%|          | 0/2853 [00:00<?, ? examples/s]



Map:   0%|          | 0/2703 [00:00<?, ? examples/s]

Now the dataset has its columns ready to be fed into the pre-trained model.

In [17]:
pd.DataFrame(libri_speech["train"][:3])

Unnamed: 0,input_values,labels
0,"[-0.30116063356399536, -0.2587476074695587, -0...","[4, 9, 2, 17, 21, 6, 19, 0, 16, 15, 6, 0, 14, ..."
1,"[-0.2367478609085083, -0.18393221497535706, -0...","[21, 9, 2, 21, 0, 9, 2, 5, 0, 10, 21, 20, 0, 2..."
2,"[0.15846337378025055, 0.23213636875152588, 0.3...","[7, 16, 19, 0, 15, 16, 21, 0, 6, 23, 6, 15, 0,..."


## Training 🥋
To fine-tune Wav2Vec2 we need:
- a data collator: input sequences are padded dinamically, i.e. training samples are padded to the longest sequence in their batch
- an evaluation metric: we define a function to compute the Word Error Rate (WER), that will be computed on the dev set after a certain amount of steps
- a [model checkpoint](https://huggingface.co/facebook/wav2vec2-base): we joust load model card from Hugging Face 🤗

The DataCollator Class was copied from [this repo](https://github.com/huggingface/transformers/blob/9a06b6b11bdfc42eea08fa91d0c737d1863c99e3/examples/research_projects/wav2vec2/run_asr.py#L81). Since audio and text belong to different modalities, we need to using different padding strategies. Initial loss definition and model hyperparameters were adapted from the official Hugging Face documentation, that can be found [here](https://huggingface.co/blog/fine-tune-wav2vec2-english).

In [18]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch


In [19]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

Metrics - Word Error Rate
- we use a greedy decoding approach, i.e. taking the argmax along the logit vector (most probable character)
- we transform the encoded labels back to the original string replacing `-100` with `pad_token_id`

In [20]:
wer_metric = load_metric("wer", trust_remote_code=True)

  wer_metric = load_metric("wer", trust_remote_code=True)


Downloading builder script:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

In [22]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    # substitute -100 with pad token
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    # get predicted and GT transcriptions
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

We are now ready to load the checkpoint from the Hugging Face Hub. To stabilize training, we also compute the mean of the CTC loss among the batch samples.

In [23]:
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base",
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/380M [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.bias', 'lm_head.weight', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


We keep the feature extractor frozen, and train only the head.

In [24]:
model.freeze_feature_extractor()



In [25]:
training_args = TrainingArguments(
    output_dir="wav2vec2-base",
    group_by_length=True,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    num_train_epochs=5,
    fp16=True,
    gradient_checkpointing=True,
    learning_rate=1e-4,
    lr_scheduler_type="linear",
    weight_decay=0.005,
    warmup_ratio=0.3,
    save_total_limit=2,
)



In [26]:
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=libri_speech["train"],
    eval_dataset=libri_speech["test"],
    tokenizer=processor.feature_extractor,
)

In [27]:
trainer.train()



Epoch,Training Loss,Validation Loss,Wer
0,No log,3.173289,1.0
2,2.470300,0.423687,0.284567
4,2.470300,0.628218,0.21854




TrainOutput(global_step=890, training_loss=1.4850337424974762, metrics={'train_runtime': 3517.7819, 'train_samples_per_second': 4.055, 'train_steps_per_second': 0.253, 'total_flos': 1.5962668111454285e+18, 'train_loss': 1.4850337424974762, 'epoch': 4.985994397759104})

## Evaluation
We now compare the predicted and ground truth transcription, and inspect eventual mistakes.

In [28]:
def map_to_result(batch):
  with torch.no_grad():
    input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
    logits = model(input_values).logits

  pred_ids = torch.argmax(logits, dim=-1)
  batch["pred_str"] = processor.batch_decode(pred_ids)[0]
  batch["text"] = processor.decode(batch["labels"], group_tokens=False)

  return batch

We now import benchmark datasets.

In [None]:
def get_eval_set(url):
    if url not in {"dev-clean", "dev-other", "test-clean", "test-other"}:
        return "Error: wrong split."
    print(f"Downloading {url}...")
    eval_set = torchaudio.datasets.LIBRISPEECH(folder_name, url=url, download=True)
    hf_eval_set = HFDataset.from_list(get_data(eval_set, 1.0))
    hf_eval_set = hf_eval_set.cast_column("audio", Audio(sampling_rate=16_000))
    hf_eval_set = hf_eval_set.map(normalize)
    hf_eval_set = hf_eval_set.map(prepare_dataset,
                                    remove_columns=hf_eval_set.column_names,
                                    num_proc=1)
    return hf_eval_set

eval_set = get_eval_set("test-clean")

Downloading test-clean...


 50%|████▉     | 165M/331M [00:24<00:16, 10.9MB/s]

In [32]:
results = eval_set.map(map_to_result, remove_columns=eval_set.column_names)
print("Test WER: {:.4f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))


Test WER: 0.2782


In [34]:
pd.DataFrame(results[:10])

Unnamed: 0,pred_str,text
0,as i approached the citty i heard bells ringin...,as i approached the city i heard bells ringing...
1,looking about me i saw a gentleman in a neet b...,looking about me i saw a gentleman in a neat b...
2,he must have realli'ed i was a stranger and wi...,he must have realized i was a stranger and wis...
3,we gayed for a moment lightly into each other'...,we gazed for a moment silently into each other...
4,of course you are going there to i said to my ...,of course you are going there too i said to my...
5,yes he answered i conduct the worchip i am a p...,yes he answered i conduct the worship i am a p...
6,an idl i whispered taken by surprise,an idol i whispered taken by surprise
7,they worchiped god they did not exist,they worshipped gods that did not exist
8,but the greks loved their gods i protested my ...,but the greeks loved their gods i protested my...
9,no i said in a low voice,no i said in a low voice
