# Fine-Tune Whisper for Speech-to-Text Using Custom Datasets

## Introduction

In this blog post, we will show you how to fine-tune the open-source Whisper speech-to-text model using your own custom datasets. By doing so, you can create a model that is tailored to your specific needs and can achieve higher levels of accuracy and performance. We will guide you through the process of preparing your data, training the model, and evaluating its performance. By the end of this tutorial, you will have a custom Whisper model that is optimized for your unique application.

## Import Libraries

To begin, we will first need to import the necessary libraries and modules that we will use throughout this tutorial. This includes the Whisper library, as well as other popular machine learning and data manipulation libraries such as PyTorch and NumPy.

In [1]:
import os
os.chdir(".")

In [2]:
import numpy as np
import pandas as pd

from datasets import load_dataset
from transformers import WhisperProcessor
from datasets import Audio

  from pandas.core.computation.check import NUMEXPR_INSTALLED


## Load Dataset

Next, we will load our custom dataset. This dataset should consist of audio recordings and their corresponding transcriptions. The audio recordings should be in a format that is compatible with Whisper, such as WAV or MP3. The transcriptions should be provided as text files, with one transcription per line.

In [3]:
data_path = "datasets/metadata.csv"
combined_dataset = pd.read_csv(data_path)
combined_dataset[:3]

Unnamed: 0,file_name,transcription
0,datasets/kenspeech/audios/male/speaker_17/twee...,mungu ndiye muumba na mlinzi wa ulimwengu unae...
1,datasets/kenspeech/audios/male/speaker_17/twee...,hongera ni heshima kubwa kufanya kazi na wewe ...
2,datasets/kenspeech/audios/male/speaker_17/twee...,asilimia kubwa ya mwanamke anapenda mwanamume ...


In [4]:
import torch
import whisper
from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-medium")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-medium", language="Swahili", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-medium", language="Swahili", task="transcribe")
normalizer = BasicTextNormalizer()

In [5]:
import wave
import io
import torchaudio
import torchaudio.transforms as at
    
def load_wave(wave_path, sample_rate:int) -> torch.Tensor:
    waveform, sr = torchaudio.load(wave_path, normalize=True)
    if sample_rate != sr:
        waveform = at.Resample(sr, sample_rate)(waveform)
    return waveform

def filter_long_samples(dataset, max_length: int=30.0):
    indices_of_long_samples = []
    for idx, audio_path in enumerate(dataset.file_name):
        audio = load_wave(dataset.file_name[idx], sample_rate=16000)
        audio_length = audio.size(1) / 16000
        if audio_length > max_length:
            indices_of_long_samples.append(idx)
    return indices_of_long_samples

In [6]:
indices_of_long_samples = filter_long_samples(combined_dataset)
combined_dataset = combined_dataset.drop(indices_of_long_samples).reset_index()

In [7]:
class PrepareSpeechDataset(torch.utils.data.Dataset):
    def __init__(self, audio_info_csv, sample_rate, processor, feature_extractor) -> None:
        super().__init__()
        self.audio_info_csv = audio_info_csv
        self.sample_rate = sample_rate
        self.processor = processor
        self.feature_extractor = feature_extractor
        
    def __len__(self):
        return len(self.audio_info_csv)
    
    def __getitem__(self, id):
        audio_path = self.audio_info_csv.file_name[id]
        transcription = self.audio_info_csv.transcription[id]
        
        audio = self._load_wave(audio_path, sample_rate=self.sample_rate)
        audio_length = audio.size(1) / self.sample_rate
        
        # fileter audio_length > 30s

        return self.prepare_dataset(
            {
            'audio' : {'path': audio_path,
                       'array': audio,
                       'sampling_rate' : self.sample_rate,
                       'input_length' : audio.size(1) / self.sample_rate
                      },
            'sentence' : transcription
            }
        )

    def _load_wave(self, wave_path, sample_rate:int) -> torch.Tensor:
        waveform, sr = torchaudio.load(wave_path, normalize=True)
        if sample_rate != sr:
            waveform = at.Resample(sr, self.sample_rate)(waveform)
        return waveform

    def prepare_dataset(self, batch):
        # load and resample audio data from 48 to 16kHz
        audio = batch["audio"]

        # compute log-Mel input features from input audio array 
        batch["input_features"] = self.feature_extractor(audio["array"].flatten(), sampling_rate=audio["sampling_rate"]).input_features[0]
        

        # encode target text to label ids 
        batch["labels"] = tokenizer(batch["sentence"]).input_ids
        return batch
    
    def get_output_dictionary(self, dic, split="train"):
        return {
            f"{split}" : dic
        }


In [9]:
# Train dataset
train_dataset = combined_dataset.reset_index()
train_dataset = PrepareSpeechDataset(train_dataset, sample_rate=16000, processor=processor, feature_extractor=feature_extractor)
next(iter(train_dataset))

{'audio': {'path': 'datasets/kenspeech/audios/male/speaker_17/tweet_286.wav',
  'array': tensor([[0., 0., 0.,  ..., 0., 0., 0.]]),
  'sampling_rate': 16000,
  'input_length': 12.61925},
 'sentence': 'mungu ndiye muumba na mlinzi wa ulimwengu unaeza kulindwa na majesshi lakini huezi kukiepuka kifo usalama wetu umehakikishiwa na mwenyezi mungu',
 'input_features': array([[-0.7177355, -0.7177355, -0.7177355, ..., -0.7177355, -0.7177355,
         -0.7177355],
        [-0.7177355, -0.7177355, -0.7177355, ..., -0.7177355, -0.7177355,
         -0.7177355],
        [-0.7177355, -0.7177355, -0.7177355, ..., -0.7177355, -0.7177355,
         -0.7177355],
        ...,
        [-0.7177355, -0.7177355, -0.7177355, ..., -0.7177355, -0.7177355,
         -0.7177355],
        [-0.7177355, -0.7177355, -0.7177355, ..., -0.7177355, -0.7177355,
         -0.7177355],
        [-0.7177355, -0.7177355, -0.7177355, ..., -0.7177355, -0.7177355,
         -0.7177355]], dtype=float32),
 'labels': [50258,
  50318,
  

In [12]:
# Test dataset

def normalize_transcriptions(batch):
    # optional pre-processing steps
    transcription = batch["sentence"]
    if do_lower_case:
        transcription = transcription.lower()
    if do_remove_punctuation:
        transcription = normalizer(transcription).strip()
    batch["sentence"] = transcription
    return batch

from datasets import load_dataset, DatasetDict
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

normalizer = BasicTextNormalizer()

do_lower_case = True
do_remove_punctuation = True

common_voice = DatasetDict()
common_voice = load_dataset("mozilla-foundation/common_voice_11_0", "sw", split="test", use_auth_token=True)

common_voice = common_voice.map(
    normalize_transcriptions,
    remove_columns=["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"]
)

next(iter(common_voice))

Found cached dataset common_voice_11_0 (/home/ubuntu/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/sw/11.0.0/f8e47235d9b4e68fa24ed71d63266a02018ccf7194b2a8c9c598a5f3ab304d9f)
Loading cached processed dataset at /home/ubuntu/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/sw/11.0.0/f8e47235d9b4e68fa24ed71d63266a02018ccf7194b2a8c9c598a5f3ab304d9f/cache-ae4ad253761f21fe.arrow


{'audio': {'path': '/home/ubuntu/.cache/huggingface/datasets/downloads/extracted/089cf2c8ac1b9b618eb8b26166acfdd2eb9872a1e2fc8ec769c3b65c985809e0/common_voice_sw_31428161.mp3',
  'array': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
  'sampling_rate': 48000},
 'sentence': 'wachambuzi wa soka wanamtaja messi kama nyota hatari zaidi duniani'}

In [23]:
from datasets import Audio
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"].flatten(), sampling_rate=audio["sampling_rate"]).input_features[0]


    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

common_vocie = common_voice.map(prepare_dataset)

  0%|          | 0/10238 [00:00<?, ?ex/s]

In [26]:
common_voice = common_vocie

## Data Collator

In [27]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [28]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

## Evaluation Metrics

In [29]:
import evaluate

metric = evaluate.load("wer")

# evaluate with the 'normalised' WER
do_normalize_eval = True

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    if do_normalize_eval:
        pred_str = [normalizer(pred) for pred in pred_str]
        label_str = [normalizer(label) for label in label_str]

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

## Load Pre-Trained Checkpoint

In [43]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("whisper-medium-sw/checkpoint-3000")

loading configuration file whisper-medium-sw/checkpoint-3000/config.json
Model config WhisperConfig {
  "_name_or_path": "openai/whisper-small",
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "architectures": [
    "WhisperForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "begin_suppress_tokens": [
    220,
    50257
  ],
  "bos_token_id": 50257,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 50258,
  "dropout": 0.0,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 50257,
  "forced_decoder_ids": null,
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "max_length": 448,
  "max_source_positions": 1500,
  "max_target_positions": 448,
  "model_type": "whisper",
  "num_hidden_layers": 12,
  "num_mel_bins": 80,
  "pad_token_id": 50257,
  "scale_embedding": false,
  "suppres

In [44]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False

model.config.dropout = 0.1

## Define the Training Run

In [45]:
from transformers import Seq2SeqTrainingArguments

# ==== Configuration ====
OUTPUT_DIR = "tuned_weights/whisper-medium-sw"
BATCH_SIZE = 32
GRADIENT_ACCUMULATION_STEPS = 2
LEARNING_RATE = 1e-5
WARMUP_STEPS = 500
MAX_STEPS = 5000
EVAL_BATCH_SIZE=8
SAVE_STEPS = 1000
EVAL_STEPS = 1000
LOGGING_STEPS=25
# =======================

training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,  # increase by 2x for every 2x decrease in batch size
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    max_steps=MAX_STEPS,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=SAVE_STEPS,
    eval_steps=EVAL_STEPS,
    logging_steps=LOGGING_STEPS,
    report_to=["wandb"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

PyTorch: setting up devices


In [46]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=common_voice,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)


max_steps is given, it will override any value given in num_train_epochs
Using cuda_amp half precision backend


In [47]:
model.save_pretrained(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)

Configuration saved in tuned_weights/whisper-medium-sw/config.json
Model weights saved in tuned_weights/whisper-medium-sw/pytorch_model.bin
Feature extractor saved in tuned_weights/whisper-medium-sw/preprocessor_config.json
tokenizer config file saved in tuned_weights/whisper-medium-sw/tokenizer_config.json
Special tokens file saved in tuned_weights/whisper-medium-sw/special_tokens_map.json
added tokens file saved in tuned_weights/whisper-medium-sw/added_tokens.json


In [48]:
import wandb

# W&B argument tracking
config = dict(
    output_dir=OUTPUT_DIR,  # your repo name
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,  # increase by 2x for every 2x decrease in batch size
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    max_steps=MAX_STEPS,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=SAVE_STEPS,
    eval_steps=EVAL_STEPS,
    logging_steps=LOGGING_STEPS,
    report_to=["wandb"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

# start a new W&B training run
wandb.init(
    project="whisper-medium-sw-fine_tuned", 
    entity="mldude", 
    tags=["whisper-event", "whisper-medium-sw", "Swahili", "SOTA"],
    config=config
)

VBox(children=(Label(value='0.001 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.202912…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666815633313187, max=1.0)…

In [49]:
%%wandb

trainer.train()

***** Running training *****
  Num examples = 12980
  Num Epochs = 25
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 2
  Total optimization steps = 5000
  Number of trainable parameters = 241734912
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: audio, sentence. If audio, sentence are not expected by `WhisperForConditionalGeneration.forward`,  you can safely ignore this message.


Step,Training Loss,Validation Loss,Wer
1000,0.38,0.383123,43.809884
2000,0.073,0.458013,40.020884
3000,0.0077,0.509322,48.430812


The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: audio, sentence. If audio, sentence are not expected by `WhisperForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10238
  Batch size = 8
Saving model checkpoint to tuned_weights/whisper-medium-sw/checkpoint-1000
Configuration saved in tuned_weights/whisper-medium-sw/checkpoint-1000/config.json
Model weights saved in tuned_weights/whisper-medium-sw/checkpoint-1000/pytorch_model.bin
Feature extractor saved in tuned_weights/whisper-medium-sw/checkpoint-1000/preprocessor_config.json
The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: audio, sentence. If audio, sentence are not expected by `WhisperForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running 

KeyboardInterrupt: 