# Whisper Fine-tuning on LibriSpeech EN

This notebook is modifed from *Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers* Colab created by Sanchit Gandhi avalible at https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/fine_tune_whisper.ipynb. Our work is only the modifications to the original notebook.

# GPU

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Wed Dec  7 13:14:43 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   69C    P0    30W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Imports

In [None]:
# use datasets to download and prepare our training data and transformers to load and train our Whisper model.
!pip install datasets>=2.6.1
!pip install git+https://github.com/huggingface/transformers
!pip install librosa
!pip install evaluate>=0.30
!pip install jiwer
!pip install gradio

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-oqf_b7uw
  Running command git clone -q https://github.com/huggingface/transformers /tmp/pip-req-build-oqf_b7uw
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 15.1 MB/s 
Building wheels for collected packages: transformers
  Building wheel for transformers (PEP 517) ... [?25l[?25hdone
  Created wheel for transformers: filename=transformers-4.26.0.dev0-py3-none-any.whl size=5931315 sha256=d400e2855e48b462862f4079d67ae6a372c4e9cfec8f8621b65892618469b6cd


In [None]:
# upload model checkpoints directly the Hugging Face Hub

from huggingface_hub import notebook_login

notebook_login()

Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.huggingface/token
Login successful


In [None]:
import pickle
from datasets import Audio
from datasets import Dataset
from datasets import Features

# Load Data

 ## Load WhisperFeatureExtractor
 load feature extractor from the pre-trained checkpoint with default values

In [None]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")

Downloading:   0%|          | 0.00/185k [00:00<?, ?B/s]

## Load WhisperTokenizer

Whisper model outputs a sequence of token ids. 

The tokenizer maps each of these token ids to their corresponding text string. 

We will load the pre-trained tokenizer and use it for fine-tuning without any further modifications.

In [None]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="English", task="transcribe")

Downloading:   0%|          | 0.00/828 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/494k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.11k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

## Combine To Create A WhisperProcessor

In [None]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="English", task="transcribe")

## Load DataSet from Hub

In [None]:
from datasets import load_dataset
from datasets import DownloadConfig
librispeech = load_dataset("bgstud/libri-mini-proc-whisper",
                           download_config=DownloadConfig(delete_extracted=True))
librispeech

Downloading metadata:   0%|          | 0.00/2.08k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/3.04k [00:00<?, ?B/s]



Downloading and preparing dataset librispeech_asr/clean (download: 3.06 GiB, generated: 5.79 GiB, post-processed: Unknown size, total: 8.85 GiB) to /root/.cache/huggingface/datasets/bgstud___parquet/bgstud--libri-mini-proc-whisper-2fc5e0e6ea4e6da5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/131M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/135M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/157M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/272M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/273M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/276M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/274M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/268M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/271M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/270M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/256M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/265M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/153M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/149M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/133M [00:00<?, ?B/s]

   

Extracting data files #0:   0%|          | 0/1 [00:00<?, ?obj/s]

Extracting data files #2:   0%|          | 0/1 [00:00<?, ?obj/s]

Extracting data files #1:   0%|          | 0/1 [00:00<?, ?obj/s]

Generating validation split:   0%|          | 0/901 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/874 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/bgstud___parquet/bgstud--libri-mini-proc-whisper-2fc5e0e6ea4e6da5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    validation: Dataset({
        features: ['audio', 'sentence', 'input_features', 'labels'],
        num_rows: 901
    })
    train: Dataset({
        features: ['audio', 'sentence', 'input_features', 'labels'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['audio', 'sentence', 'input_features', 'labels'],
        num_rows: 874
    })
})

# Training and Evaluation

We'll follow these steps:

* Define **data collator**: data collator takes pre-processed data and prepares PyTorch tensors ready for the model.
* **Evaluation metrics**: during evaluation, we evaluate the model using WER metric. We need to define a compute_metrics function that handles this computation.
* **Load pre-trained checkpoint**: load a pre-trained checkpoint and configure it correctly for training
* Define **training configuration**: this will be used by **Trainer** to define the training schedule.

After tuning the model, we evaluate it on test data to verify that we have correctly trained it to transcribe speech.

## Define Data Collator

In [None]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

Initialise the defined data collator :

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

## Evaluation Metrics

In [None]:
import evaluate

metric = evaluate.load("wer")

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

Define function that takes model predictions and returns the WER metric.

* It first replaces -100 with the pad_token_id in the label_ids (undoing the step we applied in the data collator to ignore padded tokens correctly in the loss).

* It then decodes the predicted and label ids to strings. 

* Finally, it computes the WER between the predictions and reference labels:

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

## Load a Pre-Trained Checkpoint 

In [None]:
# load the pre-trained Whisper small checkpoint.
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

Downloading:   0%|          | 0.00/1.96k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/151M [00:00<?, ?B/s]

Override generation arguments - no tokens are forced as decoder outputs (see forced_decoder_ids), no tokens are suppressed during generation (see suppress_tokens):

In [None]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

## Define Training Configuration

**Final step**: define all parameters related to training.

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-ft-libri-en",  # change to a repo name of your choice
    per_device_train_batch_size=8, # paper: 256, orig: 16
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=7.740176574997311e-05, # paper: 5e-4, orig: 1e-5
    warmup_steps=2, # paper: 2048, orig: 2
    max_steps=400, # paper: 1048576, orig: 100
    gradient_checkpointing=True,
    fp16=True,
    group_by_length=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=20, 
    eval_steps=5,
    logging_steps=1,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    seed = 42,
    push_to_hub=True,
    # paper below
    optim="adamw_torch",
    weight_decay = 0.1146914, 
    #adam_bata1 = 0.9,
    adam_beta2 = 0.98,
    adam_epsilon = 1e-6,
    max_grad_norm = 1.0
)

# Note: 'paper' refers to https://arxiv.org/abs/2212.04356 and 'orig' to original
# verson of this notebook. Our values are the best tuning we could achieve (with
# our resource limitaitons).

**Note**: if one does not want to upload the model checkpoints to the Hub, set push_to_hub=False.

Forward training arguments to Trainer along with model,
dataset, data collator and `compute_metrics` function

In [None]:
num_shards = 30
print(librispeech['test'].shard(num_shards=num_shards, index=0))
print(librispeech['train'].shard(num_shards=num_shards, index=0))

Dataset({
    features: ['audio', 'sentence', 'input_features', 'labels'],
    num_rows: 30
})
Dataset({
    features: ['audio', 'sentence', 'input_features', 'labels'],
    num_rows: 100
})


In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=librispeech['train'].shard(num_shards=num_shards, index=0),
    eval_dataset=librispeech['test'].shard(num_shards=num_shards, index=0),
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

Cloning https://huggingface.co/garnagar/whisper-ft-libri-en into local empty directory.


Download file pytorch_model.bin:   0%|          | 3.43k/144M [00:00<?, ?B/s]

Download file runs/Dec04_17-21-41_a485f772140b/events.out.tfevents.1670174591.a485f772140b.75.0:  16%|#6      …

Download file runs/Dec05_19-33-49_46abaaa366bc/events.out.tfevents.1670269006.46abaaa366bc.75.0:  14%|#3      …

Download file runs/Dec06_07-13-01_923b09607497/events.out.tfevents.1670310937.923b09607497.74.0:  63%|######3 …

Download file runs/Dec07_12-39-45_f10204151637/events.out.tfevents.1670416871.f10204151637.76.0:   5%|5       …

Download file runs/Dec05_23-56-47_3f4bd108d151/events.out.tfevents.1670284938.3f4bd108d151.75.0:  14%|#3      …

Download file runs/Dec06_11-48-29_5b28b3d76c0c/events.out.tfevents.1670327388.5b28b3d76c0c.75.0:  39%|###9    …

Clean file runs/Dec04_17-21-41_a485f772140b/events.out.tfevents.1670174591.a485f772140b.75.0:   5%|4         |…

Clean file runs/Dec05_19-33-49_46abaaa366bc/events.out.tfevents.1670269006.46abaaa366bc.75.0:   4%|3         |…

Download file runs/Dec04_16-30-11_922cfba5e7f4/events.out.tfevents.1670171526.922cfba5e7f4.76.2:  14%|#3      …

Clean file runs/Dec06_07-13-01_923b09607497/events.out.tfevents.1670310937.923b09607497.74.0:   4%|3         |…

Clean file runs/Dec05_23-56-47_3f4bd108d151/events.out.tfevents.1670284938.3f4bd108d151.75.0:   4%|3         |…

Clean file runs/Dec06_11-48-29_5b28b3d76c0c/events.out.tfevents.1670327388.5b28b3d76c0c.75.0:   5%|4         |…

Clean file runs/Dec04_16-30-11_922cfba5e7f4/events.out.tfevents.1670171526.922cfba5e7f4.76.2:   4%|3         |…

Clean file runs/Dec07_12-39-45_f10204151637/events.out.tfevents.1670416871.f10204151637.76.0:   2%|1         |…

Download file runs/Dec05_19-33-49_46abaaa366bc/1670269006.350258/events.out.tfevents.1670269006.46abaaa366bc.7…

Download file runs/Dec06_12-31-05_38c8bf087cc2/events.out.tfevents.1670329950.38c8bf087cc2.74.0:  25%|##4     …

Clean file runs/Dec05_19-33-49_46abaaa366bc/1670269006.350258/events.out.tfevents.1670269006.46abaaa366bc.75.1…

Download file runs/Dec06_13-00-27_13efa48e4079/events.out.tfevents.1670331709.13efa48e4079.74.0:  27%|##6     …

Clean file runs/Dec06_12-31-05_38c8bf087cc2/events.out.tfevents.1670329950.38c8bf087cc2.74.0:   7%|7         |…

Download file runs/Dec06_11-48-29_5b28b3d76c0c/1670327388.827646/events.out.tfevents.1670327388.5b28b3d76c0c.7…

Clean file runs/Dec06_13-00-27_13efa48e4079/events.out.tfevents.1670331709.13efa48e4079.74.0:   8%|7         |…

Clean file runs/Dec06_11-48-29_5b28b3d76c0c/1670327388.827646/events.out.tfevents.1670327388.5b28b3d76c0c.75.1…

Download file runs/Dec04_17-21-41_a485f772140b/1670174591.6205049/events.out.tfevents.1670174591.a485f772140b.…

Clean file runs/Dec04_17-21-41_a485f772140b/1670174591.6205049/events.out.tfevents.1670174591.a485f772140b.75.…

Download file runs/Dec06_12-31-05_38c8bf087cc2/1670329950.3353913/events.out.tfevents.1670329950.38c8bf087cc2.…

Download file runs/Dec06_07-13-01_923b09607497/1670310937.3624523/events.out.tfevents.1670310937.923b09607497.…

Clean file runs/Dec06_12-31-05_38c8bf087cc2/1670329950.3353913/events.out.tfevents.1670329950.38c8bf087cc2.74.…

Clean file runs/Dec06_07-13-01_923b09607497/1670310937.3624523/events.out.tfevents.1670310937.923b09607497.74.…

Download file runs/Dec05_23-56-47_3f4bd108d151/1670284938.5834782/events.out.tfevents.1670284938.3f4bd108d151.…

Clean file runs/Dec05_23-56-47_3f4bd108d151/1670284938.5834782/events.out.tfevents.1670284938.3f4bd108d151.75.…

Download file runs/Dec06_13-00-27_13efa48e4079/1670331709.3417447/events.out.tfevents.1670331709.13efa48e4079.…

Download file runs/Dec04_16-30-11_922cfba5e7f4/1670171526.4485178/events.out.tfevents.1670171526.922cfba5e7f4.…

Clean file runs/Dec06_13-00-27_13efa48e4079/1670331709.3417447/events.out.tfevents.1670331709.13efa48e4079.74.…

Clean file runs/Dec04_16-30-11_922cfba5e7f4/1670171526.4485178/events.out.tfevents.1670171526.922cfba5e7f4.76.…

Download file runs/Dec07_12-39-45_f10204151637/1670416871.3663707/events.out.tfevents.1670416871.f10204151637.…

Clean file runs/Dec07_12-39-45_f10204151637/1670416871.3663707/events.out.tfevents.1670416871.f10204151637.76.…

Download file training_args.bin:  99%|#########8| 3.44k/3.48k [00:00<?, ?B/s]

Clean file training_args.bin:  29%|##8       | 1.00k/3.48k [00:00<?, ?B/s]

Clean file pytorch_model.bin:   0%|          | 1.00k/144M [00:00<?, ?B/s]

max_steps is given, it will override any value given in num_train_epochs
Using cuda_amp half precision backend


## Training

Training will take approx 5-10 hours depending on GPU / the one allocated to this Google Colab. If using this Google Colab directly to fine-tune a Whisper model, you should make sure that training isn't interrupted due to inactivity. 

Simple workaround to prevent this is to paste the following code into the console of this tab (right mouse click -> inspect -> Console tab -> insert code).

```javascript
function ConnectButton(){
    console.log("Connect pushed"); 
    document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click() 
}
setInterval(ConnectButton, 60000);
```

In [None]:
train_result = trainer.train()

The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: sentence, audio. If sentence, audio are not expected by `WhisperForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 100
  Num Epochs = 31
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 400
  Number of trainable parameters = 37760640
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss,Wer
5,2.1717,2.170907,98.046181
10,1.2371,1.271872,79.928952
15,0.7577,1.051027,35.346359
20,0.5325,0.947496,32.68206
25,0.5545,0.860749,30.373002
30,0.2957,0.805133,33.39254
35,0.1846,0.748742,30.195382
40,0.0748,0.688234,32.149201
45,0.0709,0.669224,31.261101
50,0.0908,0.646474,29.484902


`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: sentence, audio. If sentence, audio are not expected by `WhisperForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 30
  Batch size = 8
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_c

### Push results to Hub

In [None]:
trainer.push_to_hub("whisper-ft-libri-en")

Saving model checkpoint to ./whisper-ft-libri-en
Configuration saved in ./whisper-ft-libri-en/config.json
Model weights saved in ./whisper-ft-libri-en/pytorch_model.bin
Feature extractor saved in ./whisper-ft-libri-en/preprocessor_config.json
Several commits (2) will be pushed upstream.
The progress bars may be unreliable.


Upload file runs/Dec07_13-20-37_ff21180bcaab/events.out.tfevents.1670419314.ff21180bcaab.75.0:   4%|3         …

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/garnagar/whisper-ft-libri-en
   ec5352d..cd14559  main -> main

remote: LFS file scan complete.        
To https://huggingface.co/garnagar/whisper-ft-libri-en
   ec5352d..cd14559  main -> main

To https://huggingface.co/garnagar/whisper-ft-libri-en
   cd14559..929b19c  main -> main

   cd14559..929b19c  main -> main



'https://huggingface.co/garnagar/whisper-ft-libri-en/commit/cd14559e6154602ad5052de0af175c33672f52bc'