# Real-Time Multilingual ASR with Whisper

This notebook implements a production-ready, real-time speech recognition system using OpenAI's Whisper models.

**Features:**
- Data preparation with augmentation
- Model fine-tuning with multiple variants
- Comprehensive evaluation (WER, CER, latency)
- Real-time streaming inference
- Full MLOps best practices (versioning, logging, reproducibility)

**Author:** COMP3057 Project  
**Version:** 1.0.0

## 1. Setup & Installation

In [None]:
# Setup environment and clone project
import os
import sys

# Check if running in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("🔧 Running in Google Colab")
    
    # Mount Google Drive for saving checkpoints and logs
    from google.colab import drive
    if not os.path.exists('/content/drive'):
        print("📁 Mounting Google Drive...")
        drive.mount('/content/drive', force_remount=False)
        print("✓ Drive mounted at /content/drive")
    else:
        print("✓ Drive already mounted")
    
    # Clone repository if not exists
    if not os.path.exists('COMP3057_Project'):
        print("\n📦 Cloning repository from GitHub...")
        !git clone https://github.com/jimmy00415/COMP3057_Project.git
        print("✓ Repository cloned")
    else:
        print("✓ Repository already exists")
    
    # Change to project directory
    os.chdir('COMP3057_Project')
    print(f"✓ Working directory: {os.getcwd()}")
else:
    print("💻 Running locally")
    print(f"Working directory: {os.getcwd()}")

# Check GPU availability
import torch
gpu_available = torch.cuda.is_available()
if gpu_available:
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"\n🎮 GPU: {gpu_name}")
    print(f"   Memory: {gpu_memory:.1f} GB")
    print(f"   CUDA Version: {torch.version.cuda}")
else:
    print("\n⚠️  WARNING: No GPU detected! Training will be very slow.")
    print("   Enable GPU: Runtime → Change runtime type → GPU (T4)")

In [None]:
# Install dependencies
!pip install -q -r requirements.txt

# Install additional Colab-specific packages
if IN_COLAB:
    !pip install -q sounddevice

In [None]:
# Import project modules
import sys
import os

# Add project root to Python path
project_root = os.getcwd()
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    print(f"Added '{project_root}' to Python path")
    # print("\nUpdated sys.path:")
    # for p in sys.path[:5]:
    #     print(f"  - {p}")

import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Audio, display

# Import project modules
from src.utils import (
    load_config,
    set_seed,
    setup_logging,
    get_device,
    ExperimentLogger,
    DataVersionManager,
    ModelRegistry
)

from src.data import (
    AudioPreprocessor,
    VoiceActivityDetector,
    AudioAugmenter,
    WhisperDataset,
    prepare_datasets,
    create_dataloaders
)

from src.models import (
    WhisperModelManager,
    compare_models
)

from src.training import WhisperTrainer

from src.evaluation import (
    ModelEvaluator,
    LatencyBenchmark,
    TrainingVisualizer,
    EvaluationVisualizer,
    generate_comparison_table
)

from src.inference import (
    StreamingASR,
    BatchInference
)

print("\n✓ All modules imported successfully")

## 2. Configuration & Reproducibility Setup

In [None]:
# Load configuration
config = load_config('config.yaml')

# Set random seeds for reproducibility
set_seed(config['project']['seed'])

# Setup logging
logger = setup_logging('INFO', 'logs/training.log')

# Get device
device = get_device(config['project']['device'])
logger.info(f"Using device: {device}")

# Initialize experiment tracking (choose: wandb, mlflow, or tensorboard)
experiment_logger = ExperimentLogger(
    backend=config['mlops']['experiment_tracking']['backend'],
    project_name=config['mlops']['experiment_tracking']['project_name'],
    config=config
)

# Initialize versioning
data_version_manager = DataVersionManager(
    config['mlops']['versioning']['data_version_file']
)
model_registry = ModelRegistry(
    config['mlops']['versioning']['model_registry']
)

print(f"✓ Configuration loaded")
print(f"  - Seed: {config['project']['seed']}")
print(f"  - Device: {device}")
print(f"  - Tracking: {config['mlops']['experiment_tracking']['backend']}")

## 3. Data Preparation

In [None]:
# Initialize preprocessing utilities
audio_preprocessor = AudioPreprocessor(
    target_sr=config['data']['sampling_rate'],
    normalize=True
)

vad = VoiceActivityDetector(
    threshold=config['data']['vad_threshold']
)

augmenter = AudioAugmenter(
    speed_perturbation=config['data']['augmentation']['speed_perturbation'],
    pitch_shift_semitones=config['data']['augmentation']['pitch_shift_semitones'],
    background_noise_prob=config['data']['augmentation']['background_noise_prob']
) if config['data']['augmentation']['enabled'] else None

print("✓ Preprocessing utilities initialized")

In [None]:
# Load datasets
# Note: This uses a simplified version with Common Voice
# For full implementation, add People's Speech and SpeechOcean762

from datasets import load_dataset

# Load Common Voice (small subset for demo)
print("Loading Common Voice dataset...")
dataset = load_dataset(
    "mozilla-foundation/common_voice_11_0",
    "en",
    split="train[:1000]+validation[:200]",  # Small subset for demo
    trust_remote_code=True
)

# Split into train/val
dataset = dataset.train_test_split(test_size=0.1, seed=config['project']['seed'])
train_dataset = dataset['train']
val_dataset = dataset['test']

print(f"✓ Dataset loaded")
print(f"  - Training samples: {len(train_dataset)}")
print(f"  - Validation samples: {len(val_dataset)}")

# Log dataset version
data_version_manager.log_dataset_version(
    dataset_name="common_voice_11_0",
    version="en_subset_1200",
    metadata={'train': len(train_dataset), 'val': len(val_dataset)}
)

## 4. Model Initialization

In [None]:
# Initialize model manager
model_manager = WhisperModelManager(config)

# Choose model variant: tiny, base, small, medium, distil
MODEL_VARIANT = 'base'  # Change to 'tiny' for faster training, 'small' for better accuracy

# Load model and processor
model, processor = model_manager.initialize_model(
    variant=MODEL_VARIANT,
    device=str(device)
)

# Get model info
model_info = model_manager.get_model_info()
print(f"\n✓ Model initialized: {MODEL_VARIANT}")
print(f"  - Total parameters: {model_info['total_parameters']:,}")
print(f"  - Trainable parameters: {model_info['trainable_parameters']:,}")

In [None]:
# Compare model variants
variants_info = compare_models(config)

print("\nWhisper Model Variants Comparison:")
print("-" * 60)
for variant, info in variants_info.items():
    print(f"{variant:10s} | Params: {info['params']:8s} | Speed: {info['speed']:10s} | Accuracy: {info['accuracy']}")

## 5. Prepare Data Loaders

In [None]:
# Create PyTorch datasets
train_dataset_wrapper = WhisperDataset(
    train_dataset,
    processor,
    audio_column="audio",
    text_column="sentence",
    max_audio_length_sec=config['data']['audio_max_length_sec'],
    augmenter=augmenter  # Apply augmentation only to training
)

val_dataset_wrapper = WhisperDataset(
    val_dataset,
    processor,
    audio_column="audio",
    text_column="sentence",
    max_audio_length_sec=config['data']['audio_max_length_sec'],
    augmenter=None  # No augmentation for validation
)

# Create data loaders
from src.data import DataCollatorWithPadding
from torch.utils.data import DataLoader

collator = DataCollatorWithPadding(processor)

train_loader = DataLoader(
    train_dataset_wrapper,
    batch_size=config['training']['batch_size'],
    shuffle=True,
    collate_fn=collator,
    num_workers=0
)

val_loader = DataLoader(
    val_dataset_wrapper,
    batch_size=config['training']['batch_size'],
    shuffle=False,
    collate_fn=collator,
    num_workers=0
)

print(f"✓ Data loaders created")
print(f"  - Training batches: {len(train_loader)}")
print(f"  - Validation batches: {len(val_loader)}")

## 6. Model Fine-Tuning

In [None]:
# Prepare model for training
model = model_manager.prepare_for_training()

# Initialize trainer
trainer = WhisperTrainer(
    model=model,
    processor=processor,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    device=str(device),
    experiment_logger=experiment_logger
)

print("✓ Trainer initialized")

In [None]:
# Train model
# Adjust num_epochs for quick testing (e.g., 1-2 epochs)
# For production, use config['training']['num_epochs'] (10 epochs)

TRAIN_EPOCHS = 2  # Set to 1-2 for quick demo, 10+ for production

print(f"Starting training for {TRAIN_EPOCHS} epochs...\n")

best_val_loss = trainer.train(num_epochs=TRAIN_EPOCHS)

print(f"\n✓ Training completed!")
print(f"  - Best validation loss: {best_val_loss:.4f}")

In [None]:
# Register trained model
from src.utils import get_git_revision

model_id = model_registry.register_model(
    model_id=f"{MODEL_VARIANT}_finetuned_{TRAIN_EPOCHS}ep",
    model_path="checkpoints/best_model_hf",
    metrics={'val_loss': best_val_loss},
    config=config,
    git_revision=get_git_revision(),
    dataset_version="common_voice_11_0_en_subset"
)

print(f"✓ Model registered: {model_id}")

## 7. Model Evaluation

In [None]:
# Initialize evaluator
evaluator = ModelEvaluator(
    model=model,
    processor=processor,
    device=str(device)
)

# Evaluate on validation set
print("Evaluating model...\n")
results = evaluator.evaluate_with_samples(val_loader, num_samples=5)

print(f"\n✓ Evaluation Results:")
print(f"  - WER: {results['metrics']['wer']:.3f}")
print(f"  - CER: {results['metrics']['cer']:.3f}")

# Show sample predictions
print("\n📝 Sample Predictions:")
print("-" * 80)
for i, sample in enumerate(results['samples'][:3], 1):
    print(f"\nSample {i}:")
    print(f"  Reference:  {sample['reference']}")
    print(f"  Prediction: {sample['prediction']}")

In [None]:
# Benchmark latency
print("Benchmarking inference latency...\n")

latency_bench = LatencyBenchmark(
    model=model,
    processor=processor,
    device=str(device)
)

# Generate test audio clips
test_audios = []
for i in range(10):  # Test on 10 samples
    sample = val_dataset[i]
    audio = torch.tensor(sample['audio']['array'])
    test_audios.append(audio)

latency_results = latency_bench.benchmark_batch(test_audios, sr=16000)

print(f"✓ Latency Benchmark Results:")
print(f"  - Mean latency: {latency_results['mean_latency']:.3f}s")
print(f"  - Std latency: {latency_results['std_latency']:.3f}s")
print(f"  - Mean RTF: {latency_results['mean_rtf']:.3f}x")
print(f"\n  RTF < 1.0 = Real-time capable ✓" if latency_results['mean_rtf'] < 1.0 else "  RTF >= 1.0 = Not real-time")

## 8. Real-Time Streaming Inference

In [None]:
# Initialize streaming ASR
streaming_asr = StreamingASR(
    model=model,
    processor=processor,
    vad=vad,
    chunk_length_sec=config['inference']['streaming']['buffer_size_sec'],
    overlap_sec=config['inference']['streaming']['overlap_sec'],
    device=str(device)
)

print("✓ Streaming ASR initialized")

In [None]:
# Test streaming on audio file
# Use a sample from validation set

# Get a test audio file
test_sample = val_dataset[0]
test_audio = test_sample['audio']['array']
test_text = test_sample['sentence']

# Save to temporary file
import torchaudio
import tempfile

with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
    tmp_path = tmp.name
    torchaudio.save(tmp_path, torch.tensor(test_audio).unsqueeze(0), 16000)

print(f"Test audio saved to: {tmp_path}")
print(f"Reference text: {test_text}\n")

# Stream from file
print("Streaming transcription:")
print("-" * 80)

transcriptions = []
def callback(text):
    print(f"[CHUNK] {text}")
    transcriptions.append(text)

streaming_asr.reset()
streaming_asr.stream_from_file(tmp_path, chunk_duration_sec=0.5, callback=callback)

# Get full transcription
full_transcription = streaming_asr.get_full_transcription(merge=True)

print("\n" + "-" * 80)
print(f"\n📝 Final Transcription: {full_transcription}")
print(f"📖 Reference:          {test_text}")

# Cleanup
os.unlink(tmp_path)

In [None]:
# Live microphone streaming (requires microphone access)
# Uncomment to use:

# print("Starting live microphone transcription...")
# print("Speak into your microphone. Press Ctrl+C to stop.\n")

# streaming_asr.reset()

# def mic_callback(text):
#     print(f"🎤 {text}")

# streaming_asr.stream_from_microphone(
#     duration_sec=30,  # Record for 30 seconds
#     callback=mic_callback
# )

# full_transcription = streaming_asr.get_full_transcription(merge=True)
# print(f"\n📝 Full Transcription: {full_transcription}")

## 9. Batch Inference

In [None]:
# Batch inference for multiple files
batch_inference = BatchInference(
    model=model,
    processor=processor,
    device=str(device),
    batch_size=config['inference']['batch_size']
)

# Create temporary test files
import tempfile
import torchaudio

test_files = []
for i in range(5):
    sample = val_dataset[i]
    audio = torch.tensor(sample['audio']['array'])
    
    with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
        torchaudio.save(tmp.name, audio.unsqueeze(0), 16000)
        test_files.append(tmp.name)

# Batch transcribe
print("Batch transcribing 5 files...\n")
batch_transcriptions = batch_inference.transcribe_batch(test_files)

# Display results
for i, transcription in enumerate(batch_transcriptions):
    reference = val_dataset[i]['sentence']
    print(f"File {i+1}:")
    print(f"  Prediction: {transcription}")
    print(f"  Reference:  {reference}")
    print()

# Cleanup
for f in test_files:
    os.unlink(f)

print("✓ Batch inference completed")

## 10. Visualization & Analysis

In [None]:
# Visualize evaluation results
eval_viz = EvaluationVisualizer(save_dir='plots')

# Model comparison data (example)
comparison_data = {
    'whisper-tiny': {
        'params': '39M',
        'wer_clean': 0.15,
        'wer_accented': 0.20,
        'latency': 0.1,
        'rtf': 0.1
    },
    'whisper-base': {
        'params': '74M',
        'wer_clean': results['metrics']['wer'],
        'wer_accented': 0.14,
        'latency': latency_results['mean_latency'],
        'rtf': latency_results['mean_rtf']
    },
    'whisper-small': {
        'params': '244M',
        'wer_clean': 0.08,
        'wer_accented': 0.12,
        'latency': 0.5,
        'rtf': 0.5
    }
}

# Plot comparison
comparison_path = eval_viz.plot_model_comparison(comparison_data)
print(f"✓ Model comparison plot saved: {comparison_path}")

# Display
from IPython.display import Image
display(Image(comparison_path))

In [None]:
# Generate comparison table
table = generate_comparison_table(comparison_data)
print("\nModel Comparison Table:")
print(table)

## 11. Export & Deployment

In [None]:
# Save final model for deployment
FINAL_MODEL_PATH = 'final_model'

model.save_pretrained(FINAL_MODEL_PATH)
processor.save_pretrained(FINAL_MODEL_PATH)

print(f"✓ Final model saved to: {FINAL_MODEL_PATH}")
print(f"\nTo load model later:")
print(f"  from transformers import WhisperForConditionalGeneration, WhisperProcessor")
print(f"  model = WhisperForConditionalGeneration.from_pretrained('{FINAL_MODEL_PATH}')")
print(f"  processor = WhisperProcessor.from_pretrained('{FINAL_MODEL_PATH}')")

In [None]:
# Optional: Upload to HuggingFace Hub
# Requires HuggingFace token and authentication

# from huggingface_hub import notebook_login
# notebook_login()

# HF_MODEL_NAME = "your-username/whisper-base-finetuned-en"
# model.push_to_hub(HF_MODEL_NAME)
# processor.push_to_hub(HF_MODEL_NAME)
# print(f"✓ Model uploaded to HuggingFace: {HF_MODEL_NAME}")

## 12. Cleanup & Summary

In [None]:
# Finish experiment tracking
experiment_logger.finish()

# Summary
print("\n" + "="*80)
print("📊 PROJECT SUMMARY")
print("="*80)

print(f"\n✓ Model: {MODEL_VARIANT}")
print(f"✓ Training epochs: {TRAIN_EPOCHS}")
print(f"✓ Best validation loss: {best_val_loss:.4f}")
print(f"✓ WER: {results['metrics']['wer']:.3f}")
print(f"✓ CER: {results['metrics']['cer']:.3f}")
print(f"✓ Mean latency: {latency_results['mean_latency']:.3f}s")
print(f"✓ RTF: {latency_results['mean_rtf']:.3f}x")

print(f"\n✓ Model saved: {FINAL_MODEL_PATH}")
print(f"✓ Checkpoints: checkpoints/")
print(f"✓ Logs: logs/")
print(f"✓ Plots: plots/")

print("\n" + "="*80)
print("🎉 Real-Time ASR System Ready for Deployment!")
print("="*80)