# Fine-Tune Whisper

The Whisper checkpoints come in five configurations of varying model sizes.
The smallest four are trained on either English-only or multilingual data.
The largest checkpoint is multilingual only. All nine of the pre-trained checkpoints 
are available on the [Hugging Face Hub](https://huggingface.co/models?search=openai/whisper). The 
checkpoints are summarised in the following table with links to the models on the Hub:

| Size   | Layers | Width | Heads | Parameters | English-only                                         | Multilingual                                      |
|--------|--------|-------|-------|------------|------------------------------------------------------|---------------------------------------------------|
| tiny   | 4      | 384   | 6     | 39 M       | [✓](https://huggingface.co/openai/whisper-tiny.en)   | [✓](https://huggingface.co/openai/whisper-tiny.)  |
| base   | 6      | 512   | 8     | 74 M       | [✓](https://huggingface.co/openai/whisper-base.en)   | [✓](https://huggingface.co/openai/whisper-base)   |
| small  | 12     | 768   | 12    | 244 M      | [✓](https://huggingface.co/openai/whisper-small.en)  | [✓](https://huggingface.co/openai/whisper-small)  |
| medium | 24     | 1024  | 16    | 769 M      | [✓](https://huggingface.co/openai/whisper-medium.en) | [✓](https://huggingface.co/openai/whisper-medium) |
| large  | 32     | 1280  | 20    | 1550 M     | x                                                    | [✓](https://huggingface.co/openai/whisper-large)  |


## Prepare Environment

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from pathlib import Path

import torch
import wandb
import huggingface_hub

from transformers import (
    pipeline,
    WhisperTokenizer,
    WhisperFeatureExtractor,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    EarlyStoppingCallback,
    IntervalStrategy
)

import datasets
import evaluate
import gradio as gr

## Load Dataset

Using 🤗 Datasets, downloading and preparing data is extremely simple. 
We can download and prepare the Common Voice splits in just one line of code. 

First, ensure you have accepted the terms of use on the Hugging Face Hub: [mozilla-foundation/common_voice_11_0](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0). Once you have accepted the terms, you will have full access to the dataset and be able to download the data locally.

We'll combine the `train` and `validation` to make more training data. (It is maybe not needed but it was in the original notebook)

In [None]:
LANG = "cs"
LANG_LONG = "czech"
MODEL_SIZE = "tiny"
PRETRAINED_MODEL_NAME = f"openai/whisper-{MODEL_SIZE}"
FINETUNED_MODEL_NAME_HUMAN = f"Whisper {MODEL_SIZE} {LANG}"
FINETUNED_MODEL_NAME = FINETUNED_MODEL_NAME_HUMAN.lower().replace(" ", "-")
WANDB_PROJECT_NAME = f"whisper-{LANG}"
FINETUNED_MODEL_NAME

In [None]:
huggingface_hub.notebook_login()

In [None]:
wandb.login()

## Prepare Feature Extractor, Tokenizer and Data

The ASR pipeline can be de-composed into three stages: 
1) A feature extractor which pre-processes the raw audio-inputs
2) The model which performs the sequence-to-sequence mapping 
3) A tokenizer which post-processes the model outputs to text format

In 🤗 Transformers, the Whisper model has an associated feature extractor and tokenizer, 
called [WhisperFeatureExtractor](https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperFeatureExtractor)
and [WhisperTokenizer](https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperTokenizer) 
respectively.

We'll go through details for setting-up the feature extractor and tokenizer one-by-one!

### Load WhisperFeatureExtractor

The Whisper feature extractor performs two operations:
1. Pads / truncates the audio inputs to 30s: any audio inputs shorter than 30s are padded to 30s with silence (zeros), and those longer that 30s are truncated to 30s
2. Converts the audio inputs to _log-Mel spectrogram_ input features, a visual representation of the audio and the form of the input expected by the Whisper model

### Load Tokenizer, FeatureExtractor, Processor

In [None]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(PRETRAINED_MODEL_NAME)
tokenizer = WhisperTokenizer.from_pretrained(PRETRAINED_MODEL_NAME, language=LANG_LONG, task="transcribe")
processor = WhisperProcessor.from_pretrained(PRETRAINED_MODEL_NAME, language=LANG_LONG, task="transcribe")

### Prepare Data

In [None]:
def preprocess(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch


data_path = Path(f"../data/asr_common_voice_11_for_{FINETUNED_MODEL_NAME}")


if not data_path.exists():
    common_voice = datasets.DatasetDict()

    common_voice["train"] = datasets.load_dataset("mozilla-foundation/common_voice_11_0", LANG, split="train", use_auth_token=True)
    common_voice["validation"] = datasets.load_dataset("mozilla-foundation/common_voice_11_0", LANG, split="validation", use_auth_token=True)
    common_voice["test"] = datasets.load_dataset("mozilla-foundation/common_voice_11_0", LANG, split="test", use_auth_token=True)

    common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
    common_voice = common_voice.cast_column("audio", datasets.Audio(sampling_rate=16000))
    common_voice = common_voice.map(preprocess, remove_columns=common_voice.column_names["train"], num_proc=4)

    common_voice.save_to_disk(data_path)


In [None]:
def load_locally_custom_split(data_path: Path):

    ds = datasets.concatenate_datasets([
        datasets.load_from_disk(data_path / "train"),
        datasets.load_from_disk(data_path / "validation"),
        datasets.load_from_disk(data_path / "test"),
    ])
    
    train_testvalid = ds.train_test_split(test_size=2500, seed=0)
    valid_test = train_testvalid['test'].train_test_split(test_size=2000, seed=0)

    common_voice = datasets.DatasetDict({
        'train': train_testvalid['train'],
        'validation': valid_test['train'],
        'test': valid_test['test'],
    })
    return common_voice

common_voice = load_locally_custom_split(data_path)
print(common_voice)

## Training and Evaluation

### Define a Data Collator

The data collator for a sequence-to-sequence speech model is unique in the sense that it 
treats the `input_features` and `labels` independently: the  `input_features` must be 
handled by the feature extractor and the `labels` by the tokenizer.

The `input_features` are already padded to 30s and converted to a log-Mel spectrogram 
of fixed dimension by action of the feature extractor, so all we have to do is convert the `input_features`
to batched PyTorch tensors. We do this using the feature extractor's `.pad` method with `return_tensors=pt`.

The `labels` on the other hand are un-padded. We first pad the sequences
to the maximum length in the batch using the tokenizer's `.pad` method. The padding tokens 
are then replaced by `-100` so that these tokens are **not** taken into account when 
computing the loss. We then cut the BOS token from the start of the label sequence as we 
append it later during training.

We can leverage the `WhisperProcessor` we defined earlier to perform both the 
feature extractor and the tokenizer operations:

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

### Evaluation Metrics

We'll use the word error rate (WER) metric

We then simply have to define a function that takes our model 
predictions and returns the WER metric. This function, called
`compute_metrics`, first replaces `-100` with the `pad_token_id`
in the `label_ids` (undoing the step we applied in the 
data collator to ignore padded tokens correctly in the loss).
It then decodes the predicted and label ids to strings. Finally,
it computes the WER between the predictions and reference labels:

In [None]:
metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

### Load a Pre-Trained Checkpoint

In [None]:
model = WhisperForConditionalGeneration.from_pretrained(PRETRAINED_MODEL_NAME)

Override generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)):

In [None]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

### Define the Training Configuration

In the final step, we define all the parameters related to training. For more detail on the training arguments, refer to the Seq2SeqTrainingArguments [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments).

In [None]:
wandb_run = wandb.init(project=WANDB_PROJECT_NAME)

In [None]:
early_stopping = EarlyStoppingCallback(early_stopping_patience=5)

training_args = Seq2SeqTrainingArguments(
    output_dir=f"../models/{FINETUNED_MODEL_NAME}/{wandb_run.name}",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=1,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=50000,
    gradient_checkpointing=False,
    fp16=True,
    group_by_length=True,
    evaluation_strategy=IntervalStrategy.STEPS,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=50,
    report_to=["wandb"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

In [None]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=common_voice["train"],
    eval_dataset=common_voice["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
    callbacks=[early_stopping]
)

### Training

Training took approximately 5-10 hours on Hindi (8hr dataset) on colab.

In [None]:
trainer.train()

In [None]:
model.eval()
trainer.evaluate(common_voice["test"])

## Building a Demo

Now that we've fine-tuned our model we can build a demo to show 
off its ASR capabilities! We'll make use of 🤗 Transformers 
`pipeline`, which will take care of the entire ASR pipeline, 
right from pre-processing the audio inputs to decoding the 
model predictions.

Running the example below will generate a Gradio demo where we 
can record speech through the microphone of our computer and input it to 
our fine-tuned Whisper model to transcribe the corresponding text:

In [None]:
pipe = pipeline(
    task="automatic-speech-recognition",
    model=f"../models/{FINETUNED_MODEL_NAME}/{wandb_run.name}",
    device=0,
)

def transcribe(audio):
    text = pipe(audio)["text"]
    return text

iface = gr.Interface(
    fn=transcribe, 
    inputs=gr.Audio(source="microphone", type="filepath"), 
    outputs="text",
    title=FINETUNED_MODEL_NAME_HUMAN,
    description=f"Realtime demo for {LANG} speech recognition using a fine-tuned Whisper {MODEL_SIZE} model.",
)

iface.launch()