# Whisper Fine Tuning

In this notebook, we will fine tune an openai whisper model on Irish accent data.

Heavily based on: https://huggingface.co/blog/fine-tune-whisper

In [1]:
import os
from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor
from datasets import (
    Audio,
    load_dataset,
    load_dataset_builder,
    load_from_disk,
    concatenate_datasets,
    DatasetDict,
)
from huggingface_hub import notebook_login

In [2]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
# This is approx 80Gb, download at your peril!

if not os.path.exists('irish-accent-common-voice-test'):
    common_voice = load_dataset(
        "mozilla-foundation/common_voice_11_0", "en", use_auth_token=True, cache_dir='/mnt/vol_c/huggingface/'
    )

    print(common_voice)
else:
    print('Data was already processed!')

Data was already processed!


## Filter to Irish Accent Data

Here we are going to filter the large common voice dataset to just those rows that have an Irish accent speaker. The overall aim being to produce a model that has a better performance on Irish accents.

In [4]:
if not os.path.exists('irish-accent-common-voice-test'):
    filtered_train = common_voice['train'].filter(
        lambda x: 'irish' in x['accent'].lower() or 'ireland' in x['accent'].lower()
    )
    filtered_val = common_voice['validation'].filter(
        lambda x: 'irish' in x['accent'].lower() or 'ireland' in x['accent'].lower()
    )
    filtered_test = common_voice['test'].filter(
        lambda x: 'irish' in x['accent'].lower() or 'ireland' in x['accent'].lower()
    )
    print(len(filtered_train), len(filtered_val), len(filtered_test))
else:
    print('Data was already processed!')

Data was already processed!


In [5]:
if not os.path.exists('irish-accent-common-voice-test'):

    # Merge dataseto rebalance train vs test
    merged_dataset = concatenate_datasets([filtered_train, filtered_test, filtered_val])
    merged_dataset

In [6]:
if not os.path.exists('irish-accent-common-voice-test'):

    split_dataset = merged_dataset.train_test_split(0.2)
    split_dataset

In [7]:
if not os.path.exists('irish-accent-common-voice-test'):
    split_dataset['train'].save_to_disk('irish-accent-common-voice-train')
    split_dataset['test'].save_to_disk('irish-accent-common-voice-test')
else:
    split_dataset = DatasetDict()
    split_dataset['train'] = load_from_disk('irish-accent-common-voice-train')
    split_dataset['test'] = load_from_disk('irish-accent-common-voice-test')

In [8]:
split_dataset = split_dataset.remove_columns(
    ["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"]
)
split_dataset = split_dataset.cast_column("audio", Audio(sampling_rate=16000))

In [9]:
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="English", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="English", task="transcribe")

In [10]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

In [11]:
split_dataset = split_dataset.map(prepare_dataset, remove_columns=split_dataset.column_names["train"], num_proc=4)

Loading cached processed dataset at /home/ubuntu/irish-accent-common-voice-train/cache-566430680320d43a_*_of_00004.arrow
Loading cached processed dataset at /home/ubuntu/irish-accent-common-voice-test/cache-49b2ec5d50089605_*_of_00004.arrow


In [12]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [13]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [14]:
import evaluate

metric = evaluate.load("wer")

In [15]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [16]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

In [17]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="/mnt/vol_c/whisper-small-en-irish-accent-3",  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=5e-6,
    warmup_steps=500,
    max_steps=4000,
    gradient_checkpointing=True,
    fp16=False,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=400,
    eval_steps=400,
    logging_steps=25,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

In [18]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=split_dataset["train"],
    eval_dataset=split_dataset["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [19]:
trainer.train()

`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss,Wer
400,0.2157,0.206165,11.315683
800,0.1232,0.183969,10.663917
1200,0.1176,0.18061,10.840667
1600,0.0494,0.191588,10.792797
2000,0.0466,0.190812,10.730198
2400,0.027,0.207609,11.061605


KeyboardInterrupt: 

In [21]:
!sudo apt-get install -y zip

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following additional packages will be installed:
  unzip
The following NEW packages will be installed:
  unzip zip
0 upgraded, 2 newly installed, 0 to remove and 233 not upgraded.
Need to get 335 kB of archives.
After this operation, 1,231 kB of additional disk space will be used.
Get:1 http://nova.clouds.archive.ubuntu.com/ubuntu focal-updates/main amd64 unzip amd64 6.0-25ubuntu1.1 [168 kB]
Get:2 http://nova.clouds.archive.ubuntu.com/ubuntu focal/main amd64 zip amd64 3.0-11build1 [167 kB]
Fetched 335 kB in 0s (1,600 kB/s)
Selecting previously unselected package unzip.
(Reading database ... 88617 files and directories currently installed.)
Preparing to unpack .../unzip_6.0-25ubuntu1.1_amd64.deb ...
Unpacking unzip (6.0-25ubuntu1.1) ...
Selecting previously unselected package zip.
Preparing to unpack .../zip_3.0-11build1_amd64.deb ...
Unpacking zip (3.0-11build1) ...
Setting up unzip (6.0

In [23]:
!zip -r /mnt/vol_c/whisper-small-en-irish-accent-3-best.zip /mnt/vol_c/whisper-small-en-irish-accent-3/checkpoint-800/

  adding: mnt/vol_c/whisper-small-en-irish-accent-3/checkpoint-800/ (stored 0%)
  adding: mnt/vol_c/whisper-small-en-irish-accent-3/checkpoint-800/training_args.bin (deflated 48%)
  adding: mnt/vol_c/whisper-small-en-irish-accent-3/checkpoint-800/generation_config.json (deflated 71%)
  adding: mnt/vol_c/whisper-small-en-irish-accent-3/checkpoint-800/trainer_state.json (deflated 80%)
  adding: mnt/vol_c/whisper-small-en-irish-accent-3/checkpoint-800/scheduler.pt (deflated 48%)
  adding: mnt/vol_c/whisper-small-en-irish-accent-3/checkpoint-800/pytorch_model.bin (deflated 7%)
  adding: mnt/vol_c/whisper-small-en-irish-accent-3/checkpoint-800/preprocessor_config.json (deflated 42%)
  adding: mnt/vol_c/whisper-small-en-irish-accent-3/checkpoint-800/config.json (deflated 63%)
  adding: mnt/vol_c/whisper-small-en-irish-accent-3/checkpoint-800/optimizer.pt (deflated 7%)
  adding: mnt/vol_c/whisper-small-en-irish-accent-3/checkpoint-800/rng_state.pth (deflated 28%)
