# Environment set-up

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Not connected to a GPU


In [None]:
!add-apt-repository -y ppa:jonathonf/ffmpeg-4
!apt update
!apt install -y ffmpeg

In [None]:
!pip install datasets>=2.6.1
!pip install git+https://github.com/huggingface/transformers
!pip install evaluate>=0.30
!pip install jiwer
!pip install pytorch-lightning==1.7.7
!pip install -qqq evaluate==0.2.2
!pip install git+https://github.com/openai/whisper.git
!pip install ray[tune]

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Import libraries

In [None]:
import IPython.display
from pathlib import Path
import whisper
import os
import numpy as np

try:
    import tensorflow 
except ImportError:
    pass

import torch
from torch import nn
import pandas as pd
import whisper
import torchaudio
import torchaudio.transforms as at

from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

import sklearn.model_selection

from tqdm.notebook import tqdm
import evaluate

from transformers import (
    AdamW,
    get_linear_schedule_with_warmup
)

from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray import air, tune

# Set up data and helper functions

In [None]:
# Load data
DATASET_DIR = "/content/drive/MyDrive/asr-qcFrench-data/segments"
#DATASET_DIR = "/content/drive/MyDrive/asr-qcFrench-data/test-input"
dataset_dir = Path(DATASET_DIR) 
transcripts_path_list = list(dataset_dir.glob("*.txt"))
print(len(transcripts_path_list))
print(transcripts_path_list)

In [None]:
SAMPLE_RATE = 16000
BATCH_SIZE = 2

AUDIO_MAX_LENGTH = 96000 ## TO 6 SEC
TEXT_MAX_LENGTH = 1200 ## very big to not filter things
SEED = 3407
DEVICE = "gpu" if torch.cuda.is_available() else "cpu"
seed_everything(SEED, workers=True)

In [None]:
# Read wav files
def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:
    waveform, sr = torchaudio.load(wave_path, normalize=True)
    if sample_rate != sr:
        waveform = at.Resample(sr, sample_rate)(waveform)
    return waveform

In [None]:
# French tokenizer
woptions = whisper.DecodingOptions(language="fr", without_timestamps=True)
wmodel = whisper.load_model("base")
wtokenizer = whisper.tokenizer.get_tokenizer(True, language="fr", task=woptions.task)

In [None]:
# The audio file list function for the 2.5 hours dataset 
def get_audio_file_list(transcripts_path_list, text_max_length=120, audio_max_sample_length=96000, sample_rate=16000):
    audio_transcript_pair_list = []
    for transcripts_path in transcripts_path_list:
        with open(transcripts_path, "r") as f:
            text = f.read()
            audio_id = os.path.basename(f.name).rstrip(".txt")

            audio_path = transcripts_path.with_suffix(".wav")
            if audio_path.exists():
                # Keep the data that satisfy the criteria
                audio = load_wave(audio_path, sample_rate=sample_rate)[0]
                if len(text) > text_max_length or len(audio) > audio_max_sample_length:
                    continue
                audio_transcript_pair_list.append((audio_id, str(audio_path), text))
    return audio_transcript_pair_list

In [None]:
# Make audio-transcript pair
audio_transcript_pair_list = get_audio_file_list(transcripts_path_list, 120, 96000, 16000)

In [None]:
# Split data into train, test, val
train, test= sklearn.model_selection.train_test_split(audio_transcript_pair_list, test_size=0.2, random_state=1)
train, val = sklearn.model_selection.train_test_split(train, test_size=0.25, random_state=1)

In [None]:
# Take a look at the split
print("train", train[0:10])
print("TRAIN AUDIO DATASET NUM: ", len(train))
print("val", val[0:10])
print("EVAL AUDIO DATASET NUM: ", len(val))
print("test", test[0:10])
print("TEST AUDIO DATASET NUM: ", len(test))

In [None]:
# Create dataset
class FrSpeechDataset(torch.utils.data.Dataset):
    def __init__(self, audio_info_list:list, tokenizer, sample_rate) -> None:
        super().__init__()

        self.audio_info_list = audio_info_list
        self.sample_rate = sample_rate
        self.tokenizer = tokenizer

    def __len__(self):
      return len(self.audio_info_list)
    
    def __getitem__(self, id):
        audio_id, audio_path, text = self.audio_info_list[id]

        # audio
        audio = load_wave(audio_path, sample_rate=self.sample_rate)
        audio = whisper.pad_or_trim(audio.flatten())
        mel = whisper.log_mel_spectrogram(audio)
        text_str = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text)
        labels = text_str[1:] + [self.tokenizer.eot]

        return {
            "input_ids": mel,
            "labels": labels,
            "dec_input_ids": text_str
        }

In [None]:
# Prepare data for them to be used in the model 
class WhisperDataCollatorWhithPadding:
    def __call__(self, features):
        input_ids, labels, dec_input_ids = [], [], []
        for f in features:
            input_ids.append(f["input_ids"])
            labels.append(f["labels"])
            dec_input_ids.append(f["dec_input_ids"])

        input_ids = torch.concat([input_id[None, :] for input_id in input_ids])
        
        label_lengths = [len(lab) for lab in labels]
        dec_input_ids_length = [len(e) for e in dec_input_ids]
        max_label_len = max(label_lengths+dec_input_ids_length)

        labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]
        dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=50257) for e, e_len in zip(dec_input_ids, dec_input_ids_length)] # 50257 is eot token id

        batch = {
            "labels": labels,
            "dec_input_ids": dec_input_ids
        }
        batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}
        batch["input_ids"] = input_ids

        return batch

In [None]:
# Split the data
train_dataset = FrSpeechDataset(train, wtokenizer, SAMPLE_RATE)
val_dataset = FrSpeechDataset(val, wtokenizer, SAMPLE_RATE)
test_dataset = FrSpeechDataset(test, wtokenizer, SAMPLE_RATE)

In [None]:
# Prepare for loading the different sets
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 2, collate_fn=WhisperDataCollatorWhithPadding())
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 2, collate_fn=WhisperDataCollatorWhithPadding())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 2, collate_fn=WhisperDataCollatorWhithPadding())

In [None]:
for b in train_loader:
    print(b["labels"].shape)
    print(b["input_ids"].shape)
    print(b["dec_input_ids"].shape)

    for token, dec in zip(b["labels"], b["dec_input_ids"]):
        token[token == -100] = wtokenizer.eot
        text = wtokenizer.decode(token, skip_special_tokens=False)
        print(text)

        dec[dec == -100] = wtokenizer.eot
        text = wtokenizer.decode(dec, skip_special_tokens=False)
        print(text)
    break

# Define the fine-tuning model

In [None]:
class Config:
    learning_rate = 0.0005 #0.005 #0.01
    weight_decay = 0.01
    adam_epsilon = 1e-8
    warmup_steps = 2
    batch_size = 16 # 32
    num_worker = 2
    num_train_epochs = 10
    gradient_accumulation_steps = 1
    sample_rate = SAMPLE_RATE

In [None]:
class WhisperModelModule(LightningModule):
    def __init__(self, cfg:Config, model_name="small", lang="fr", train_dataset=[], eval_dataset=[]) -> None:
        super().__init__()
        self.options = whisper.DecodingOptions(language=lang, without_timestamps=True)
        self.model = whisper.load_model(model_name)
        self.tokenizer = whisper.tokenizer.get_tokenizer(True, language="fr", task=self.options.task)

        # only decoder training
        for p in self.model.encoder.parameters():
            p.requires_grad = False
        
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
        self.metrics_wer = evaluate.load("wer")
        self.metrics_cer = evaluate.load("cer")

        self.cfg = cfg
        self.__train_dataset = train_dataset
        self.__eval_dataset = eval_dataset

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_id):
        input_ids = batch["input_ids"]
        labels = batch["labels"].long()
        dec_input_ids = batch["dec_input_ids"].long()

        with torch.no_grad():
            audio_features = self.model.encoder(input_ids)

        out = self.model.decoder(dec_input_ids, audio_features)
        loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))
        self.log("train/loss", loss, on_step=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_id):
        input_ids = batch["input_ids"]
        labels = batch["labels"].long()
        dec_input_ids = batch["dec_input_ids"].long()


        audio_features = self.model.encoder(input_ids)
        out = self.model.decoder(dec_input_ids, audio_features)

        loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))

        out[out == -100] = self.tokenizer.eot
        labels[labels == -100] = self.tokenizer.eot

        o_list, l_list = [], []
        for o, l in zip(out, labels):
            o = torch.argmax(o, dim=1)
            o_list.append(self.tokenizer.decode(o, skip_special_tokens=True))
            l_list.append(self.tokenizer.decode(l, skip_special_tokens=True))
        cer = self.metrics_cer.compute(references=l_list, predictions=o_list)
        wer = self.metrics_wer.compute(references=l_list, predictions=o_list)

        self.log("val/loss", loss, on_step=True, prog_bar=True, logger=True)
        self.log("val/cer", cer, on_step=True, prog_bar=True, logger=True)
        self.log("val/wer", wer, on_step=True, prog_bar=True, logger=True)

        return {
            "cer": cer,
            "wer": wer,
            "loss": loss
        }

    def configure_optimizers(self):
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() 
                            if not any(nd in n for nd in no_decay)],
                "weight_decay": self.cfg.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() 
                            if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, 
                          lr=self.cfg.learning_rate, 
                          eps=self.cfg.adam_epsilon)
        self.optimizer = optimizer

        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.cfg.warmup_steps, 
            num_training_steps=self.t_total
        )
        self.scheduler = scheduler

        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]
    
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.t_total = (
                (len(self.__train_dataset) // (self.cfg.batch_size))
                // self.cfg.gradient_accumulation_steps
                * float(self.cfg.num_train_epochs)
            )
    
    def train_dataloader(self):
        dataset = FrSpeechDataset(self.__train_dataset, self.tokenizer, self.cfg.sample_rate)
        return torch.utils.data.DataLoader(dataset, 
                          batch_size=self.cfg.batch_size, 
                          drop_last=True, shuffle=True, num_workers=self.cfg.num_worker,
                          collate_fn=WhisperDataCollatorWhithPadding()
                          )

    def val_dataloader(self):
        dataset = FrSpeechDataset(self.__eval_dataset, self.tokenizer, self.cfg.sample_rate)
        return torch.utils.data.DataLoader(dataset, 
                          batch_size=self.cfg.batch_size, 
                          num_workers=self.cfg.num_worker,
                          collate_fn=WhisperDataCollatorWhithPadding()
                          )
       

# Main

In [None]:
log_output_dir = "/content/logs"
check_output_dir = "/content/artifacts"

train_name = "whisper"
train_id = "00001"

model_name = "small"
lang = "fr"

In [None]:
# Clean cache
with torch.no_grad():
    torch.cuda.empty_cache()
model = None
gc.collect()

In [None]:
cfg = Config()

Path(log_output_dir).mkdir(exist_ok=True)
Path(check_output_dir).mkdir(exist_ok=True)

tflogger = TensorBoardLogger(
    save_dir=log_output_dir,
    name=train_name,
    version=train_id
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{check_output_dir}/checkpoint",
    filename="checkpoint-{epoch:04d}",
    save_top_k=-1 # all model save
)

callback_list = [checkpoint_callback, LearningRateMonitor(logging_interval="epoch")]
model = WhisperModelModule(cfg, model_name, lang, train, val)


In [None]:
trainer = Trainer(
    precision=16,
    accelerator=DEVICE,
    max_epochs=cfg.num_train_epochs,
    accumulate_grad_batches=cfg.gradient_accumulation_steps,
    logger=tflogger,
    callbacks=callback_list, 
    auto_lr_find = TRUE
    #auto_scale_batch_size = TRUE
)

In [None]:
# Automatic find the best learning rate
lr_finder = trainer.tuner.lr_find(model)
fig = lr_finder.plot(suggest = True)
lr_finder.suggestion()

In [None]:
class Config:
    learning_rate = 2.7542287033381663e-05 #0.0005 #0.005 #0.01
    weight_decay = 0.01
    adam_epsilon = 1e-8
    warmup_steps = 4 # 2
    batch_size = 16 # 32
    num_worker = 2
    num_train_epochs = 15 #10
    gradient_accumulation_steps = 1
    sample_rate = SAMPLE_RATE

In [None]:
trainer.fit(model)

## Visualize the model performance throughout training

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir /content/logs

## Best model

In [None]:
checkpoint_callback.best_model_path

In [None]:
checkpoint_path = "/content/artifacts/checkpoint/checkpoint-epoch=0009.ckpt"

In [None]:
state_dict = torch.load(checkpoint_path)
print(state_dict.keys())
state_dict = state_dict['state_dict']

In [None]:
whisper_model = WhisperModelModule(cfg)
whisper_model.load_state_dict(state_dict)

In [None]:
whisper_base = whisper.load_model("small")

In [None]:
woptions = whisper.DecodingOptions(language="fr", without_timestamps=True)
dataset = FrSpeechDataset(val, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())

refs = []
res = []
res_b = []
for b in tqdm(loader):
    input_ids = b["input_ids"].half().cuda()
    labels = b["labels"].long().cuda()
    with torch.no_grad():
        results = whisper_model.model.decode(input_ids, woptions)
        results_b = whisper.decode(whisper_base, input_ids, woptions)
        for r in results:
            res.append(r.text)
        for r in results_b:
          res_b.append(r.text)
        for l in labels:
            l[l == -100] = wtokenizer.eot
            ref = wtokenizer.decode(l, skip_special_tokens=True)
            refs.append(ref)

## Baseline WER

In [None]:
wer_metrics.compute(references=refs, predictions=res_b)

In [None]:
for k, v in zip(refs, res_b):
    print("-"*10)
    print(k)
    print(v)

## Our model's WER

In [None]:
wer_metrics = evaluate.load("wer")
wer_metrics.compute(references=refs, predictions=res)

In [None]:
for k, v in zip(refs, res):
    print("-"*10)
    print(k)
    print(v)