Since I did this task on google colab, I needed to install several libraries and add-ons

In [1]:
# ! pip install -U openai-whisper
# import locale
# locale.getpreferredencoding = lambda: "UTF-8"
# !add-apt-repository -y ppa:jonathonf/ffmpeg-4
# !apt update
# !apt install -y ffmpeg

In [1]:
import whisper

import numpy as np
import torch
import torchaudio
from torch.utils.data import DataLoader
from torchaudio.datasets import LIBRISPEECH

from tqdm.notebook import tqdm

import metrics_task_1

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
def collate_fn(items: list):
    mels, tokens = zip(*items)

    # padding to match shapes
    max_token_len = max([len(token) for token in tokens])
    padded_tokens = torch.tensor([token + [0] * (max_token_len - len(token)) for token in tokens])

    mels = torch.stack([mel.clone().detach() for mel in mels])

    return mels, padded_tokens

In [3]:
class MyDataset(LIBRISPEECH):
    def __init__ (self, url, tokenizer):
        super().__init__('/content/repos', url=url, download=True)
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        audio, sample_rate, text, _, _, _ = super().__getitem__(idx)
        assert sample_rate == 16000

        audio = whisper.pad_or_trim(audio.flatten()).to(DEVICE)
        mel = whisper.log_mel_spectrogram(audio)

        # tokens = (start of sentence without <|start of trancr|>) + text + <|end of text|>
        tokens = [*self.tokenizer.sot_sequence_including_notimestamps[1:]] + self.tokenizer.encode(text.lower()) + [self.tokenizer.eot]

        return mel, tokens

In [4]:
class Trainer:
    def __init__(self,
            model: whisper.Whisper, train_dataset , valid_dataset, epochs, training_loader, validation_loader, lr=0.001):
        self.model = model
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.epochs = epochs
        self.training_loader = training_loader
        self.validation_loader = validation_loader
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.best_val_loss = float("inf")

    def train_step(self, input_mel, target_tokens):
        input_mel = input_mel.to(DEVICE)
        target_tokens = target_tokens.to(DEVICE)

        # To train mode and zero our gradients
        self.model.train()
        self.optimizer.zero_grad()

        # Make predictions for this batch, equals to self.forward with output - self.decoder(tokens, self.encoder(mel))
        output = self.model(input_mel, target_tokens)

        # Compute the loss and its gradients        
        loss = self.loss_fn(output.transpose(1, 2), target_tokens) #transposing to match shapes
        loss.backward()

        # Adjust learning weights
        self.optimizer.step()

        return loss.item()

    def train(self):
        for epoch in range(self.epochs):
            running_loss = 0.0

            for input_mel, target_tokens in tqdm(self.training_loader):
                loss = self.train_step(input_mel, target_tokens)
                running_loss += loss

            running_loss = running_loss / len(self.training_loader) # calculate avg loss 

            print(f'EPOCH {epoch}')
            print(f"Training loss: {running_loss:.4f}")

            val_loss, wer = self.validate()

            print(f'Validation loss: {val_loss:.4f}')
            print(f'Validation WER: {wer:.4f}')
            
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                torch.save(self.model.state_dict(), "/content/savedmodel/whisper_model.pt")

    def validate (self):
        val_loss = 0
        wer = 0
        self.model.eval() # to evaluation mode
        
        for input_mel, target_tokens in tqdm(self.validation_loader): 
            with torch.no_grad():
                input_mel = input_mel.to(DEVICE)
                target_tokens = target_tokens.to(DEVICE)

                output = self.model(input_mel, target_tokens).transpose(1, 2)

                loss = self.loss_fn(output, target_tokens)              
                val_loss += loss.item() * input_mel.size()[0]

                wer += metrics_task_1.wer_torch(output.argmax(dim=1), target_tokens)
        return val_loss / len(self.validation_loader), wer / len(self.validation_loader)

In [5]:
epochs = 1
tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True, language='en', task='transcribe')

train_dataset = MyDataset(tokenizer=tokenizer, url="train-clean-100")

valid_dataset = MyDataset(tokenizer=tokenizer, url="dev-clean")

training_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, collate_fn=collate_fn, shuffle=True)
validation_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=8, collate_fn=collate_fn, shuffle=True)

In [7]:
params = torch.load('/content/patriotic_whisper_mixed_en_uk.pt')
model = whisper.load_model("tiny", device=DEVICE)
model.load_state_dict(params)
print("Patriotic model loaded!")

Patriotic model loaded!


In [8]:
model.transcribe('/content/tests/jfk.flac')['text']

'and і що? my fellow americans asked not Шо треба? your country можеш? do для тебе ти asked Шо треба? ти можеш? do для тебе your country Героям слава!'

In [8]:
trainer = Trainer(model, train_dataset, valid_dataset, epochs, training_loader, validation_loader)

In [10]:
trainer.train()

  0%|          | 0/3568 [00:00<?, ?it/s]

EPOCH 0
Training loss: 0.2091


  0%|          | 0/338 [00:00<?, ?it/s]

Validation loss: 0.1902
Validation WER: 0.0025


There is a problem with showing progress bar, I don't know how to fix it

In [40]:
mel, tokens = next(trainer.validation_loader._get_iterator())

mel = mel.to(DEVICE)
tokens = tokens.to(DEVICE)
output = model(mel, tokens).argmax(dim=2)


print(f'True: {tokenizer.decode(tokens[0])} \n')
print(f'Pred: {tokenizer.decode(output[0])}')

True: <|en|><|transcribe|><|notimestamps|>as i passed through the upper entry veronica opened her door she was undressed and had a little book in her hand which she shook at me saying<|endoftext|>!!!!!!!!!!!!!!!!!! 

Pred: <|en|><|transcribe|><|notimestamps|>as i passed through the upper entry veronica opened her door she was undressed and had a little book in her hand which she shook at me saying<|endoftext|>!!!!!!!!!!!!!!!!!!


As we can see model shows itself very good on test data

In [50]:
%%timeit

metrics_task_1.wer_torch(output.argmax(dim=2), tokens)

1.55 s ± 278 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


~278 ms per batch which consists out of 8 files is very good, as for me

In [29]:
waveform, _ = torchaudio.load('/content/tests/jfk.flac')
audio = whisper.pad_or_trim(waveform.flatten()).to(DEVICE)
mel_1 = whisper.log_mel_spectrogram(audio)
mels = mel_1.unsqueeze(0).repeat(8, 1, 1).to(DEVICE)
output = model(mels, tokens)
tokenizer.decode(output[0].argmax(dim=1)) # we can use any tokens here

'<|en|><|transcribe|><|notimestamps|>and so my fellow americans ask not what your country can do for you<|endoftext|> ask what you can do for your country of<|endoftext|>'


I don't know why but, model.transcribe() started to give poor results, so I was forced to use this workaround