# Automatic Speech Recognition (ASR) Tutorial 

## Fine-tune a pretrained, multilingual ASR model on FLEURS

In this tutorial, we will be evaluating and improving a multilingual ASR model for a language in the FLEURS dataset. We will focus on **Hausa**, but you can follow along in any language in FLEURS. See the FLEURS dataset paper for a list of supported languages: https://arxiv.org/abs/2205.12446.

We will be looking at three major open-source ASR multilingual models:
* XLS-R  (https://arxiv.org/abs/2111.09296)
* Whisper (https://cdn.openai.com/papers/whisper.pdf)
* MMS (https://scontent-sjc3-1.xx.fbcdn.net/v/t39.8562-6/348827959_6967534189927933_6819186233244071998_n.pdf?_nc_cat=104&ccb=1-7&_nc_sid=ad8a9d&_nc_ohc=-JOSFMsFL-UAX-4O6o4&_nc_ht=scontent-sjc3-1.xx&oh=00_AfDdMFq0DP2xIRyjWpGrmIpqncnouiylLfWnFsAgxboLWw&oe=6497E242)

## Before you start: Setting up your coding environment 

You can run follow along and run the lines of code in this notebook, and also utilize the scripts found in this GitHub respository. Before starting this tutorial, you will need to create a virtual environment for this project so you can download all the required packages without affecting your other projects. We recommend using Anaconda (conda) to create a virtual environment. We have provided an `environment.yml` file that you can use to create a virtual environment named `asr` containing all the required packages. In your terminal run this code:

```
git clone https://github.com/kashrest/lrl-asr-experiments.git
cd lrl-asr-experiments
conda env create -f environment.yml 
conda activate asr
```

Now you can run the lines in this notebook.

Note: The pretrained multilingual ASR models we will be using in this notebook require GPUs with at least 40 GB of space (CHECK) for practical use. 

## Data Preprocessing

The first step is to download and prepare the data for the ASR model. Hugging Face has an easy way to download FLEURS data for any supported language, where the split can be specified

In [1]:
import os 
from datasets import load_dataset

cache_dir_fleurs = "./data/fleurs/"

# create a data directory for caching 
try:
    os.mkdir(cache_dir_fleurs)
except:
    pass

# create a directory for outputs 
out_dir = "./tutorial/"

try:
    os.mkdir(out_dir)
except:
    pass


# for Hausa, the language code is "ha_ng"
train_data = load_dataset("google/fleurs", "ha_ng", split="train", cache_dir=cache_dir_fleurs)
val_data = load_dataset("google/fleurs", "ha_ng", split="validation", cache_dir=cache_dir_fleurs)
test_data = load_dataset("google/fleurs", "ha_ng", split="test", cache_dir=cache_dir_fleurs)

Found cached dataset fleurs (/data/users/kashrest/lrl-asr-experiments/data/fleurs/google___fleurs/ha_ng/2.0.0/af82dbec419a815084fa63ebd5d5a9f24a6e9acdf9887b9e3b8c6bbd64e0b7ac)
Found cached dataset fleurs (/data/users/kashrest/lrl-asr-experiments/data/fleurs/google___fleurs/ha_ng/2.0.0/af82dbec419a815084fa63ebd5d5a9f24a6e9acdf9887b9e3b8c6bbd64e0b7ac)
Found cached dataset fleurs (/data/users/kashrest/lrl-asr-experiments/data/fleurs/google___fleurs/ha_ng/2.0.0/af82dbec419a815084fa63ebd5d5a9f24a6e9acdf9887b9e3b8c6bbd64e0b7ac)


FLEURS data is organized like so

In [2]:
train_data[0]

{'id': 302,
 'num_samples': 301440,
 'path': '/data/users/kashrest/asr-experiments/downloads/extracted/953575415c600d2042020b042380119265aefaa4fcf95afc300adfcd2b79784e/10002175198254707815.wav',
 'audio': {'path': 'train/10002175198254707815.wav',
  'array': array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         6.78300858e-05, 1.26361847e-05, 6.46114349e-05]),
  'sampling_rate': 16000},
 'transcription': 'nasarorin da vautier ta samu marasa alaka da bada umarni ba sun hada da yajin cin abinci a 1973 a kan abin da ya ke ganin dabaibayin siyasa ne',
 'raw_transcription': 'Nasarorin da Vautier ta samu marasa alaka da bada umarni ba sun hada da yajin cin abinci a 1973 a kan abin da ya ke ganin dabaibayin siyasa ne.',
 'gender': 1,
 'lang_id': 30,
 'language': 'Hausa',
 'lang_group_id': 3}

We are interested in the audio (represented as an array of floats each describing the amplitude/loudness of the sound; the number of floats is determined by the sampling rate which here is 16,000 measurements per second) and the corresponding transcript

In [3]:
train_transcripts, val_transcripts, test_transcripts = [], [], []
train_audio, val_audio, test_audio = [], [], []

for elem in train_data:
    assert elem["audio"]["sampling_rate"] == 16000
    train_audio.append(elem["audio"]["array"])
    train_transcripts.append(elem["raw_transcription"])
    
for elem in val_data:
    assert elem["audio"]["sampling_rate"] == 16000
    val_audio.append(elem["audio"]["array"])
    val_transcripts.append(elem["raw_transcription"])
    
for elem in test_data:
    assert elem["audio"]["sampling_rate"] == 16000
    test_audio.append(elem["audio"]["array"])
    test_transcripts.append(elem["raw_transcription"])

Now, since we are interested in transcribing speech, we want to clean the transcripts by removing special characters that do not have a clear sound (such as ! '). This part may depend on your target application and language. For example for Hausa, many native speakers do not speak English and does not have much code-switching, so we want to also normalize any foreign characters (ç ş) and symbols (% & $).

In [4]:
import re

def preprocess_texts_hausa(transcriptions):
    chars_to_remove_regex = '[\,\?\!\-\;\:\"\“\%\‘\'\ʻ\”\�\$\&\(\)\–\—]'

    def _remove_special_characters(transcription):
        transcription = transcription.strip()
        transcription = transcription.lower()
        transcription = re.sub(chars_to_remove_regex, '', transcription)
        transcription = re.sub("\[\]\{\}", '', transcription)
        transcription = re.sub(r'[\\]', '', transcription)
        transcription = re.sub(r'[/]', '', transcription)
        transcription = re.sub(u'[¥£°¾½²]', '', transcription)
        transcription = re.sub(u'[\+><]', '', transcription)
        return transcription

    def _normalize_diacritics(transcription):
        a = '[āăáã]'
        u = '[ūúü]'
        o = '[öõó]' 
        c = '[ç]'
        i = '[í]'
        s = '[ş]'
        e = '[é]'

        transcription = re.sub(a, "a", transcription)
        transcription = re.sub(u, "u", transcription)
        transcription = re.sub(o, "o", transcription)
        transcription = re.sub(c, "c", transcription)
        transcription = re.sub(i, "i", transcription)
        transcription = re.sub(s, "s", transcription)
        transcription = re.sub(e, "e", transcription)

        return transcription

    cleaned_transcriptions = map(_remove_special_characters, transcriptions)
    cleaned_transcriptions = list(map(_normalize_diacritics, list(cleaned_transcriptions)))
    return cleaned_transcriptions

train_transcripts = preprocess_texts_hausa(train_transcripts)
val_transcripts = preprocess_texts_hausa(val_transcripts)
test_transcripts = preprocess_texts_hausa(test_transcripts)

Some models predict one character at a time, and so we need a character vocabulary made up of all characters in the dataset after preprocessing. We can save the vocabulary in a JSON file

In [5]:
import json 

def extract_all_chars(transcription):
      all_text = " ".join(transcription)
      vocab = list(set(all_text))
      return {"vocab": [vocab], "all_text": [all_text]}

vocab_train = list(map(extract_all_chars, train_transcripts))
vocab_val = list(map(extract_all_chars, val_transcripts))
vocab_test = list(map(extract_all_chars, test_transcripts))

vocab_train_chars = []
for elem in [elem["vocab"][0] for elem in vocab_train]:
    vocab_train_chars.extend(elem)

vocab_val_chars = []
for elem in [elem["vocab"][0] for elem in vocab_val]:
    vocab_val_chars.extend(elem)

vocab_test_chars = []
for elem in [elem["vocab"][0] for elem in vocab_test]:
    vocab_test_chars.extend(elem)

vocab_list = list(set(vocab_train_chars) | set(vocab_val_chars) | set(vocab_test_chars))
vocab_dict = {v: k for k, v in enumerate(vocab_list)}

# for word delimiter, change " " --> "|" (ex. "Hello my name is Bob" --> "Hello|my|name|is|Bob")
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict) # this is for models (like MMS and XLS-R) that use the CTC algorithm to predict the end of a character (e.g. "hhh[PAD]iii[PAD]iii[PAD]" == "hii")

vocab_file = out_dir+"vocab_hausa.json"
with open(vocab_file, 'w') as f:
    json.dump(vocab_dict, f)

## Evaluation code

In ASR, word error rate (WER) and character error rate (CER) are the common metrics used to evaluate how good a model-produced transcript is in comparison to the gold transcript.

In [6]:
from datasets import load_metric, load_dataset, Audio
                         
def compute_metrics(label_strs, pred_strs):
    wer_metric = load_metric("wer")
    cer_metric = load_metric("cer")
    # make sure labels are preprocessed the same way for proper comparison with other models after finetuning
    wer = wer_metric.compute(predictions=pred_strs, references=label_strs) * 100
    cer = cer_metric.compute(predictions=pred_strs, references=label_strs) * 100
    return {"wer": wer, "cer": cer}

## Section A: Zero-Shot ASR

Let's run inference on our audio dataset with different existing, ready-to-use pretrained and finetuned models. Later, when we compare model performance, we will compare performance on the test set since some models will be trained on the train split.

### OpenAI Whisper

OpenAI's Whisper model is a pretrained encoder-decoder model that supports a set of languages without futher fine-tuning. Here, we will use whisper-large V2 to compare it with MMS-1b-all (both have about 1 billion parameters). You can use the smaller checkpoints if you do not have enough GPU space (found on Hugging Face Hub: https://huggingface.co/openai/whisper-small) or decrease the batch size.

Note: Whisper requires that input is sampled at 16,000 Hz. FLEURS data is sampled at this rate, so we are good. Also, Whisper does not support all FLEURS languages (the 96 languages are listed in the paper)

In [7]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from tqdm import tqdm
import time

With a batch size of 10, inference takes about 5 minutes.

In [8]:
device = "cuda:0" # change this to a gpu if you have access to one, otherwise set to "cpu"
model_id = "openai/whisper-large-v2" # 8630MiB for batch size 1, ~25426MiB for batch size 10
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="hausa", task="transcribe")

predicted_test_transcripts = []

batch_size = 10 # decrease if needed

for i in tqdm(range(0, len(test_audio), batch_size)):
    batch = test_audio[i:i+batch_size] if i+batch_size <= len(test_audio) else test_audio[i:]
    input_features = processor(batch, sampling_rate=16000, return_tensors="pt").input_features.to(device)
    # generate token ids
    predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
    # decode token ids to text
    predicted_test_transcripts.extend(processor.batch_decode(predicted_ids, skip_special_tokens=True))

100%|████████████████████████████████████████████████████████████████████████████████| 63/63 [05:19<00:00,  5.07s/it]


Let's evaluate the performance of whisper-large V2 on our preprocessed test dataset

In [11]:
compute_metrics(test_transcripts, predicted_test_transcripts)

  wer_metric = load_metric("wer")


{'wer': 97.78959507944643, 'cer': 40.52395399540877}

It looks like we have a 97.8% WER and 40.5% CER. Let's create a running table of the performances of different models on our test dataset. 

| Model | WER % | CER %|
|-------|-----|----|
|whisper-large-v2| 97.8| 40.5|

### Facebook MMS

A batch size of 10 takes approximately 1 hour (?!).

In [36]:
import torch
from transformers import Wav2Vec2ForCTC, AutoProcessor

model_id = "facebook/mms-1b-all"

processor = AutoProcessor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id)

processor.tokenizer.set_target_lang("hau")
model.load_adapter("hau")

predicted_test_transcripts = []

batch_size = 10

"""for i in tqdm(range(0, len(test_audio), batch_size)):
    batch = test_audio[i:i+batch_size] if i+batch_size <= len(test_audio) else test_audio[i:]
    inputs = processor(batch, sampling_rate=16_000, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(**inputs).logits
    ids = torch.argmax(outputs, dim=-1)
    predicted_test_transcripts.extend((processor.batch_decode(ids)))
"""
for elem in tqdm(test_audio): 
    inputs = processor(elem, sampling_rate=16_000, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs).logits
    ids = torch.argmax(outputs, dim=-1)[0]
    predicted_test_transcripts.append(processor.decode(ids))

  2%|█▊                                                                             | 14/621 [00:35<25:38,  2.53s/it]


KeyboardInterrupt: 

In [None]:
compute_metrics(test_transcripts, predicted_test_transcripts)

It looks like we have a 29.3% WER and 7.7% CER. Let's add this result to our table. This is substantially better than whisper-large!

| Model | WER % | CER %|
|-------|-----|----|
|whisper-large-v2| 97.8| 40.5|
|mms-1b-all|29.3|7.7|

## Section B: Finetuning

For fine-tuning, we will be using functions from the Hugging Face API for the training loop and model setup. In order to use these functions, we need to wrap the data in a custom PyTorch Dataset object.

In [53]:
import torch
class ASRDatasetWav2Vec2(torch.utils.data.Dataset):
    def __init__(self, audio, transcripts, sampling_rate, processor):
        self.audio = audio
        self.transcripts = transcripts
        self.sampling_rate = sampling_rate
        self.processor = processor
    
    def __getitem__(self, idx):
        input_values = self.processor.feature_extractor(self.audio[idx], sampling_rate=self.sampling_rate).input_values[0]
        labels = self.processor.tokenizer(self.transcripts[idx]).input_ids
        item = {}
        item["input_values"] = input_values
        item["labels"] = labels
        
        return item

    def __len__(self):
        return len(self.transcripts)

In [37]:
class ASRDatasetWhisper(torch.utils.data.Dataset):
    def __init__(self, audio, transcripts, sampling_rate, processor):
        self.audio = audio
        self.transcripts = transcripts
        self.sampling_rate = sampling_rate
        self.processor = processor
    
    def __getitem__(self, idx):
        input_values = self.processor.feature_extractor(self.audio[idx], sampling_rate=self.sampling_rate).input_features[0]
        labels = self.processor.tokenizer(self.transcripts[idx]).input_ids
        item = {}
        item["input_features"] = input_values
        item["labels"] = labels
        
        return item

    def __len__(self):
        return len(self.transcripts)

### Whisper

First, let's import the required Whisper classes and training loop functions from Hugging Face and some other utility functions

Note: Hugging Face has a great tutorial that we referenced for fine-tuning Whisper. You can refer to this tutorial for more information if needed: https://huggingface.co/blog/fine-tune-whisper#prepare-feature-extractor-tokenizer-and-data

In [11]:
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import evaluate

Let's fine-tune a Whisper-large V2 model which we will download from Hugging Face, and set up a WhisperProcessor object which takes contains a feature extractor and a tokenizer. The feature extractor transforms the input into log-Mel spectrograms. This transformation takes in the amplitude information respresented by the input array and transforms it into frequencies (refer to the Hugging Face tutorial for more information). Frequencies encode pitch, and so useful audio signals can be found for speech recognition. Additionally, the tokenizer splits the transcripts into tokens based on Whisper's vocabulary. Whisper utilizes byte-level BPE, which is the same tokenizer as GPT-2. If interested, refer to this page: https://huggingface.co/learn/nlp-course/chapter6/5?fw=pt

In [27]:
model_card = "openai/whisper-large-v2"
processor = WhisperProcessor.from_pretrained(model_card, language="Hausa", task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained(model_card)

The following lines are required to fine-tune the Whisper model. The first line makes the model predict the language and task by setting the token ids that control the transcription language and task, to `None`.

The second line makes sure that all possible tokens are predicted by setting the set of supressed tokens to an empty list.

In [28]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

Like mentioned before, Whisper takes inputs sampled at 16,000 Hz, and so we will prepare our data using this sampling rate (Fleurs is already sampled at 16 kHz) using the ASRDataset object mentioned before

In [38]:
model_sampling_rate = 16000
train_dataset = ASRDatasetWhisper(train_audio, train_transcripts, model_sampling_rate, processor)
val_dataset = ASRDatasetWhisper(val_audio,  val_transcripts, model_sampling_rate, processor)
test_dataset = ASRDatasetWhisper(test_audio, test_transcripts, model_sampling_rate, processor)

Next, we need a function that will pad all the inputs/outputs in a batch to the same length. This code is from the tutorial mentioned earlier.

In [30]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch
    
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In order to use the Trainer class from Hugging Face, we need to define an evaluation function that takes in a model prediction object.

In [31]:
def compute_metrics(pred):
    metric_wer = evaluate.load("wer")
    metric_cer = evaluate.load("cer")
    
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric_wer.compute(predictions=pred_str, references=label_str)
    cer = 100 * metric_cer.compute(predictions=pred_str, references=label_str)

    return {"wer": wer, "cer": cer}

Now, we will setup the model training hyperparameters by using Hugging Face Seq2SeqTrainingArguments. Feel free to experiment with different hyperparameters. Learning rate is an important hyperparameter to experiment with. Reference the official Seq2SeqTrainingArguments for explanations of the hyperparameters: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments

Note: Decrease batch size if you have limited GPU space. We have also set the mixed precision 

In [41]:
train_batch_size = 16
num_train_epochs = 3 
learning_rate = 1e-05

training_args = Seq2SeqTrainingArguments(
    output_dir=out_dir+"whisper-finetuning-experiment-2/",  # change to a repo name of your choice
    per_device_train_batch_size=train_batch_size,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=learning_rate,
    warmup_steps=500,
    num_train_epochs=num_train_epochs,
    gradient_checkpointing=True, # another way to save GPU memory by recomputing gradients (less memory, more time)
    fp16=True, # this enables mixed precision training, which lets some data be stored in 16 bit floating point precision instead of 32 bits.
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=100,
    eval_steps=100,
    save_total_limit=2, 
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False)

Now, we set up the Trainer object by inputing our training and validation datasets, our evaluation function, tokenizer, model, data collator, and previously instantiated training arguments.

In [43]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

Then we call train to start training. Training the Whisper large V2 model with batch size 16 for 3 epochs takes about 1 hour and about 56846MiB (57 GB) on a NVIDIA A100 GPU. 

In [44]:
trainer.train()

Step,Training Loss,Validation Loss,Wer,Cer
100,No log,0.678611,47.085015,15.846247
200,No log,0.626455,53.026297,22.505824
300,No log,0.603841,48.69904,20.497822
400,No log,0.572886,42.646445,18.659475
500,0.386100,0.596419,53.90288,24.265674
600,0.386100,0.549716,41.185474,18.540464


Training finished in 3869.6817395687103 seconds.


Let's see the performance on the FLEURS test set

In [45]:
preds = trainer.predict(test_dataset)
eval_preds = compute_metrics(preds)
eval_preds

{'wer': 40.395950794464376, 'cer': 19.22895013093082}

It looks like we have a 40.4% WER and 19.2% CER. Great! We have some made some improvement after finetuning for 3 epochs. Let's add this result to our table. 

| Model | WER % | CER %|
|-------|-----|----|
|whisper-large-v2| 97.8| 40.5|
|mms-1b-all|**29.3**|**7.7**|
|finetuned whisper-large-v2|40.4|19.2|

It looks like mms-1b-all still has the best results. Let's see if further finetuning mms-1b-all is better.

### MMS

MMS-1b-all is Facebook's MMS (**M**assively **M**ultilingual **S**peech) model, which is MMS, an ASR (Wav2Vec) model that is pretrained on a large corpus of Bible data covering 1107 languages, finetuned on additional labeled datasets. We wil further fine-tune MMS to see if it can be improved by further finetuning on our FLEURS dataset. You can refer to Hugging Face's recent MMS finetuning blog for more details and explanations if needed: https://huggingface.co/blog/mms_adapters

MMS-1b-all works by incorporating an adapter architecture, which are extra parameters throughout the architecture that are trainable during finetuning, and are language-specific. This enables the user to finetune a smaller number of parameters in comparison to the entire model.

Here, we will finetune the MMS adapter weights for Hausa.

First, we will set up the tokenizer based on our previously made character vocabulary, setting special tokens for unknown characters, padding, and word delimiters according to the vocabulary. We need to specify our vocabulary for the specific language of interest in a dictionary so that the MMS-1b-all checkpoint will correctly finetune the adapter weights for Hausa.

In [7]:
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor

target_lang = "hau"

with open(vocab_file, "r") as f:
    vocab_dict = json.load(f)

new_vocab_dict = {target_lang: vocab_dict}

experiment_file = out_dir+"mms-1b-all-finetuning-1/"

try:
    os.mkdir(experiment_file)
except:
    pass

with open(experiment_file+"vocab.json", 'w') as f:
    json.dump(new_vocab_dict, f)
    
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(experiment_file, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|", target_lang=target_lang)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Then, we will setup the feature extractor, which transforms the input audio into features. MMS takes in the raw audio, unlike the Whisper model, and simply zero-mean-unit-variance normalizes the values.

In [8]:
model_sampling_rate = 16000
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

And finally, the processor wraps both the tokenizer and feature extractor into one conventient class.

In [9]:
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

Now, we want to create a data collator (similar to the one we made for Whisper) that prepares the input in batches for the model

In [56]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union


@dataclass
class DataCollatorCTCWithPaddingWav2Vec2: # credit: https://huggingface.co/blog/mms_adapters
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        labels_batch = self.processor.pad(
            labels=label_features,
            padding=self.padding,
            return_tensors="pt",
        )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch
data_collator = DataCollatorCTCWithPaddingWav2Vec2(processor=processor, padding=True)

Now we create an evaluation function.

In [23]:
import numpy as np
def compute_metrics(pred):
    wer_metric = evaluate.load("wer")
    cer_metric = evaluate.load("cer")
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer, "cer": cer}

Now, we can define the model.

In [13]:
from transformers import Wav2Vec2ForCTC

model_card = "facebook/mms-1b-all"
model = Wav2Vec2ForCTC.from_pretrained(
    model_card,
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    ignore_mismatched_sizes=True,
)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/mms-1b-all and are newly initialized because the shapes did not match:
- lm_head.bias: found shape torch.Size([154]) in the checkpoint and torch.Size([50]) in the model instantiated
- lm_head.weight: found shape torch.Size([154, 1280]) in the checkpoint and torch.Size([50, 1280]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


We re-initialize the adapter layers to prepare for finetuning

In [14]:
model.init_adapter_layers()

Then we freeze all the parameters (learned from the pretraining and finetuning by the Meta team) except the adapter weights

In [15]:
model.freeze_base_model()

adapter_weights = model._get_adapters()
for param in adapter_weights.values():
    param.requires_grad = True

Then, we set up the parameters for model training like for Whisper

In [16]:
train_batch_size = 16 # batch size 2 already takes 32472MiB (32 GB), batch size 16 takes 74049MiB (74 GB)
learning_rate = 1e-3
num_epochs = 3

In [17]:
from transformers import TrainingArguments
training_args = TrainingArguments(
  output_dir=experiment_file,
  group_by_length=True,
  per_device_train_batch_size=train_batch_size,
  evaluation_strategy="steps",
  num_train_epochs=num_epochs,
  gradient_checkpointing=True, # another way to save GPU memory by recomputing gradients (less memory, more time)
  fp16=True, # this enables mixed precision training, which lets some data be stored in 16 bit floating point precision instead of 32 bits.
  save_steps=200,
  eval_steps=100,
  logging_steps=100,
  learning_rate=learning_rate,
  warmup_steps=100,
  save_total_limit=2,
  push_to_hub=False,
)

Then send everything to the Trainer class for training!

In [20]:
# since our processor is different, we will need to create new ASRDataset objects
train_dataset = ASRDatasetWav2Vec2(train_audio, train_transcripts, model_sampling_rate, processor)
val_dataset = ASRDatasetWav2Vec2(val_audio,  val_transcripts, model_sampling_rate, processor)
test_dataset = ASRDatasetWav2Vec2(test_audio, test_transcripts, model_sampling_rate, processor)

In [24]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor.feature_extractor,
)

Training 3 epochs with batch size 16 takes about 25 minutes and about 74049MiB (74 GB) on an NVIDIA A100 GPU. 

In [25]:
trainer.train()

Step,Training Loss,Validation Loss,Wer,Cer
100,0.2209,0.253368,0.287324,0.074749
200,0.2071,0.245935,0.288298,0.073331
300,0.1983,0.248379,0.297064,0.083004
400,0.1984,0.253414,0.286907,0.075104
500,0.1851,0.249626,0.285933,0.082928
600,0.1825,0.249387,0.282594,0.079864


TrainOutput(global_step=612, training_loss=0.19829010495952532, metrics={'train_runtime': 1454.1031, 'train_samples_per_second': 6.724, 'train_steps_per_second': 0.421, 'total_flos': 1.4198738598624823e+19, 'train_loss': 0.19829010495952532, 'epoch': 3.0})

In [26]:
preds = trainer.predict(test_dataset)
eval_preds = compute_metrics(preds)
eval_preds

{'wer': 0.27905933615276174, 'cer': 0.08098286900847898}

It looks like we have a 27.9% WER and 8.1% CER. Interesting, we have a 4% decrease in WER but a 5% increase in CER... We have some made some improvement in WER after finetuning. Perhaps a different set of hyperparameters (such as learning rate) would show better results. Please refer to Section C for guidance on how to experiment with different hyperparameters. Let's add this result to our table. 

| Model | WER % | CER %|
|-------|-----|----|
|whisper-large-v2| 97.8| 40.5|
|mms-1b-all|29.3|**7.7**|
|finetuned whisper-large-v2|40.4|19.2|
|finetuned mms-1b-all| **27.9**| 8.1|

### XLS-R

XLS-R was released before MMS, and the MMS paper claims (CHECK) that it has better performance than XLS-R. However, it may be a good idea to check to see which model is better for your specific dataset and use-case. Therefore, let's try finetuning XLS-R on the Hausa fleurs dataset

Similar to MMS, we will create a tokenizer from the character vocabulary file we made earlier in this tutorial, then the feature extractor and processor that wraps the tokenizer and feature extractor.

In [52]:
from transformers import Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer(vocab_file, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=model_sampling_rate, padding_value=0.0, do_normalize=True, return_attention_mask=True)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

Then, we want to create our dataset objects using our processor.

In [54]:
train_dataset = ASRDatasetWav2Vec2(train_audio, train_transcripts, model_sampling_rate, processor)
val_dataset = ASRDatasetWav2Vec2(val_audio,  val_transcripts, model_sampling_rate, processor)
test_dataset = ASRDatasetWav2Vec2(test_audio, test_transcripts, model_sampling_rate, processor)

Then, we want to instantiate a data collator of the same class as the one for MMS 

In [57]:
data_collator = DataCollatorCTCWithPaddingWav2Vec2(processor=processor, padding=True)

Now, we create variables and functions for training

In [74]:
import numpy as np
def compute_metrics(pred):
    wer_metric = evaluate.load("wer")
    cer_metric = evaluate.load("cer")
    
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer, "cer": cer}

In [75]:
batch_size = 8
learning_rate = 2e-4
num_train_epochs = 3
attention_dropout = 0.1
hidden_dropout = 0.1
feat_proj_dropout = 0.0
mask_time_prob = 0.05
layerdrop = 0.1
warmup_steps = 500
    
model_card = "facebook/wav2vec2-xls-r-1b"
model = Wav2Vec2ForCTC.from_pretrained(
    model_card, 
    attention_dropout=attention_dropout,
    hidden_dropout=hidden_dropout,
    feat_proj_dropout=feat_proj_dropout,
    mask_time_prob=mask_time_prob,
    layerdrop=layerdrop,
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    ignore_mismatched_sizes=True
)

model.freeze_feature_extractor()
model.gradient_checkpointing_enable()

Some weights of the model checkpoint at facebook/wav2vec2-xls-r-1b were not used when initializing Wav2Vec2ForCTC: ['project_q.bias', 'project_q.weight', 'project_hid.weight', 'quantizer.codevectors', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'project_hid.bias']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-1b and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for 

In [76]:
training_args = TrainingArguments(
  output_dir=out_dir,
  group_by_length=True,
  per_device_train_batch_size=batch_size,
  gradient_accumulation_steps=3,
  evaluation_strategy="steps",
  num_train_epochs=num_train_epochs,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  load_best_model_at_end=True,
  learning_rate=learning_rate,
  warmup_steps=warmup_steps,
  save_total_limit=2,
  metric_for_best_model="wer",
  greater_is_better=False
)

In [77]:
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset, 
    tokenizer=processor.feature_extractor,
)

Training with batch size of 8 and 3 epochs takes about _ minutes and 61338MiB (61 GB) on an NVIDIA-SMI. This model should be trained longer, we found that the loss stabalizes after around 10 epochs, so train for 10+ epochs.s

In [None]:
trainer.train()

Step,Training Loss,Validation Loss,Wer,Cer
100,3.0123,2.884076,0.965772,0.82875


In [None]:
preds = trainer.predict(test_dataset)
eval_preds = compute_metrics(preds)
eval_preds

It looks like we have a _% WER and _% CER. Let's add this result to our table. 

| Model | WER % | CER %|
|-------|-----|----|
|whisper-large-v2| 97.8| 40.5|
|mms-1b-all|29.3|7.7|
|finetuned whisper-large-v2|40.4|19.2|
|finetuned mms-1b-all|27.9|8.1|
|finetuned xls-r-1b|_|_|

## Section C: Further Improvements **(under construction)**

### Available scripts

For convenience, we have provided in this GitHub repo a finetuning script that enables the user to enter any FLUERS language, custom prepared dataset, and model training hyperparameters to do finetuning and evaluation all in one easy script (in progress).

#### Hyperparameter tuning

Using the scripts available in this GitHub repo, you can run your own experiments with different hyperparameters to see what gives the best model performance.

#### Adding more data

In order to use a dataset other than FLEURS, you must make sure to set up a Python script that has a function called `create_dataset()`. It must return three ASRDataset objects for the training, validation, and test set. The ASRDataset is available in the `utilities.py` script.

*Example custom dataset script:*

```
from utilities import ASRDataset 
def create_dataset() -> Tuple[ASRDataset]:
    # your code
    train_dataset = ASRDataset(audio_train, transcripts_train, sampling_rate, processor)
    val_dataset = ASRDataset(audio_val, transcripts_val, sampling_rate, processor)
    test_dataset = ASRDataset(audio_test, transcripts_test, sampling_rate, processor)

    return train_dataset, val_dataset, test_dataset
```