# Imports

In [1]:
import os
from pathlib import Path

In [2]:
import librosa
import jiwer

In [3]:
import numpy as np
import pandas as pd
from tqdm import tqdm

In [4]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [5]:
import torchaudio

In [6]:
from transformers import  WhisperProcessor, WhisperForConditionalGeneration

# Project Variables

In [7]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cpu'

# Load the Librispeech dataset

In [8]:
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")

In [9]:
class LibriSpeech(Dataset):
    def __init__(self, split="test-clean"):
        self.dataset = torchaudio.datasets.LIBRISPEECH(
            root=Path("~/.cache/torch_datasets").expanduser(),
            url="test-clean",
            download=True
            )
        self.device = DEVICE
        self.processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
        self.tokenizer = processor.tokenizer
        

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        audio, sample_rate, text, _, _, _ = self.dataset[idx]
        assert sample_rate == 16000 
        audio = audio.flatten()
        audio_features = self.processor(audio, sampling_rate=sample_rate, truncation=True, padding="max_length", return_tensors="pt").input_features.squeeze(0)
        text_features = self.tokenizer(text, padding="max_length", truncation=True, return_tensors="pt", max_length=448)
        text_tokens = text_features.input_ids.squeeze(0)
        attention_mask = text_features.attention_mask.squeeze(0)
        return audio_features, text, text_tokens, attention_mask 

In [10]:
dataset = LibriSpeech()
loader = DataLoader(dataset, batch_size=8)

In [11]:
print(next(iter(loader))[0].shape)
print(next(iter(loader))[1])
print(next(iter(loader))[2].shape)
print(next(iter(loader))[3].shape)

torch.Size([8, 80, 3000])
('HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE', 'STUFF IT INTO YOU HIS BELLY COUNSELLED HIM', 'AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS', 'HELLO BERTIE ANY GOOD IN YOUR MIND', 'NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD NIGHT HUSBAND', "THE MUSIC CAME NEARER AND HE RECALLED THE WORDS THE WORDS OF SHELLEY'S FRAGMENT UPON THE MOON WANDERING COMPANIONLESS PALE FOR WEARINESS", 'THE DULL LIGHT FELL MORE FAINTLY UPON THE PAGE WHEREON ANOTHER EQUATION BEGAN TO UNFOLD ITSELF SLOWLY AND TO SPREAD ABROAD ITS WIDENING TAIL', 'A COLD LUCID INDIFFERENCE REIGNED IN HIS SOUL')
torch.Size([8, 448])
torch.Size([8, 448])


# Evaluate the model on data before training

In [12]:
model.eval()
audio_inputs, text, text_tokens, attention_mask = next(iter(loader))
predicted_ids = model.generate(audio_inputs, attention_mask=attention_mask)

In [13]:
text_pred = processor.batch_decode(predicted_ids, skip_special_tokens=True)
text_pred

[' He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat-mutton pieces to be ladled out and thick, peppered flower-fattened sauce.',
 ' Stuffered into you, his belly, counseled him.',
 ' After early nightfall, the yellow lamps would light up here and there, the squalid quarter of the brothels.',
 ' Hello Bertie, any good in your mind?',
 ' Number 10, fresh Nelly is waiting on you. Good night, husband.',
 " The music came nearer and he recalled the words. The words of Shelley's fragment upon the moon wandering companionless, pale for weariness.",
 ' The dull light fell more faintly upon the page where on another equation began to unfold itself slowly and to spread abroad its widening tale.',
 ' a cold, lucid indifference rained in his soul.']

In [14]:
text = [processor.tokenizer.basic_normalize(t) for t in text]

In [15]:
df = pd.DataFrame({"original_text": text, "predicted_text": text_pred})
df

Unnamed: 0,original_text,predicted_text
0,he hoped there would be stew for dinner turnip...,"He hoped there would be stew for dinner, turn..."
1,stuff it into you his belly counselled him,"Stuffered into you, his belly, counseled him."
2,after early nightfall the yellow lamps would l...,"After early nightfall, the yellow lamps would..."
3,hello bertie any good in your mind,"Hello Bertie, any good in your mind?"
4,number ten fresh nelly is waiting on you good ...,"Number 10, fresh Nelly is waiting on you. Goo..."
5,the music came nearer and he recalled the word...,The music came nearer and he recalled the wor...
6,the dull light fell more faintly upon the page...,The dull light fell more faintly upon the pag...
7,a cold lucid indifference reigned in his soul,"a cold, lucid indifference rained in his soul."


In [16]:
wer = jiwer.wer(df["original_text"].tolist(), df["predicted_text"].tolist())
print(f"WER before training: {wer * 100:.3f}%")

WER before training: 33.858%


# Train the model to overfit a single batch

In [17]:
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

In [19]:
num_steps = 10

In [20]:
num_batches = len(loader.dataset) // loader.batch_size
num_batches = 1

In [21]:
for step in range(num_steps):
    for i, batch_data in enumerate(tqdm(loader, total=num_batches)):
        audio_input, text, text_tokens, attention_mask = batch_data

        optimizer.zero_grad()
        
        outputs = model(audio_input, labels=text_tokens, attention_mask=attention_mask)
        
        loss = outputs.loss
    
        loss.backward()
        optimizer.step()
    
        print(f"Step {step+1}. Batch {i+1}, Loss: {loss.item():.4f}")
        break

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
  0%|                                                                                                                                       | 0/327 [00:13<?, ?it/s]


Step 1. Batch 1, Loss: 3.3286


  0%|                                                                                                                                       | 0/327 [00:13<?, ?it/s]


Step 2. Batch 1, Loss: 0.6258


  0%|                                                                                                                                       | 0/327 [00:14<?, ?it/s]


Step 3. Batch 1, Loss: 0.3406


  0%|                                                                                                                                       | 0/327 [00:14<?, ?it/s]


Step 4. Batch 1, Loss: 0.2338


  0%|                                                                                                                                       | 0/327 [00:14<?, ?it/s]


Step 5. Batch 1, Loss: 0.1845


  0%|                                                                                                                                       | 0/327 [00:14<?, ?it/s]


Step 6. Batch 1, Loss: 0.1595


  0%|                                                                                                                                       | 0/327 [00:14<?, ?it/s]


Step 7. Batch 1, Loss: 0.1455


  0%|                                                                                                                                       | 0/327 [00:15<?, ?it/s]


Step 8. Batch 1, Loss: 0.1368


  0%|                                                                                                                                       | 0/327 [00:15<?, ?it/s]


Step 9. Batch 1, Loss: 0.1306


  0%|                                                                                                                                       | 0/327 [00:15<?, ?it/s]

Step 10. Batch 1, Loss: 0.1259





# Eval model outputs after training

In [22]:
model.eval()
audio_inputs, text, text_tokens, attention_mask = next(iter(loader))
predicted_ids = model.generate(audio_inputs, attention_mask=attention_mask)

In [23]:
text_pred = processor.batch_decode(predicted_ids, skip_special_tokens=True)
text_pred

[' He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat-mutton pieces to be ladled out and thick peppered flower-fat and sauce.',
 ' stuff it into you, his belly counseled him.',
 ' After early nightfall, the yellow lamps would light up here and there, the squalid quarter of the brothels.',
 ' Hello, Bertie. Any good in your mind?',
 ' Number 10, fresh Nelly is waiting on you. Good night, husband.',
 " The music came nearer and he recalled the words. The words of Shelley's fragment upon the moon wandering companionless pale for weariness.",
 ' The dull light fell more faintly upon the page where on another equation began to unfold itself slowly and to spread abroad its widening tail.',
 ' a cold, lucid indifference rained in his soul.']

In [26]:
text_clean = [processor.tokenizer.basic_normalize(t) for t in text]
text_pred_clean = [processor.tokenizer.normalize(t) for t in text_pred]

In [29]:
df = pd.DataFrame({"original_text": text_clean, "predicted_text": text_pred_clean})
df

Unnamed: 0,original_text,predicted_text
0,he hoped there would be stew for dinner turnip...,he hoped there would be stew for dinner turnip...
1,stuff it into you his belly counselled him,stuff it into you his belly counseled him
2,after early nightfall the yellow lamps would l...,after early nightfall the yellow lamps would l...
3,hello bertie any good in your mind,hello bertie any good in your mind
4,number ten fresh nelly is waiting on you good ...,number 10 fresh nelly is waiting on you good n...
5,the music came nearer and he recalled the word...,the music came nearer and he recalled the word...
6,the dull light fell more faintly upon the page...,the dull light fell more faintly upon the page...
7,a cold lucid indifference reigned in his soul,a cold lucid indifference rained in his soul


In [30]:
wer = jiwer.wer(df["original_text"].tolist(), df["predicted_text"].tolist())
print(f"WER before training: {wer * 100:.3f}%")

WER before training: 7.874%


In [70]:
df.predicted_text.tolist()

['he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out and thick peppered flower fat and sauce',
 'stuff it into you his belly counseled him',
 'after early nightfall the yellow lamps would light up here and there the squalid quarter of the brothels',
 'hello bertie any good in your mind',
 'number 10 fresh nelly is waiting on you good night husband',
 'the music came nearer and he recalled the words the words of shelley is fragment upon the moon wandering companionless pale for weariness',
 'the dull light fell more faintly upon the page where on another equation began to unfold itself slowly and to spread abroad its widening tail',
 'a cold lucid indifference rained in his soul']

In [32]:
predicted_ids

tensor([[  679, 10719,   612,   561,   307, 20798,   329,  8073,    11,  1210,
          2419,   290, 34397,   290, 44379, 18821,   290,  3735,    12,    76,
         21115,  5207,   284,   307,  9717,   992,   503,   290,  6546, 49038,
          1068, 15061,    12, 17359,   290, 10746,    13],
        [ 3404,   340,   656,   345,    11,   465, 19921,  7739,   276,   683,
            13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256],
        [ 2293,  1903,  1755,  7207,    11,   262,  7872, 32209,   561,  1657,
           510,   994,   290,   612,    11,   262,  2809, 10751,  3860,   286,
           262,  1379,  1169,  7278,    13, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256],
        [18435,    11, 22108,   494,    13,  4377,   922,   287,   534,  2000,
            30, 50256, 50256, 502

In [84]:
io = audio_inputs[-2, :, :].unsqueeze(0)
io.shape

torch.Size([1, 80, 3000])

In [85]:
ia = attention_mask[-2, :].unsqueeze(0)
ia.shape

torch.Size([1, 448])

In [103]:
output = model.generate(
    io, 
    attention_mask=ia, 
    return_segments=True, 
    return_timestamps=True,
    no_speech_threshold=0.8,  # Adjust silence sensitivity
    logprob_threshold=-1.0,   # Ensure log-probabilities are computed
    temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)  # Enable temperature fallback
)

In [99]:
# Extract segments
segments = output["segments"][0]  # First batch element

# Convert tokens to readable text
for segment in segments:
    start, end = segment["start"].item(), segment["end"].item()
    tokens = segment["tokens"]
    text = processor.tokenizer.decode(tokens, skip_special_tokens=True)

    print(f"Start: {start:.2f}s - End: {end:.2f}s")
    print(f"Text: {text}\n")

Start: 0.00s - End: 10.00s
Text:  The dull light fell more faintly upon the page where on another equation began to unfold itself slowly and to spread abroad its widening tail.

