# Whisper Fine Tuning Event

MIT License

## Import

In [1]:
import IPython.display
from pathlib import Path

import os
import numpy as np

import torch
from torch import nn
import pandas as pd
import whisper
import torchaudio
import torchaudio.transforms as at

from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

# from tqdm.notebook import tqdm   # for colab
from tqdm import tqdm              # for jupyter
import evaluate

from transformers import (
    AdamW,
    get_linear_schedule_with_warmup
)

from utils import CfgNode

from typing import List, Union

In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
CN = CfgNode()

CN.DATASET_DIR = r"C:\Users\Hedronstone\Desktop\whisper_event\data_broadcastnews_sw\data"
CN.SAMPLE_RATE = 16000
CN.BATCH_SIZE = 2
CN.TRAIN_RATE = 0.8
CN.BATCH_SIZE = 2

CN.AUDIO_MAX_LENGTH = 480000
CN.TEXT_MAX_LENGTH = 120
CN.SEED = 3407
CN.DEVICE = "gpu" if torch.cuda.is_available() else "cpu"
seed_everything(CN.SEED, workers=True)

Global seed set to 3407


3407

## Util

To make the code efficient, we use the `concurrent.futures` module for multithreading or multiprocessing to parallelize the resampling of the waveform. This can be useful if the resampling operation is computationally expensive and there are multiple waveforms to be resampled:

In [4]:
from concurrent.futures import ThreadPoolExecutor

def load_wave(wave_path: str, sample_rate: int=16000) -> torch.Tensor:
    """Takes a path to a wave file and an optional sample rate as inputs, 
    loads the waveform and normalizes it, and then resamples the waveform 
    to the specified sample rate if necessary.
    
    Arguments:
    wave_path -- The path of the audio file.
    sample_rate -- The number of individual sound samples per second.
    
    Returns:
    The waveform as a torch.Tensor object.    
    """
    # Load the waveform and resample it if necessary
    waveform, sr = torchaudio.load(wave_path, normalize=True)
    if sample_rate != sr:
        # Use a thread pool to parallelize the resampling operation
        with ThreadPoolExecutor() as executor:
            waveform = executor.submit(at.Resample(sr, sample_rate), waveform).result()
    
    return waveform

In [5]:
def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:
    waveform, sr = torchaudio.load(wave_path, normalize=True)
    if sample_rate != sr:
        waveform = at.Resample(sr, sample_rate)(waveform)
    return waveform

### Load Audio

In [6]:
import os
from pathlib import Path
from tqdm import tqdm
from typing import List, Tuple

def get_audio_path(audio_id: str, set_name: str, cfg: CfgNode) -> str:
    """Returns the path to the audio file with the given ID and set name.

    Arguments:
    audio_id -- The ID of the audio file.
    set_name -- The name of the set (train or test).
    cfg -- The configuration node containing the dataset directory.

    Returns:
    The path to the audio file.
    """
    if set_name == "test":
        audio_dir = Path(CN.DATASET_DIR) / set_name / "wav5"
        folders = [d.name for d in audio_dir.iterdir() if d.is_dir()]

        for folder in folders:
            audio_dirs = Path(CN.DATASET_DIR) / set_name / "wav5" / folder
            audio_paths = [p.name for p in audio_dirs.glob("*.wav")]

            for audio_path in audio_paths:
                return(audio_dirs / audio_path)      
        
    if set_name == "train":
        audio_dir = Path(cfg.DATASET_DIR) / set_name / "wav"
        folders = [d.name for d in audio_dir.iterdir() if d.is_dir()]

        for folder in folders:
            audio_dirs = audio_dir / folder
            audio_paths = [p.name for p in audio_dirs.glob("*.wav")]

            for audio_path in audio_paths:
                return(audio_dirs / audio_path) 


def stage_audio_data(cfg: CfgNode, set_name: str="train") -> List[Tuple[str, str, str]]:
    """Returns a list of tuples containing audio file IDs, paths, and transcriptions.

    Arguments:
    cfg -- The configuration node containing the dataset directory.
    set_name -- The name of the set (train or test). Defaults to "train".

    Returns:
    A list of tuples containing audio file IDs, paths, and transcriptions.
    """        
    path = Path(CN.DATASET_DIR)
    text_path = path / set_name / (set_name + "_text.txt")
    
    with open(text_path, "r") as f:
        text_list = f.readlines()
    
    audio_transcript_pairs = []
    for text in tqdm(text_list):
        audio_id, transcription = text.split("=")
        transcription = transcription.replace("\n", "")
        audio_path = get_audio_path(audio_id, set_name, CN)
        audio_transcript_pairs.append((audio_id, str(audio_path), transcription))
    
    return audio_transcript_pairs

In [32]:
train_audio_transcript_pairs = stage_audio_data(CN, "train")
test_audio_transcript_pairs = stage_audio_data(CN, "test")

100%|██████████████████████████████████████████████████████████████████| 10180/10180 [00:10<00:00, 1015.42it/s]
100%|████████████████████████████████████████████████████████████████████| 1991/1991 [00:01<00:00, 1735.83it/s]


In [8]:
print("TRAIN AUDIO DATASET NUM: ", len(train_audio_transcript_pairs))
print("EVAL AUDIO DATASET NUM: ", len(test_audio_transcript_pairs))

TRAIN AUDIO DATASET NUM:  10180
EVAL AUDIO DATASET NUM:  1991


Let's check our `stage_audio_data()` function for compatibility with `load_wave()`

In [9]:
print(load_wave(train_audio_transcript_pairs[0][1]))
print(load_wave(test_audio_transcript_pairs[0][1]))

tensor([[0.0304, 0.0286, 0.0277,  ..., 0.0885, 0.0946, 0.0796]])
tensor([[0.0019, 0.0025, 0.0040,  ..., 0.0955, 0.1230, 0.0723]])


### Data loader

In [10]:
woptions = whisper.DecodingOptions(language="sw", without_timestamps=True)
wmodel = whisper.load_model("base")
wtokenizer = whisper.tokenizer.get_tokenizer(True, language="sw", task=woptions.task)

The `SwahiliSpeechDataset` class below creates a dataset of audio information, 
including the audio file path, associated text, and tokenized text. It 
uses the sample rate and tokenizer specified in the init method, and has 
methods to return the length of the dataset and retrieve items from the 
dataset by index. Each dataset item is a dictionary containing the audio 
data as input_ids, the tokenized text as labels, and the original text as 
dec_input_ids.

In the `getitem` method, the model retrieves the audio information for a specific index, loads the audio file, pads or trims it to a specific length, and then extracts the log Mel-spectrogram as input_ids. It also encodes the text transcript using the tokenizer, and creates the labels and dec_input_ids for the model's training.


In [81]:
class SwahiliSpeechDataset(torch.utils.data.Dataset):
    """ Extracts log Mel-spectrogram as input_ids, encodes text transcripts using tokenizer, and creates the
    labels and dec_input_ids for the model's training.
    
    Arguments:
        audio_info_list -- A list of audio information, including the audio ID, audio path, and text transcript.
        sample_rate -- The sample rate of the audio files, which defaults to 16e3 (16kHz).
        tokenizer -- An instance of the whisper.tokenizer class that is used to encode the text transcript.        
    """
    def __init__(self, audio_info_list: str, tokenizer: whisper.tokenizer, sample_rate: int=16e3) -> None:
        super().__init__()
        
        self.audio_info_list = audio_info_list
        self.sample_rate = sample_rate
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.audio_info_list)
    
    def __getitem__(self, id):
        audio_id, audio_path, text = self.audio_info_list[id]
        
        audio = load_wave(audio_path, sample_rate=self.sample_rate)
        audio = whisper.pad_or_trim(audio.flatten())
        mel = whisper.log_mel_spectrogram(audio)
        
        text = self.audio_info_list[id][2]
        text = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text)
        labels = text[1:] + [self.tokenizer.eot]
        
        return {
            "input_ids": mel,
            "labels": labels,
            "dec_input_ids": text
        }

When called with a list of features, `WhisperDataCollatorWhithPadding` class 
collects the `input_ids` and `labels` attributes from each feature and concatenates 
them into a single tensor. It then pads the `labels` and `dec_input_ids` attributes 
with the constant values -100 and 50257, respectively, to the maximum length of all 
values in these attributes.Padded attributes are then converted into tensors and 
returned as a batch.

In [12]:
class WhisperDataCollatorWhithPadding:
    """Prepares batches for model training.
    
    Arguments:
        input_ids -- A list of input IDs representing the sequence of tokens in the input text.
        labels -- A list of labels corresponding to the input sequence.
        dec_input_ids -- A list of input IDs representing the sequence of tokens in the decoder input text.
    """

    def __call__(self, features):
        input_ids = [feature["input_ids"] for feature in features]
        labels = [feature["labels"] for feature in features]
        dec_input_ids = [feature["dec_input_ids"] for feature in features]

        input_ids = torch.concat([input_id[None, :] for input_id in input_ids])

        max_lengths = [len(lab) for lab in labels + dec_input_ids]
        max_length = max(max_lengths)

        labels = [np.pad(lab, (0, max_length - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, max_lengths)]
        dec_input_ids = [np.pad(e, (0, max_length - e_len), 'constant', constant_values=50257) for e, e_len in zip(dec_input_ids, max_lengths)] # 50257 is eot token id

        batch = {
            "labels": labels,
            "dec_input_ids": dec_input_ids
        }

        batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}
        batch["input_ids"] = input_ids

        return batch

Let's now check our `SwahiliSpeechDataset` and `WhisperDataCollatorWhithPadding` classes

In [13]:
dataset = SwahiliSpeechDataset(test_audio_transcript_pairs, wtokenizer, CN.SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())

for b in loader:
    print(b["labels"].shape)
    print(b["input_ids"].shape)
    print(b["dec_input_ids"].shape)
    
    for token, dec in zip(b["labels"], b["dec_input_ids"]):
        token[token == -100] = wtokenizer.eot
        text = wtokenizer.decode(token, skip_special_tokens=False)
        print(text)

        dec[dec == -100] = wtokenizer.eot
        text = wtokenizer.decode(dec, skip_special_tokens=False)
        print(text)
    
    break

torch.Size([2, 37])
torch.Size([2, 80, 3000])
torch.Size([2, 37])
<|sw|><|transcribe|><|notimestamps|>ya redio france internanational mimi ni zuhra mwera<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|startoftranscript|><|sw|><|transcribe|><|notimestamps|>ya redio france internanational mimi ni zuhra mwera<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|sw|><|transcribe|><|notimestamps|>marekani yasema iko tayari kuisaidia korea kusini kuikabili korea kaskazini<|endoftext|>
<|startoftranscript|><|sw|><|transcribe|><|notimestamps|>marekani yasema iko tayari kuisaidia korea kusini kuikabili korea kaskazini


In [14]:
with torch.no_grad():
    audio_features = wmodel.encoder(b["input_ids"].cuda())
    input_ids = b["input_ids"]
    labels = b["labels"].long()
    dec_input_ids = b["dec_input_ids"].long()
        
    audio_features = wmodel.encoder(input_ids.cuda())
    print(dec_input_ids)
    print(input_ids.shape, dec_input_ids.shape, audio_features.shape)
    print(audio_features.shape)
    print()
out = wmodel.decoder(dec_input_ids.cuda(), audio_features)

tensor([[50258, 50318, 50359, 50363,  3016,  2182,  1004,   431,   719,  2154,
           282,  1478,   275, 10121,  3867,  2164,    71,   424,   275,  1554,
            64, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257],
        [50258, 50318, 50359, 50363, 15455,    74,  3782,   288,   296,  5619,
           741,  4093,   220,    83,   320,  3504, 17807,  3837,   327,   654,
           350,   418,    64,   350,   301,  3812, 17807,  1035,   455,  2312,
           350,   418,    64,   350,  3863,   921,  3812]])
torch.Size([2, 80, 3000]) torch.Size([2, 37]) torch.Size([2, 1500, 512])
torch.Size([2, 1500, 512])



In [15]:
print(out.shape)
print(out.view(-1, out.size(-1)).shape)
print(b["labels"].view(-1).shape)

torch.Size([2, 37, 51865])
torch.Size([74, 51865])
torch.Size([74])


In [16]:
tokens = torch.argmax(out, dim=2)
for token in tokens:
    token[token == -100] = wtokenizer.eot
    text = wtokenizer.decode(token, skip_special_tokens=True)
    print(text)

 i radiio frans intes e mimi ni zhra mureraa
 i diereau- niireakanawuataaereakisaauuaka kuso


## Data Preprocess and Synthesis Pipeline
We will create a pipeline for processing two data streams. The pipeline will load the data from each stream, preprocess it as needed, and then return it in a format compatible with the Whisper model.

In [82]:
from datasets import Audio, interleave_datasets, IterableDataset, load_dataset
from typing import List, Optional

We now create `DatasetLoader` class, which takes in the names and config names of datasets as well as the names of text columns for each dataset. The class has two methods: `load_datasets`, which loads and processes the datasets, and `print_samples`, which prints a specified number of samples from each dataset. The datasets are processed by casting the "audio" column to the correct format, renaming the text column to "sentence", and removing any other columns. The class can be used to easily load and process multiple datasets with a consistent format.

In [79]:
class DatasetLoader:
    def __init__(self, dataset_names, dataset_config_names, text_column_names):
        # Initialize class variables
        self.dataset_names = dataset_names
        self.dataset_config_names = dataset_config_names
        self.text_column_names = text_column_names
    
    def load_datasets(self):
        # Create a list of splits for each dataset
        self.splits = ["train" for i in range(len(self.dataset_names))]
        # Initialize empty list to store datasets
        datasets = []

        # Loop through each dataset and load, process, and append to datasets list
        for i, dataset_name in tqdm(enumerate(self.dataset_names)):
            dataset = load_dataset(self.dataset_names[i], self.dataset_config_names[i], split=self.splits[i], streaming=True)
            dataset = dataset.cast_column("audio", Audio(CN.SAMPLE_RATE))
            if self.text_column_names[i] != "sentence":
                dataset = dataset.rename_column(self.text_column_names[i], "sentence")
            dataset = dataset.remove_columns(set(dataset.features.keys()) - set(["audio", "sentence"]))
            datasets.append(dataset)
        return datasets

    def print_samples(self, num_samples):
        # Load datasets
        datasets = self.load_datasets()
        # Loop through each dataset and print specified number of samples
        for dataset in datasets:
            for i, sample in enumerate(dataset):
                print(i, sample["sentence"])
                if i == num_samples-1:
                    break

Create instance of DatasetLoader and print samples from loaded datasets:

In [80]:
dataset_loader = DatasetLoader(["mozilla-foundation/common_voice_11_0", "google/fleurs"], ["sw", "sw_ke"], ["sentence", "transcription"])
dataset_loader.print_samples(10)

2it [00:03,  1.52s/it]
Reading metadata...: 26614it [00:02, 12683.20it/s]


ValueError: Input signal length=0 is too small to resample from 32000->16000

In [76]:
dataset_names = ["mozilla-foundation/common_voice_11_0", "google/fleurs"]
dataset_config_names = ["sw", "sw_ke"]
text_column_names = ["sentence", "transcription"]

splits = ["train" for i in range(len(dataset_names))]

for i, dataset_name in tqdm(enumerate(dataset_names)):                         
    dataset = load_dataset(dataset_names[i], dataset_config_names[i], split=splits[i], streaming=True)
    dataset = dataset.cast_column("audio", Audio(CN.SAMPLE_RATE))
    dataset = dataset.rename_column(text_column_names[i], "sentence")
    dataset = dataset.remove_columns(set(dataset.features.keys()) - set(["audio", "sentence"]))

for i, sample in enumerate(dataset):
    print(i, sample["sentence"])
    if i == 9:
        break

2it [00:02,  1.28s/it]


0 jua halina ganda gumu kama dunia ambayo unaweza kusimama kwayo. jua zima limetengenezwa kwa gesi moto na gesi ya utegili
1 hata hivyo ugunduzi wa kaburi lake mnamo 1922 ulimfanya mtu mashuhuri. huku makaburi mengi ya zamani yakiporwa kaburi hili liliachwa karibu bila kusumbuliwa kamwe
2 gurudumu limebadilisha dunia kwa njia za ajabu. jambo kubwa zaidi ambalo gurudumu limetufanyia ni kutupatia uchukuzi rahisi na wa haraka
3 hakikisha kwamba basi unalofikiria kuabiri linaenda hebron na si katika makazi ya karibu ya kiyahudi ya kiryat arba tu
4 wakati kupwa kwa maji kulifungua mwanya katika mto mystic katika kaskazini mashariki mwa rasi waliendeleza ua kwa haraka kwa ukuta mfupi wa mawe upande wa kaskazini na kuishia kwenye ukingo wa maji katika pwani ndogo
5 mwanzoni aliipatia alfabeti ya hangeul jina la hunmin jeongeum kumaanisha  sauti sahihi za maagizo kwa watu
6 uonaji au uwezo wa kuona hutegemea viungo vya hisia vya mfumo wa kuona au macho
7 usilalie godoro au mto kwenye ardhi kat

In [29]:
def load_multiple_streaming_datasets(
    dataset_names: List,
    dataset_config_names: List,
    splits: Optional[List] = None,
    text_column_names: Optional[List] = None,
    sampling_rate: Optional[int] = 16000,
    stopping_strategy: Optional[str] = "all_exhausted",
    **kwargs
) -> IterableDataset:

    if len(dataset_names) != len(dataset_config_names):
        raise ValueError(
            f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
            f" {len(dataset_config_names)} configs."
        )

    if splits is not None and len(splits) != len(dataset_names):
        raise ValueError(
            f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
        )

    if text_column_names is not None and len(text_column_names) != len(dataset_names):
        raise ValueError(
            f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
            f" {len(text_column_names)} text column names."
        )

    splits = splits if splits is not None else ["train" for i in range(len(dataset_names))]
    text_column_names = (
        text_column_names if text_column_names is not None else ["text" for i in range(len(dataset_names))]
    )

    all_datasets = []
    # iterate over the datasets we want to interleave
    for i, dataset_name in enumerate(dataset_names):
        dataset = load_dataset(dataset_name, dataset_config_names[i], split=splits[i], streaming=True, **kwargs)
        # resample to specified sampling rate
        dataset = dataset.cast_column("audio", Audio(sampling_rate))
        #  normalise columns to ["audio", "sentence"]
        if text_column_names[i] != "sentence":
            dataset = dataset.rename_column(text_column_names[i], "sentence")
        dataset = dataset.remove_columns(set(dataset.features.keys()) - set(["audio", "sentence"]))
        all_datasets.append(dataset)

    interleaved_dataset = interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)
    return interleaved_dataset

In [30]:
from datasets import IterableDatasetDict

raw_datasets = IterableDatasetDict()
raw_datasets["train"] = load_streaming_dataset("mozilla-foundation/common_voice_11_0", "sw", split="train", use_auth_token=True)  # set split="train+validation" for low-resource

from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Swahili", task="transcribe")

raw_datasets["train"].features

{'client_id': Value(dtype='string', id=None),
 'path': Value(dtype='string', id=None),
 'audio': Audio(sampling_rate=48000, mono=True, decode=True, id=None),
 'sentence': Value(dtype='string', id=None),
 'up_votes': Value(dtype='int64', id=None),
 'down_votes': Value(dtype='int64', id=None),
 'age': Value(dtype='string', id=None),
 'gender': Value(dtype='string', id=None),
 'accent': Value(dtype='string', id=None),
 'locale': Value(dtype='string', id=None),
 'segment': Value(dtype='string', id=None)}

## Trainer

In [20]:
CN.LEARNING_RATE = 0.0005
CN.WEIGHT_DECAY = 0.01
CN.ADAM_EPSILON = 1e-8
CN.WARMUP_STEPS = 2
CN.BATCH_SIZE = 16
CN.NUM_WORKER = 2
CN.NUM_TRAIN_EPOCHS = 1
CN.GRADIENT_ACCUMULATION_STEPS = 1

Let's create `WhisperModelModule` class that extends the `LightningModule` class from the PyTorch Lightning library. The class will initialize a whisper model for our chosen language and train only the decoder part of the model with a given dataset. The class will also define the forward method, training and validation steps, and the configuration of optimizers and schedulers. The class will log metrics such as loss, WER, and CER during training and validation.

#### Under Construction

In [17]:
class WhisperModelModule(LightningModule):
    def __init__(self):
        super().__init__()
        
    def forward(self):
        return
    
    def training_step(self):
        return
    
    def validation_step(self):
        return
    
    def configure_optimizer(self):
        return
    
    def setup(self):
        return
    
    def train_dataloader(self):
        return
    
    def val_dataloader(self):
        return