# Automatic Speech Recognition (ASR) Tutorial 

## Fine-tune a pretrained, multilingual ASR model on FLEURS

In this tutorial, we will be evaluating and improving a multilingual ASR model for a language in the FLEURS dataset. We will focus on **Hausa**, but you can follow along in any language in FLEURS. See the FLEURS dataset paper for a list of supported languages: https://arxiv.org/abs/2205.12446.

We will be looking at three major open-source ASR multilingual models:
* XLS-R  (https://arxiv.org/abs/2111.09296)
* Whisper (https://cdn.openai.com/papers/whisper.pdf)
* MMS (https://scontent-sjc3-1.xx.fbcdn.net/v/t39.8562-6/348827959_6967534189927933_6819186233244071998_n.pdf?_nc_cat=104&ccb=1-7&_nc_sid=ad8a9d&_nc_ohc=-JOSFMsFL-UAX-4O6o4&_nc_ht=scontent-sjc3-1.xx&oh=00_AfDdMFq0DP2xIRyjWpGrmIpqncnouiylLfWnFsAgxboLWw&oe=6497E242)

## Before you start: Setting up your coding environment 

You can run follow along and run the lines of code in this notebook, and also utilize the scripts found in this GitHub respository. Before starting this tutorial, you will need to create a virtual environment for this project so you can download all the required packages without affecting your other projects. We recommend using Anaconda (conda) to create a virtual environment. We have provided an `environment.yml` file that you can use to create a virtual environment named `asr` containing all the required packages. In your terminal run this code:

```
git clone https://github.com/kashrest/lrl-asr-experiments.git
cd lrl-asr-experiments
conda env create -f environment.yml 
conda activate asr
```

Note: The pretrained multilingual ASR models we will be using in this notebook require GPUs with at least 40 GB of space (check) for practical use. 

Now you can run the lines in this notebook

## Data Preprocessing

The first step is to download and prepare the data for the ASR model. Hugging Face has an easy way to download FLEURS data for any supported language, where the split can be specified

In [1]:
import os 
from datasets import load_dataset

cache_dir_fleurs = "./data/fleurs/"
out_dir = "./asr-tutorial-experiment-1/"

# create a data directory for caching 
try:
    os.mkdir(cache_dir_fleurs)
except:
    pass

# create a directory for outputs 
try:
    os.mkdir(out_dir)
except:
    pass
    
# for Hausa, the language code is "ha_ng"
train_data = load_dataset("google/fleurs", "ha_ng", split="train", cache_dir=cache_dir_fleurs)
val_data = load_dataset("google/fleurs", "ha_ng", split="validation", cache_dir=cache_dir_fleurs)
test_data = load_dataset("google/fleurs", "ha_ng", split="test", cache_dir=cache_dir_fleurs)

Found cached dataset fleurs (/data/users/kashrest/lrl-asr-experiments/data/fleurs/google___fleurs/ha_ng/2.0.0/af82dbec419a815084fa63ebd5d5a9f24a6e9acdf9887b9e3b8c6bbd64e0b7ac)
Found cached dataset fleurs (/data/users/kashrest/lrl-asr-experiments/data/fleurs/google___fleurs/ha_ng/2.0.0/af82dbec419a815084fa63ebd5d5a9f24a6e9acdf9887b9e3b8c6bbd64e0b7ac)
Found cached dataset fleurs (/data/users/kashrest/lrl-asr-experiments/data/fleurs/google___fleurs/ha_ng/2.0.0/af82dbec419a815084fa63ebd5d5a9f24a6e9acdf9887b9e3b8c6bbd64e0b7ac)


FLEURS data is organized like so

In [2]:
train_data[0]

{'id': 302,
 'num_samples': 301440,
 'path': '/data/users/kashrest/asr-experiments/downloads/extracted/953575415c600d2042020b042380119265aefaa4fcf95afc300adfcd2b79784e/10002175198254707815.wav',
 'audio': {'path': 'train/10002175198254707815.wav',
  'array': array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         6.78300858e-05, 1.26361847e-05, 6.46114349e-05]),
  'sampling_rate': 16000},
 'transcription': 'nasarorin da vautier ta samu marasa alaka da bada umarni ba sun hada da yajin cin abinci a 1973 a kan abin da ya ke ganin dabaibayin siyasa ne',
 'raw_transcription': 'Nasarorin da Vautier ta samu marasa alaka da bada umarni ba sun hada da yajin cin abinci a 1973 a kan abin da ya ke ganin dabaibayin siyasa ne.',
 'gender': 1,
 'lang_id': 30,
 'language': 'Hausa',
 'lang_group_id': 3}

We are interested in the audio (represented as an array of floats each describing the amplitude/loudness of the sound; the number of floats is determined by the sampling rate which here is 16,000 measurements per second) and the corresponding transcript

In [3]:
train_transcripts, val_transcripts, test_transcripts = [], [], []
train_audio, val_audio, test_audio = [], [], []

for elem in train_data:
    train_audio.append(elem["audio"]["array"])
    train_transcripts.append(elem["raw_transcription"])
    
for elem in val_data:
    val_audio.append(elem["audio"]["array"])
    val_transcripts.append(elem["raw_transcription"])
    
for elem in test_data:
    test_audio.append(elem["audio"]["array"])
    test_transcripts.append(elem["raw_transcription"])

Now, since we are interested in transcribing speech, we want to clean the transcripts by removing special characters that do not have a clear sound (such as ! '). This part may depend on your target application and language. For example for Hausa, many native speakers do not speak English and does not have much code-switching, so we want to also normalize any foreign characters (ç ş) and symbols (% & $).

In [4]:
import re
def preprocess_texts_hausa(transcriptions):
    chars_to_remove_regex = '[\,\?\!\-\;\:\"\“\%\‘\'\ʻ\”\�\$\&\(\)\–\—]'

    def _remove_special_characters(transcription):
        transcription = transcription.strip()
        transcription = transcription.lower()
        transcription = re.sub(chars_to_remove_regex, '', transcription)
        transcription = re.sub("\[\]\{\}", '', transcription)
        transcription = re.sub(r'[\\]', '', transcription)
        transcription = re.sub(r'[/]', '', transcription)
        transcription = re.sub(u'[¥£°¾½²]', '', transcription)
        transcription = re.sub(u'[\+><]', '', transcription)
        return transcription

    def _normalize_diacritics(transcription):
        a = '[āăáã]'
        u = '[ūúü]'
        o = '[öõó]' 
        c = '[ç]'
        i = '[í]'
        s = '[ş]'
        e = '[é]'

        transcription = re.sub(a, "a", transcription)
        transcription = re.sub(u, "u", transcription)
        transcription = re.sub(o, "o", transcription)
        transcription = re.sub(c, "c", transcription)
        transcription = re.sub(i, "i", transcription)
        transcription = re.sub(s, "s", transcription)
        transcription = re.sub(e, "e", transcription)

        return transcription

    cleaned_transcriptions = map(_remove_special_characters, transcriptions)
    cleaned_transcriptions = list(map(_normalize_diacritics, list(cleaned_transcriptions)))
    return cleaned_transcriptions

train_transcripts = preprocess_texts_hausa(train_transcripts)
val_transcripts = preprocess_texts_hausa(val_transcripts)
test_transcripts = preprocess_texts_hausa(test_transcripts)

Some models predict one character at a time, and so we need a character vocabulary made up of all characters in the dataset after preprocessing. We can save the vocabulary in a JSON file

In [5]:
import json 

def extract_all_chars(transcription):
      all_text = " ".join(transcription)
      vocab = list(set(all_text))
      return {"vocab": [vocab], "all_text": [all_text]}

vocab_train = list(map(extract_all_chars, train_transcripts))
vocab_val = list(map(extract_all_chars, val_transcripts))
vocab_test = list(map(extract_all_chars, test_transcripts))

vocab_train_chars = []
for elem in [elem["vocab"][0] for elem in vocab_train]:
    vocab_train_chars.extend(elem)

vocab_val_chars = []
for elem in [elem["vocab"][0] for elem in vocab_val]:
    vocab_val_chars.extend(elem)

vocab_test_chars = []
for elem in [elem["vocab"][0] for elem in vocab_test]:
    vocab_test_chars.extend(elem)

vocab_list = list(set(vocab_train_chars) | set(vocab_val_chars) | set(vocab_test_chars))
vocab_dict = {v: k for k, v in enumerate(vocab_list)}

# for word delimiter, change " " --> "|" (ex. "Hello my name is Bob" --> "Hello|my|name|is|Bob")
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict) # this is for CTC to predict the end of a character (e.g. "hhh[PAD]iii[PAD]iii[PAD]" == "hii")

with open(out_dir+"vocab_hausa.json", 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

## Evaluation code

In ASR, word error rate (WER) and character error rate (CER) are the common metrics used to evaluate how good a model-produced transcript is in comparison to the gold transcript.

In [6]:
from datasets import load_metric, load_dataset, Audio
wer_metric = load_metric("wer")
cer_metric = load_metric("cer")
                         
def compute_metrics(label_strs, pred_strs):
    # make sure labels are preprocessed the same way for proper comparison with other models after finetuning
    wer = wer_metric.compute(predictions=pred_strs, references=label_strs) * 100
    cer = cer_metric.compute(predictions=pred_strs, references=label_strs) * 100
    return {"wer": wer, "cer": cer}

  wer_metric = load_metric("wer")


## Section A: Zero-Shot ASR

Let's run inference on our audio dataset with different existing, ready-to-use pretrained and finetuned models. Later, when we compare model performance, we will compare performance on the test set since some models will be trained on the train split.

### OpenAI Whisper

OpenAI's Whisper model is a pretrained encoder-decoder model that supports a set of languages without futher fine-tuning. Here, we will use whisper-large V2 to compare it with MMS-1b-all (both have about 1 billion parameters). You can use the smaller checkpoints if you do not have enough GPU space (found on Hugging Face Hub: https://huggingface.co/openai/whisper-small) or decrease the batch size.

Note: Whisper requires that input is sampled at 16,000 Hz. FLEURS data is sampled at this rate, so we are good. Also, Whisper does not support all FLEURS languages (the 96 languages are listed in the paper)

In [7]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from tqdm import tqdm
import time

device = "cuda:0" # change this to a gpu if you have access to one, otherwise set to "cpu"
model_id = "openai/whisper-large-v2" # 8630MiB for batch size 1, ~25426MiB for batch size 10
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="hausa", task="transcribe")

predicted_test_transcripts = []

batch_size = 10 # decrease if needed

for i in tqdm(range(0, len(test_audio), batch_size)):
    batch = test_audio[i:i+batch_size] if i+batch_size <= len(test_audio) else test_audio[i:]
    input_features = processor(batch, sampling_rate=16000, return_tensors="pt").input_features.to(device)
    # generate token ids
    predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
    # decode token ids to text
    predicted_test_transcripts.extend(processor.batch_decode(predicted_ids, skip_special_tokens=True))

100%|████████████████████████████████████████████████████████████████████████████████| 63/63 [05:19<00:00,  5.07s/it]


In [8]:
len(predicted_test_transcripts), len(test_audio)

(621, 621)

Let's evaluate the performance of whisper-small on our preprocessed test dataset

In [9]:
len(test_audio)

621

In [10]:
compute_metrics(test_transcripts, predicted_test_transcripts)

{'wer': 97.78959507944643, 'cer': 40.52395399540877}

It looks like we have a 97.8% WER and 40.5% CER. Let's create a running table of the performances of different models on our test dataset. 

| Model | WER % | CER %|
|-------|-----|----|
|whisper-large-v2| 97.8| 40.5|

### Facebook MMS

In [11]:
import torch
from transformers import Wav2Vec2ForCTC, AutoProcessor

model_id = "facebook/mms-1b-all"

processor = AutoProcessor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id)

processor.tokenizer.set_target_lang("hau")
model.load_adapter("hau")

predicted_test_transcripts = []

for elem in tqdm(test_audio): # Todo: batch inference?
    inputs = processor(elem, sampling_rate=16_000, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs).logits
    ids = torch.argmax(outputs, dim=-1)[0]
    predicted_test_transcripts.append(processor.decode(ids))

100%|██████████████████████████████████████████████████████████████████████████████| 621/621 [28:10<00:00,  2.72s/it]


In [12]:
compute_metrics(test_transcripts, predicted_test_transcripts)

{'wer': 29.279856483854434, 'cer': 7.700116511126236}

It looks like we have a 29.3% WER and 7.7% CER. Let's add this result to our table. 

| Model | WER % | CER %|
|-------|-----|----|
|whisper-large-v2| 97.8| 40.5|
|mms-1b-all|29.3|7.7|

## Section B: Finetuning

For fine-tuning, we will be using functions from the Hugging Face API for the training loop and model setup. In order to use these functions, we need to wrap the data in a custom PyTorch Dataset object.

In [13]:
class ASRDataset(torch.utils.data.Dataset):
    def __init__(self, audio, transcripts, sampling_rate, processor):#feature_extractor, tokenizer):
        self.audio = audio
        self.transcripts = transcripts
        self.sampling_rate = sampling_rate
        self.processor = processor
        """self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor"""
    def __getitem__(self, idx):
        input_values = self.processor.feature_extractor(self.audio[idx], sampling_rate=self.sampling_rate).input_features[0]
        labels = self.processor.tokenizer(self.transcripts[idx]).input_ids
        
        item = {}
        item["input_features"] = input_values
        item["labels"] = labels
        
        return item

    def __len__(self):
        return len(self.transcripts)

### Whisper

First, let's import the required Whisper classes and training loop functions from Hugging Face and some other utility functions

In [14]:
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import evaluate

Let's fine-tune a Whisper-large V2 model which we will download from Hugging Face, and set up a WhisperProcessor object which takes contains a feature extractor and a tokenizer. The feature extractor extracts features from the input, which in this case is just normalizing the input array. The tokenizer splits the transcripts into tokens based on Whisper's vocabulary. Whisper utilizes byte-level BPE, which is the same tokenizer as GPT-2. If interested, refer to this page: https://huggingface.co/learn/nlp-course/chapter6/5?fw=pt

In [15]:
model_card = "openai/whisper-large-v2"
processor = WhisperProcessor.from_pretrained(model_card, language="Hausa", task="transcribe")

Like mentioned before, Whisper takes inputs sampled at 16,000 Hz, and so we will prepare our data using this sampling rate (Fleurs is already sampled at 16 kHz) using the ASRDataset object mentioned before

In [16]:
model_sampling_rate = 16000
train_dataset = ASRDataset(train_audio, train_transcripts, model_sampling_rate, processor)
val_dataset = ASRDataset(val_audio,  val_transcripts, model_sampling_rate, processor)
test_dataset = ASRDataset(test_audio, test_transcripts, model_sampling_rate, processor)

Next, we need a functio

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch
    
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

metric_wer = evaluate.load("wer")
metric_cer = evaluate.load("cer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric_wer.compute(predictions=pred_str, references=label_str)
    cer = 100 * metric_cer.compute(predictions=pred_str, references=label_str)

    return {"wer": wer, "cer": cer}


model = WhisperForConditionalGeneration.from_pretrained(model_card)

model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

training_args = Seq2SeqTrainingArguments(
    output_dir=out_dir,  # change to a repo name of your choice
    per_device_train_batch_size=train_batch_size,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=learning_rate,
    warmup_steps=500,
    num_train_epochs=num_train_epochs,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=100,
    eval_steps=100,
    save_total_limit=2, 
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

start = time.time()
trainer.train()
end = time.time()


### MMS

### XLS-R

## Section C: Further Improvements

### Hyperparameter tuning

### Adding more data