In [27]:
!pip install -U transformers



In [None]:
# train_whisper_metrics_colab.py
# -*- coding: utf-8 -*-
"""
Fine-tune OpenAI Whisper on LibriSpeech with Gaussian noise augmentation,
and report Training/Validation Loss, WER and CER each epoch.
"""

In [None]:

import os
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import DataLoader
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    TrainerCallback,
)

In [29]:
# -------------------- 1) Configuration --------------------
MODEL_NAME         = "openai/whisper-small"
DATA_DIR           = "./data"
OUTPUT_DIR         = "./whisper_finetuned"
BATCH_SIZE         = 4
NUM_EPOCHS         = 1
LEARNING_RATE      = 3e-5
MAX_TARGET_LENGTH  = 128
SAMPLE_RATE        = 16000

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [30]:
# -------------------- 2) Noise Augmentation --------------------
class AddGaussianNoise(nn.Module):
    def __init__(self, noise_level=0.01):
        super().__init__()
        self.noise_level = noise_level
    def forward(self, waveform):
        if self.training:
            return waveform + torch.randn_like(waveform) * self.noise_level
        return waveform

audio_augment = AddGaussianNoise(noise_level=0.01)

In [31]:
# -------------------- 3) Load Whisper Processor & Model --------------------
processor = WhisperProcessor.from_pretrained(MODEL_NAME)
model     = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)

In [32]:
# -------------------- 4) Text-/Error-Rate Helpers --------------------
def levenshtein_distance(a, b):
    m, n = len(a), len(b)
    dp = [[0]*(n+1) for _ in range(m+1)]
    for i in range(m+1):
        for j in range(n+1):
            if i == 0:        dp[i][j] = j
            elif j == 0:      dp[i][j] = i
            else:
                cost = 0 if a[i-1] == b[j-1] else 1
                dp[i][j] = min(dp[i-1][j] + 1,
                               dp[i][j-1] + 1,
                               dp[i-1][j-1] + cost)
    return dp[m][n]

def wer(ref, hyp):
    ref_words = ref.split()
    hyp_words = hyp.split()
    return levenshtein_distance(ref_words, hyp_words) / max(len(ref_words), 1)

def cer(ref, hyp):
    return levenshtein_distance(ref, hyp) / max(len(ref), 1)


In [33]:
# -------------------- 5) Custom Callback --------------------
class TrainingReporter(TrainerCallback):
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        # Get both training and validation metrics
        train_metrics = {}
        eval_metrics = metrics.copy() if metrics else {}
        
        # Find the training metrics for this epoch
        for log in reversed(state.log_history):
            if 'loss' in log and log.get('epoch', 0) == state.epoch:
                train_metrics = log
                break

        print(f"\n=== Epoch {state.epoch} Report ===")
        if train_metrics:
            print(f"Training Loss:   {train_metrics.get('loss', 'N/A'):.4f}")
        if eval_metrics:
            print(f"Validation Loss: {eval_metrics.get('eval_loss', 'N/A'):.4f}")
            print(f"Validation WER:  {eval_metrics.get('eval_wer', 'N/A'):.4f}")
            print(f"Validation CER:  {eval_metrics.get('eval_cer', 'N/A'):.4f}")

In [34]:
# -------------------- 6) Data Collation --------------------
def preprocess_batch(batch, processor, augment=None):
    audio_inputs = []
    labels_text = []
    
    for waveform, sr, transcript, *_ in batch:
        # Ensure audio is at least 30 seconds (Whisper's default)
        target_length = 30 * SAMPLE_RATE  # 30 seconds
        if waveform.shape[1] < target_length:
            # Pad with silence if too short
            padding = target_length - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, padding))
        elif waveform.shape[1] > target_length:
            # Truncate if too long
            waveform = waveform[:, :target_length]
            
        if augment is not None:
            waveform = augment(waveform)
            
        audio_inputs.append(waveform.squeeze(0).numpy())
        labels_text.append(transcript.lower())

    inputs = processor.feature_extractor(
        audio_inputs,
        sampling_rate=SAMPLE_RATE,
        return_tensors="pt",
        padding=True
    )
    
    label_tokens = processor.tokenizer(
        labels_text,
        max_length=MAX_TARGET_LENGTH,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    
    inputs["labels"] = label_tokens.input_ids
    return inputs

def data_collator(batch):
    return preprocess_batch(batch, processor, augment=audio_augment)


In [35]:
# -------------------- 7) Prepare Datasets --------------------
train_dataset = torchaudio.datasets.LIBRISPEECH(
    root=DATA_DIR, url="train-clean-100", download=True
)
eval_dataset  = torchaudio.datasets.LIBRISPEECH(
    root=DATA_DIR, url="test-clean", download=True
)

In [36]:
# -------------------- 8) Metrics Callback --------------------
def compute_metrics(pred):
    pred_ids  = pred.predictions
    label_ids = pred.label_ids.copy()
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    pred_strs  = processor.batch_decode(pred_ids,  skip_special_tokens=True)
    label_strs = processor.batch_decode(label_ids, skip_special_tokens=True)

    wers = [wer(r, h) for r, h in zip(label_strs, pred_strs)]
    cers = [cer(r, h) for r, h in zip(label_strs, pred_strs)]
    return {"wer": sum(wers)/len(wers), "cer": sum(cers)/len(cers)}


In [37]:
# -------------------- 9) Training Setup --------------------
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
    logging_steps=50,
    report_to="none",  # Disable default logging to use our custom reports
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=processor.tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[TrainingReporter()],  # Add our custom callback
)


  trainer = Seq2SeqTrainer(


In [38]:
import torch
torch.cuda.empty_cache()

In [None]:
# -------------------- 10) Execute Training --------------------
print("=== Starting Training ===")
trainer.train()

print("\n=== Final Evaluation ===")
final_metrics = trainer.evaluate()
print(f"Final Validation Loss: {final_metrics['eval_loss']:.4f}")
print(f"Final Validation WER:  {final_metrics['eval_wer']:.4f}")
print(f"Final Validation CER:  {final_metrics['eval_cer']:.4f}")

=== Starting Training ===




Epoch,Training Loss,Validation Loss


In [None]:
# -------------------- 11) Sample Inference --------------------
sample = eval_dataset[0]
wave, sr, _, *_ = sample
wave = audio_augment(wave)
inputs = processor.feature_extractor(
    wave.squeeze(0).numpy(),
    sampling_rate=SAMPLE_RATE,
    return_tensors="pt",
    padding=True
).to(model.device)

generated_ids = model.generate(
    inputs.input_features,
    max_length=MAX_TARGET_LENGTH,
    num_beams=5,
    no_repeat_ngram_size=2
)
print("\nSample Transcription:", 
      processor.batch_decode(generated_ids, skip_special_tokens=True)[0])