This notebook is contains the basic training and evaluation loop for fine tuning Whisper. 
- See whisper-dataset-creation.ipynb to create a dataset from raw audio files
- Performance metric functions are found after the main training cycle

-------------------

Create the Whisper processor
- whisper-base is common
- Sets device appropriately



In [None]:
from transformers import WhisperProcessor
import torch
import pandas as pd

model_name = "openai/whisper-base"
language = "english" # Change to your dataset's language
task = "transcribe" # Use "translate" if you're trans"lating to English

processor = WhisperProcessor.from_pretrained(model_name, language=language, task=task)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Create Whisper model. 
- If loading from a fine-tuned checkpoint use pretrained_path variable
- create_whisper_model automatically freezes all parameters except for the LM layer 

In [None]:
from transformers import WhisperForConditionalGeneration

def create_whisper_model(model_name, device, pretrained_path=False):
    if not pretrained_path:
        model = WhisperForConditionalGeneration.from_pretrained(model_name)
        print(f'Loaded {model_name} on {device}')
    else: 
        model = WhisperForConditionalGeneration.from_pretrained(pretrained_path)
        print(f'Loaded {model_name} on {device} from checkpoint {pretrained_path}')

    # Send to device
    model.to(device)

    for param in model.parameters():
        param.requires_grad = False

    # Except those in the last layer
    for param in model.proj_out.parameters():
        param.requires_grad = True

    # Verify which layers are trainable
    print("\nTrainable parameters after freezing:")
    trainable_params = 0
    frozen_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            trainable_params += param.numel()
            print(f"  - {name} (Trainable, shape: {param.shape})")
        else:
            frozen_params += param.numel()
            # print(f"  - {name} (Frozen)") # Uncomment to see all frozen params

    total_params = trainable_params + frozen_params
    print(f"\nTotal trainable parameters: {trainable_params}")
    print(f"Total frozen parameters: {frozen_params}")
    print(f"Total parameters: {total_params}")
    print(f"Ratio of trained params to total params: {trainable_params / total_params:.4f}")

    return model

model = create_whisper_model(model_name, device)
    
    

Define DataCollator class for training

In [6]:
from transformers import DataCollatorForSeq2Seq
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch

# --- Data Collator ---
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels since they have to be of different lengths and need different padding methods.
        # "input_features" for Whisper-based models (vs. "input_values" for wav2vec...)
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.feature_extractor.pad(input_features, 
                                                     return_tensors="pt",
                                                     return_attention_mask=True)
        
        labels_batch = self.processor.tokenizer.pad(label_features,
                                                    padding='longest', 
                                                    return_tensors="pt",)

        # Replace padding with -100 for loss to work correctly correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, padding=True)

Load an already created dataset. 
- sample_percentage can be used to downsample the dataset for quick testing

In [None]:
from datasets import load_from_disk

sample_percentage = .1

# Load full prepared dataset
prepared_dataset_path = 'wer0-dataset-fixed-padding'
prepared_datasets = load_from_disk(prepared_dataset_path)
print("--- Full Prepared Dataset ---")
print(prepared_datasets)

# Sample 1% from the training set
train_split = prepared_datasets["train"]
sampled_train_split = train_split.train_test_split(train_size=sample_percentage, shuffle=True, seed=555)['train'] # We only want the 'train' part of this new split

test_split = prepared_datasets["test"]
sampled_test_split = test_split.train_test_split(train_size=sample_percentage, shuffle=True, seed=555)['train'] 

# Overwrite the original splits with the sampled splits
prepared_datasets['train'] = sampled_train_split
prepared_datasets['test'] = sampled_test_split

print(f"\n--- Sampled ({sample_percentage*100}%) Dataset ---")
print(prepared_datasets)


Training parameters and dataloaders
- Sets learning rate, batch sizes, number of epochs, optimizer and LR scheduler

In [None]:
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import AdamW
import re

def remove_punctuation(s):
    s = re.sub(r'[^a-zA-Z0-9\s]', '', s)
    return s.lower()

# Training parameters
learning_rate = .001 # Max learning rate
train_batch_size = 64 # 64 works with 16GB of VRAM
eval_batch_size = 64

num_epochs = 20
total_steps = len(train_dataloader) * num_epochs

# Defined train and test DLs
train_dataloader = DataLoader(prepared_datasets["train"], shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
eval_dataloader = DataLoader(prepared_datasets["test"], collate_fn=data_collator, batch_size=eval_batch_size)

# Optim and LR scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=.00001)

# Forces the model.generate method to transcribe audio interpreted as english
forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")

# Directory for best model to be saved
output_dir = ".\\whisper-ft"

Main training cycle
- Automatically creates and saves a dataframe of the best MER output
- Patience counter for early exit of training

In [None]:
from tqdm import tqdm
import jiwer

# use for early stopping of training if no increase in MER is detected
patience = 0 
best_mer = float('inf')

for epoch in range(num_epochs):
    # train loop
    model.train()
    train_loss = 0
    
    for batch in tqdm(train_dataloader, desc=f"(Epoch {epoch+1} / {num_epochs}) Training "):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss

        # Backwards pass
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_dataloader)

    # eval loop
    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Generate predictions. Note this is different than model.transcribe. 
            # Use whisper-timestamped for time stamp generation
            # model.generate will only generate text transcriptions for 30s of audio
            generated_ids = model.generate(input_features=batch["input_features"], 
                                    attention_mask=batch["attention_mask"], 
                                    num_beams=3, 
                                    length_penalty=.8,
                                    early_stopping=True,
                                    forced_decoder_ids=forced_decoder_ids, # Depricated? 
                                    pad_token_id=processor.tokenizer.pad_token_id,
                                    eos_token_id=processor.tokenizer.eos_token_id
                                    )                    
            
            # Decode predictions
            predictions = processor.batch_decode(generated_ids, skip_special_tokens=True)
            
            # Decode labels, replacing -100 with pad token
            labels = batch["labels"].clone()
            labels[labels == -100] = processor.tokenizer.pad_token_id
            labels_str = processor.batch_decode(labels, skip_special_tokens=True)

            all_predictions.extend(predictions)
            all_labels.extend(labels_str)

    # Compute WER and MER. Since the text field lengths are of varied size, MER is a better metric for correcteness
    all_predictions = [remove_punctuation(p) for p in all_predictions]
    
    wer = jiwer.wer(all_predictions, all_labels)
    mer = jiwer.mer(all_predictions, all_labels)

    print(f"Avg training loss: {avg_train_loss:.4f} | Eval. MER: {mer:.5f}, WER: {wer:.5f}")
    print()

    # Save the model if it has the best WER so far
    if mer < best_mer:
        patience = 0 # reset patience counter
        
        best_mer = mer
        print(f"(!) New best MER: {best_mer}. Saving model...")
        model.save_pretrained(output_dir)
        processor.save_pretrained(output_dir)
        print(f"Model saved to {output_dir}")
        print()

        ## Create df to analyse the outputs for a best output
        to_add = []

        for i in range(len(all_predictions)):
            pred = all_predictions[i]
            actual = all_labels[i]
            mer = jiwer.mer(pred, actual)
            wer = jiwer.wer(pred, actual)

            to_add.append([pred, actual, mer, wer])

        df_best = pd.DataFrame(to_add, columns=['predicted', 'actual', 'mer', 'wer'])
        df_best.to_csv('best-mer-outputs.csv', index=False)

    else: 
        patience += 1

    if patience == 5: 
        print('No increase in MER detected in 5 rounds, breaking')
        break
        
print("\n--- Training Complete ---")
print(f"Best MER achieved: {best_mer}")

--------------------------

For investigating the outputs of whisper
- Load model if needed

In [None]:
from transformers import WhisperForConditionalGeneration
import pandas as pd
import jiwer
import torch
import re

#model = create_whisper_model(model_name, device)

Fine tuned model from .\whisper-ft loaded on cuda


Evaluation cycle only. 
- Automatically creates a dataframe of the outputs

In [None]:
from tqdm import tqdm 

forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
device = 'cuda' if torch.cuda.is_available() else 'gpu'
model.eval()

all_predictions = []
all_norm_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Generate predictions. Note this is different than model.transcribe (which is used for untrained?)
        generated_ids = model.generate(input_features=batch["input_features"], 
                                    attention_mask=batch["attention_mask"], 
                                    num_beams=3, 
                                    length_penalty=.8,
                                    early_stopping=True,
                                    task='transcribe',
                                    language='en',
                                    pad_token_id=processor.tokenizer.pad_token_id,
                                    eos_token_id=processor.tokenizer.eos_token_id
                                    )              
        
        # Decode predictions
        predictions = processor.batch_decode(generated_ids, skip_special_tokens=True)

        # Decode labels, replacing -100 with pad token
        labels = batch["labels"].clone()
        labels[labels == -100] = processor.tokenizer.pad_token_id
        labels_str = processor.batch_decode(labels, skip_special_tokens=True)

        all_predictions.extend(predictions)
        all_labels.extend(labels_str)

# Create df for easier inspection
lst = []

for i in range(len(all_predictions)):
    pred = all_predictions[i]
    pred_norm = remove_punctuation(pred)
    actual = all_labels[i]

    wer = jiwer.wer(pred_norm, actual)
    mer = jiwer.mer(pred_norm, actual)

    # print(f'Predicted - {pred}')
    # print(f'Actual    - {actual}')
    # print(f'Jiwer WER: {wer} | MER: {mer}')
    # print()

    lst.append([pred, pred_norm, actual, wer, mer])

df_wer = pd.DataFrame(l, columns=['Prediction (raw)', 'Prediction (normalized)', 'Actual', 'WER', 'MER'])
df_wer.to_csv('wer0-dataset-base-untrained.csv', index=False)

# Overall WER and MER scores
jwer = jiwer.wer(df_wer['Prediction (normalized)'].tolist(), df_wer['Actual'].tolist())
jmer = jiwer.mer(df_wer['Prediction (normalized)'].tolist(), df_wer['Actual'].tolist())

#print(f'Evaluate wer: {ev_wer:.5f}')
print(f'Overall Jiwer wer: {jwer:.5f} | mer {jmer:.5f}')


For investigating the outputs. 

In [None]:
#df_wer = pd.read_csv('small-dataset-whisper-trained-outputs.csv')

pd.set_option('display.max_colwidth', 60)
pd.set_option('display.max_rows', 500)

df_wer = df_wer.sort_values(by='WER')

df_wer

------------------------------

Test a model on an audio file. 
- Audio does not need to be preprocessed

In [None]:
import torchaudio

def test_transcribe(audio_path):
    # Put in evaluation mode
    model.eval()

    # Load audio file
    print(f"Loading audio from: {audio_path}...")
    waveform, sample_rate = torchaudio.load(audio_path)

    # Resample if necessary (Whisper expects 16kHz)
    if sample_rate != 16000:
        print(f"Resampling audio from {sample_rate}Hz to 16kHz...")
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)
        sample_rate = 16000 # Update sample rate after resampling

    # Ensure mono audio (Whisper expects single channel)
    if waveform.shape[0] > 1:
        print("Converting stereo audio to mono...")
        waveform = waveform.mean(dim=0, keepdim=True) # Average channels to mono

    # Convert to numpy array (required by feature_extractor for raw audio)
    audio_array = waveform.squeeze().numpy()

    # Extract features (Mel spectrogram)
    processed_audio = processor.feature_extractor(audio_array, 
                                                  sampling_rate=sample_rate, 
                                                  return_tensors="pt",
                                                  return_attention_mask=True,
                                                  )
 
    input_features = processed_audio.input_features.to(device)
    attention_mask = processed_audio.attention_mask.to(device)

    print("Generating transcription...")
    with torch.no_grad():
        generated_ids = model.generate(input_features=input_features, 
                                       attention_mask=attention_mask,
                                       max_new_tokens=400,
                                       temperature=0.0,
                                       #no_speech_threshold=.3 # Error when using this ?
                                       )
        
    # Create the transcription
    transcription = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return transcription

In [None]:
# Load and preprocess the audio file
audio_path = 'vocals.wav'
print("\nTranscription:\n", test_transcribe(audio_path))