# Whisper Fine-Tuning on LibriSpeech with Optuna

This notebook is modifed from Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers Colab created by Sanchit Gandhi avalible at https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/fine_tune_whisper.ipynb. Our work is only the modifications to the original notebook.

In [None]:
# use datasets to download and prepare our training data and transformers to load and train our Whisper model.
!pip install datasets>=2.6.1
!pip install git+https://github.com/huggingface/transformers
!pip install librosa
!pip install evaluate>=0.30
!pip install jiwer
!pip install gradio
!pip install optuna


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-4eegb0gf
  Running command git clone -q https://github.com/huggingface/transformers /tmp/pip-req-build-4eegb0gf
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 31.0 MB/s 
Building wheels for collected packages: transformers
  Building wheel for transformers (PEP 517) ... [?25l[?25hdone
  Created wheel for transformers: filename=transformers-4.26.0.dev0-py3-none-any.whl size=5881991 sha256=7046b10cf4f764a4ee2d0060f559fcd330fd6eccb00d48d96d79c8e19a604f10


In [None]:
# import the relavant libraries for loggin in
from huggingface_hub import HfApi, HfFolder

# set api for login and save token\
api=HfApi()
api.set_access_token(token)
folder = HfFolder()
folder.save_token(token)



In [None]:
import pickle
from datasets import Audio
from datasets import Dataset
from datasets import Features

In [None]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny.en")

Downloading:   0%|          | 0.00/185k [00:00<?, ?B/s]

In [None]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en", language="English", task="transcribe")

Downloading:   0%|          | 0.00/844 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/999k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.08k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

In [None]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en", language="English", task="transcribe")

In [None]:
from datasets import load_dataset
from datasets import DownloadConfig
# link = "DTU54DL/commo-test1k-whisper-proc"
ds_train = load_dataset(r"CristianaLazar/librispeech5k_train", download_config=DownloadConfig(delete_extracted=True))
ds_test = load_dataset(r"CristianaLazar/librispeech_augm_validation-tiny", download_config=DownloadConfig(delete_extracted=True))

Downloading readme:   0%|          | 0.00/687 [00:00<?, ?B/s]



Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/CristianaLazar___parquet/CristianaLazar--librispeech5k_train-d689137cca331e69/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/284M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/269M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/281M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/283M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/269M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/290M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/282M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/293M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/303M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/284M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/271M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/283M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/301M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/294M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train.360 split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/CristianaLazar___parquet/CristianaLazar--librispeech5k_train-d689137cca331e69/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
train = ds_train["train.360"].select(range(0, 500))
test = ds_test["validation"].select(range(0, 100))

In [None]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch
        
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [None]:
import evaluate

metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [None]:
def optuna_hp_space(trial):

    return {

        "learning_rate": trial.suggest_float("learning_rate", 1e-5, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 0.1, 0.2, log=True),

        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [8]),

    }

In [None]:
# load the pre-trained Whisper small checkpoint.
from transformers import WhisperForConditionalGeneration

def model_init(trial):
    model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
    model.config.forced_decoder_ids = None
    model.config.suppress_tokens = []

    return model

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="/whisper-tiny-libri_search",  # change to a repo name of your choice
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    warmup_steps=2,
    max_steps=500,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=50,
    eval_steps=50,
    logging_steps=50,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
)

PyTorch: setting up devices


In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=None,
    train_dataset=train,
    eval_dataset=test,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
    model_init=model_init,
    
)

/whisper-tiny-libri_search is already a clone of https://huggingface.co/CristianaLazar/whisper-tiny-libri_search. Make sure you pull the latest changes with `repo.git_pull()`.
max_steps is given, it will override any value given in num_train_epochs
Using cuda_amp half precision backend


In [None]:
best_trial = trainer.hyperparameter_search(
    direction="maximize",
    backend="optuna",
    hp_space=optuna_hp_space,
    n_trials=4,
)

`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss,Wer
50,0.9048,0.747488,28.880266
100,0.2567,0.685584,25.332594
150,0.1284,0.689597,23.946785
200,0.0582,0.682746,24.223947
250,0.0301,0.720203,24.390244
300,0.0167,0.705789,24.168514
350,0.0098,0.707749,22.89357
400,0.0056,0.697938,23.004435
450,0.0046,0.705192,22.67184
500,0.0038,0.705268,22.67184


`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient c

Step,Training Loss,Validation Loss,Wer
50,0.9272,0.790834,27.549889
100,0.3103,0.645642,24.05765
150,0.1622,0.638435,24.05765
200,0.0823,0.637445,23.392461
250,0.0443,0.635826,23.281596
300,0.0225,0.648954,24.501109
350,0.0163,0.657079,24.279379
400,0.0115,0.661503,24.778271
450,0.0095,0.665024,24.279379
500,0.0083,0.665801,24.611973


`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient c

Step,Training Loss,Validation Loss,Wer
50,0.8687,0.734508,27.494457
100,0.2591,0.661372,25.388027
150,0.1315,0.672508,25.277162
200,0.0576,0.66365,24.279379
250,0.0291,0.66749,23.946785
300,0.0162,0.672271,23.115299
350,0.0105,0.681922,23.226164
400,0.0065,0.685979,22.89357
450,0.0056,0.689046,22.89357
500,0.0048,0.690212,23.004435


`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient c

Step,Training Loss,Validation Loss,Wer
50,0.9363,0.801701,28.991131
100,0.3215,0.643573,23.946785
150,0.1688,0.636789,23.558758
200,0.0884,0.634394,24.002217
250,0.0474,0.635364,23.614191
300,0.0249,0.646138,24.390244
350,0.0174,0.655878,23.83592
400,0.0125,0.66026,24.889135
450,0.0104,0.663743,24.889135
500,0.009,0.664997,25.110865


`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
`use_cache = True` is incompatible with gradient c