## 4 Model Evaluation

This notebook loads the Whisper medium model finetuned on 8 hours of Sursilvan data and evaluates it on the test set.

In [10]:
# Cell 1: Imports
import os
import torch
import whisper
import pandas as pd
import numpy as np
from jiwer import wer, cer
from tqdm import tqdm
import warnings
from transformers import WhisperForConditionalGeneration, WhisperProcessor
warnings.filterwarnings("ignore")
import librosa
from torch.utils.data import DataLoader, Dataset

# Configuration
MODEL_PATH = "./whisper-medium-rm-finetuned"  # Path to your fine-tuned model
DATA_PATH = "romansh-data/sursilvan-small/"  # Path to your data
TEST_FILE = "test.tsv"
CLIPS_DIR = "clips"
BATCH_SIZE = 16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_SAMPLES = None  # Set to a number for quick test, None for full test set

In [11]:
# Cell 2: Check GPU and Load Model (Fixed)
print("="*60)
print("Whisper Romansh Model Evaluation")
print("="*60)

print(f"Device: {DEVICE}")
if DEVICE == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Load your fine-tuned model (Hugging Face format)
print(f"\nüì• Loading fine-tuned model from {MODEL_PATH}...")

# Load processor (for feature extraction and tokenization)
processor = WhisperProcessor.from_pretrained(MODEL_PATH)

# Load model
model = WhisperForConditionalGeneration.from_pretrained(MODEL_PATH).to(DEVICE)

print("‚úÖ Model loaded successfully!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")

Whisper Romansh Model Evaluation
Device: cuda
GPU: NVIDIA GeForce RTX 3090

üì• Loading fine-tuned model from ./whisper-medium-rm-finetuned...


Loading weights:   0%|          | 0/947 [00:00<?, ?it/s]

‚úÖ Model loaded successfully!
Model parameters: 763.9M


In [12]:
# Cell 3: Load Test Data
print("\nüìÇ Loading test data...")

test_tsv = os.path.join(DATA_PATH, TEST_FILE)
clips_path = os.path.join(DATA_PATH, CLIPS_DIR)

# Read test TSV
df_test = pd.read_csv(test_tsv, sep='\t')

# Filter for existing audio files
audio_paths = []
valid_indices = []
missing_files = []

for idx, row in df_test.iterrows():
    audio_path = os.path.join(clips_path, row['path'])
    if os.path.exists(audio_path):
        audio_paths.append(audio_path)
        valid_indices.append(idx)
    else:
        missing_files.append(row['path'])

# Keep only valid samples
df_test = df_test.loc[valid_indices].reset_index(drop=True)

print(f"Total test samples: {len(df_test)}")
if missing_files:
    print(f"Missing audio files: {len(missing_files)}")

# Limit samples if specified
if NUM_SAMPLES:
    df_test = df_test.head(NUM_SAMPLES)
    audio_paths = audio_paths[:NUM_SAMPLES]
    print(f"Using first {NUM_SAMPLES} samples for quick test")


üìÇ Loading test data...
Total test samples: 94


In [14]:
# Cell 4: Batch Transcription (Fixed)

class AudioDataset(Dataset):
    def __init__(self, audio_paths, processor, device):
        self.audio_paths = audio_paths
        self.processor = processor
        self.device = device
    
    def __len__(self):
        return len(self.audio_paths)
    
    def __getitem__(self, idx):
        audio_path = self.audio_paths[idx]
        # Load audio
        audio_array, sr = librosa.load(audio_path, sr=16000)
        
        # Process to features
        input_features = self.processor(
            audio_array, 
            sampling_rate=16000, 
            return_tensors="pt"
        ).input_features[0]  # Remove batch dimension
        
        return input_features

def collate_fn(batch):
    """Custom collate function to handle different sized tensors"""
    # Find max length in batch
    max_len = max(features.shape[-1] for features in batch)
    
    # Pad all features to max length
    padded_batch = []
    for features in batch:
        pad_len = max_len - features.shape[-1]
        if pad_len > 0:
            # Pad with zeros along the time dimension
            padding = torch.zeros((features.shape[0], pad_len))
            padded = torch.cat([features, padding], dim=-1)
        else:
            padded = features
        padded_batch.append(padded)
    
    # Stack into batch
    return torch.stack(padded_batch)

print(f"\nüéôÔ∏è Transcribing {len(audio_paths)} test files...")

# Create dataset and dataloader with custom collate
dataset = AudioDataset(audio_paths, processor, DEVICE)
dataloader = DataLoader(
    dataset, 
    batch_size=8, 
    shuffle=False, 
    num_workers=0,  # Set to 0 to avoid multiprocessing issues
    collate_fn=collate_fn
)

transcriptions = []

for batch_features in tqdm(dataloader, desc="Transcribing"):
    # Move batch to device
    batch_features = batch_features.to(DEVICE)
    
    # Generate transcriptions for the batch
    with torch.no_grad():
        predicted_ids = model.generate(
            batch_features,
            max_length=225,
            num_beams=1,
            task="transcribe"
        )
    
    # Decode batch
    batch_transcriptions = processor.batch_decode(
        predicted_ids, 
        skip_special_tokens=True
    )
    transcriptions.extend(batch_transcriptions)

print(f"‚úÖ Transcribed {len(transcriptions)} files")


üéôÔ∏è Transcribing 94 test files...


Transcribing:   0%|          | 0/12 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> will take precedence. Please check the docstring of <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> to see related `.generate()` flags.
A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'tr

‚úÖ Transcribed 94 files





In [None]:
# Cell 5: Calculate WER and CER
print("\n" + "="*60)
print("üìä FINAL RESULTS")
print("="*60)

# Filter valid pairs
valid_pairs = [(ref, hyp) for ref, hyp in zip(references, transcriptions) 
               if ref and hyp]

if not valid_pairs:
    print("‚ùå No valid reference-hypothesis pairs found!")
else:
    refs = [pair[0] for pair in valid_pairs]
    hyps = [pair[1] for pair in valid_pairs]
    
    wer_score = wer(refs, hyps)
    cer_score = cer(refs, hyps)
    
    print(f"\nTest set size: {len(df_test)} utterances")
    print(f"Valid pairs: {len(valid_pairs)}/{len(df_test)}")
    print(f"\nüìà Word Error Rate (WER): {wer_score:.4f} ({wer_score*100:.2f}%)")
    print(f"üìà Character Error Rate (CER): {cer_score:.4f} ({cer_score*100:.2f}%)")


üìä FINAL RESULTS
["lgl agid svizzer per la muntogna sustegn ils purs grischuns. A Tartar sin la Muntogna han responsabels digl agid da muntogna surdau oz in schec da 3 100 1000 francs pils purs ch'han stuiu cumprar pavel muort la schitgira digl onn vargau. lls daners dueien vegnir reparti tochen la fin da matg, dabien enzatgei han mo purs ch'han piars silmeins 30 prozent da lur racolta. L'uniun purila grischuna schazegia ils dons totals sin 10 milliuns francs.", "II possessur dalla Casa Demont a Vella sto pagar il castitg maximal per haver spazzau illegalmein quella casa. 30 1000 francs cuosta ei ad el d'haver destruiu la casa da valeta historica e culturala malgrad che la damonda era vegnida renviada. L'interpresa ch'ha spazzau la casa sto pagar 2 1000 francs.", 'In incendi sil plazzal dalla NEAT a Sedrun ha caschunau donns da varga in milliunfrancs. La tschenta che transportescha il material ord il tunnel ella val Bugnei hapigliau fiug, perquei che lavurs da reparar quella han cas

In [17]:
# Cell 6: Sample Transcriptions
print("\n" + "="*60)
print("üìù SAMPLE TRANSCRIPTIONS")
print("="*60)

# Show 5 random samples
import random
sample_indices = random.sample(range(len(valid_pairs)), min(5, len(valid_pairs)))

for i, idx in enumerate(sample_indices):
    print(f"\n--- Sample {i+1} ---")
    print(f"Reference: {refs[idx][:200]}...")
    print(f"Hypothesis: {hyps[idx][:200]}...")
    
    # Calculate sample-level WER
    sample_wer = wer(refs[idx], hyps[idx])
    print(f"Sample WER: {sample_wer:.4f}")
    print("-" * 40)


üìù SAMPLE TRANSCRIPTIONS

--- Sample 1 ---
Reference: La solidaritad cun Bondo ei stada gronda ‚Ä¶ suenter la catastrofa dalla stad avon dus onns. Tier acziuns da collecta eis ei vegniu ensemen milliuns. Dapi quest'jamna sai ins denton ... che mo ina pintg...
Hypothesis: La solidaritad cun Bondo ei stada gronda suenter la catastrofa dalla stad avon dus onns. Tiracziuns da collecta eis vegniu ensemen milliuns. Dapi quest'jamna san ins denton che mo ina pintga part da q...
Sample WER: 0.1719
----------------------------------------

--- Sample 2 ---
Reference: Mo in punct, e nuot dapli.E puncto conclusiun, less jeu era concluder mia jamna dad impuls e giavischar avus tuts ina biala dumengia....
Hypothesis: Moi in punct e nuot dapli. E puncto conclusiun, less jeu era concluder moia jamna dad impuls e giavischar a vus tuts ina biala dumengia....
Sample WER: 0.3043
----------------------------------------

--- Sample 3 ---
Reference: Il Sven e sia battaria. Propi in duo ch'ins damogna bu

In [19]:
# Cell 8: Quick Stats
print("\n" + "="*60)
print("üìä QUICK STATISTICS")
print("="*60)

sample_wers = [wer(r, h) for r, h in zip(refs, hyps)]
print(f"Mean WER: {np.mean(sample_wers):.4f}")
print(f"Median WER: {np.median(sample_wers):.4f}")
print(f"Std WER: {np.std(sample_wers):.4f}")
print(f"Min WER: {np.min(sample_wers):.4f}")
print(f"Max WER: {np.max(sample_wers):.4f}")

# Success rate (WER < 30%)
success_rate = sum(1 for w in sample_wers if w < 0.3) / len(sample_wers)
print(f"\n‚úÖ Success rate (WER < 30%): {success_rate*100:.1f}%")
print("="*60)


üìä QUICK STATISTICS
Mean WER: 0.3028
Median WER: 0.2952
Std WER: 0.1587
Min WER: 0.0000
Max WER: 0.7600

‚úÖ Success rate (WER < 30%): 50.0%
