# Fine-tuning Whisper Large-v3 with LoRA/PEFT

**Target Hardware**: RunPod A6000 (48GB VRAM)

This notebook fine-tunes Whisper large-v3 on VHP oral history audio using LoRA (Low-Rank Adaptation).

## Key Configuration
- **Model**: openai/whisper-large-v3 (via HuggingFace transformers)
- **Method**: LoRA with r=32, alpha=64
- **Learning Rate**: 1e-5 (CRITICAL: not 1e-3)
- **Precision**: fp16 (NOT int8 for V3)
- **Batch Size**: 4 with gradient accumulation 4 (effective: 16)

## Data Requirements
- Parquet files: `veterans_history_project_resources_pre2010_train.parquet` and `_val.parquet`
- Azure blob storage connection for audio files

See [learnings/whisper-lora-finetuning.md](../learnings/whisper-lora-finetuning.md) for gotchas.

## 1. Setup Dependencies

Fine-tuning requires additional packages not in the base project. Add them via uv:

In [None]:
# Add fine-tuning dependencies to pyproject.toml (run once)
# uv add peft accelerate
#
# For A6000 with CUDA 11.8:
# uv add torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
import sys
import os
from pathlib import Path

# Load Azure credentials from .env file
from dotenv import load_dotenv
load_dotenv(dotenv_path='../credentials/creds.env')

# Add scripts directory to path for imports
sys.path.insert(0, str(Path.cwd().parent / "scripts"))

import torch
import pandas as pd
from dataclasses import dataclass
from typing import Any, Dict, List, Union

from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
from peft import LoraConfig, get_peft_model
from datasets import Dataset, Audio
import evaluate as hf_evaluate  # HuggingFace evaluate library (for WER metric)

# Import project modules
import data_loader
import azure_utils
from evaluate import clean_raw_transcript_str  # Local scripts/evaluate.py

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Configuration

In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================

CONFIG = {
    # Data paths - VHP parquet files (same naming as data splits)
    "train_parquet": "../data/raw/loc/veterans_history_project_resources_pre2010_train.parquet",
    "val_parquet": "../data/raw/loc/veterans_history_project_resources_pre2010_val.parquet",
    
    # Azure blob settings (same as inference configs)
    "blob_prefix": "loc_vhp",
    
    # Sampling (set to None to use all data, or small number for testing)
    "train_sample_size": 100,  # Set to None for full training
    "val_sample_size": 20,
    "random_seed": 42,
    
    # Output directory - follows convention: {dataset}-{model}-{task}-{infra}
    "output_dir": "../outputs/vhp-pre2010-whisper-large-v3-lora-ft-a6000",
    
    # Model - using HuggingFace transformers (not faster-whisper, which is inference-only)
    # Note: For inference we use faster-whisper, but for fine-tuning we need the original HF model
    "model_name": "openai/whisper-large-v3",
    
    # LoRA configuration
    "lora_r": 32,
    "lora_alpha": 64,
    "lora_dropout": 0.05,
    "target_modules": ["q_proj", "v_proj"],
    
    # Training hyperparameters
    "learning_rate": 1e-5,           # CRITICAL: Use 1e-5 for V3 (not 1e-3)
    "batch_size": 4,                  # Per device
    "gradient_accumulation": 4,       # Effective batch = 16
    "warmup_steps": 500,
    "max_steps": 5000,               # Adjust based on data size
    "eval_steps": 500,
    "save_steps": 500,
    
    # Precision
    "fp16": True,                    # Use fp16 for V3
    "bf16": False,
}

# Create output directory
os.makedirs(CONFIG["output_dir"], exist_ok=True)
print(f"Output directory: {CONFIG['output_dir']}")

## 3. Load Data

Using existing `data_loader.py` and `azure_utils.py` from scripts/.

Ground truth is extracted from `fulltext_file_str` column using `clean_raw_transcript_str()` from evaluate.py (see [notebooks/evals_learn.ipynb](./evals_learn.ipynb) for details on how this works).

In [None]:
def load_finetune_dataset(parquet_path: str, sample_size: int = None, random_seed: int = 42):
    """
    Load dataset for fine-tuning using existing data_loader infrastructure.
    
    Uses:
    - data_loader.load_vhp_dataset() for parquet loading and sampling
    - evaluate.clean_raw_transcript_str() for ground truth cleaning
    - azure_utils for audio download
    """
    # Load using existing data_loader (handles filtering, sampling)
    df = data_loader.load_vhp_dataset(
        parquet_path=parquet_path,
        sample_size=sample_size,
        filter_has_transcript=True,
        filter_has_media=True
    )
    
    print(f"Loaded {len(df)} samples")
    return df

# Load train and validation sets
print("Loading training data...")
df_train = load_finetune_dataset(
    CONFIG["train_parquet"], 
    sample_size=CONFIG["train_sample_size"],
    random_seed=CONFIG["random_seed"]
)

print("\nLoading validation data...")
df_val = load_finetune_dataset(
    CONFIG["val_parquet"],
    sample_size=CONFIG["val_sample_size"],
    random_seed=CONFIG["random_seed"]
)

print(f"\nTrain: {len(df_train)} samples")
print(f"Val: {len(df_val)} samples")

In [None]:
def prepare_hf_dataset(df: pd.DataFrame, blob_prefix: str):
    """
    Convert DataFrame to HuggingFace Dataset with audio and cleaned transcripts.
    
    Downloads audio from Azure blob and cleans transcripts using existing evaluate.py function.
    """
    from tempfile import NamedTemporaryFile
    from pydub import AudioSegment
    import soundfile as sf
    import librosa
    
    records = []
    
    for idx, row in df.iterrows():
        # Get blob path candidates using existing data_loader function
        blob_path_candidates = data_loader.get_blob_path_for_row(row, idx, blob_prefix)
        
        if not blob_path_candidates:
            print(f"  Skipping {idx}: no blob path")
            continue
        
        # Clean transcript using existing evaluate.py function
        raw_transcript = row.get('fulltext_file_str', '')
        cleaned_transcript = clean_raw_transcript_str(raw_transcript)
        
        if not cleaned_transcript.strip():
            print(f"  Skipping {idx}: empty transcript")
            continue
        
        # Download audio from Azure blob
        audio_data = None
        for blob_path in blob_path_candidates:
            try:
                if azure_utils.blob_exists(blob_path):
                    audio_bytes = azure_utils.download_blob_to_memory(blob_path)
                    
                    # Convert to WAV 16kHz mono using pydub (same as infer_whisper.py)
                    with NamedTemporaryFile(suffix=Path(blob_path).suffix, delete=False) as tmp:
                        tmp.write(audio_bytes)
                        tmp_path = tmp.name
                    
                    audio_seg = AudioSegment.from_file(tmp_path)
                    audio_seg = audio_seg.set_frame_rate(16000).set_channels(1)
                    
                    # Export to wav
                    wav_path = tmp_path.replace(Path(blob_path).suffix, '.wav')
                    audio_seg.export(wav_path, format='wav')
                    
                    # Load as numpy array
                    audio_data, sr = librosa.load(wav_path, sr=16000)
                    
                    # Cleanup temp files
                    os.unlink(tmp_path)
                    if os.path.exists(wav_path):
                        os.unlink(wav_path)
                    
                    break
            except Exception as e:
                print(f"  Error downloading {blob_path}: {e}")
                continue
        
        if audio_data is None:
            print(f"  Skipping {idx}: could not download audio")
            continue
        
        records.append({
            "audio": {"array": audio_data, "sampling_rate": 16000},
            "sentence": cleaned_transcript
        })
        
        if len(records) % 10 == 0:
            print(f"  Processed {len(records)} samples...")
    
    print(f"  Total valid samples: {len(records)}")
    
    # Create HuggingFace dataset
    dataset = Dataset.from_dict({
        "audio": [r["audio"] for r in records],
        "sentence": [r["sentence"] for r in records]
    })
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
    
    return dataset

In [None]:
# Prepare HuggingFace datasets
print("Preparing training dataset (downloading audio from Azure)...")
train_dataset = prepare_hf_dataset(df_train, CONFIG["blob_prefix"])

print("\nPreparing validation dataset...")
val_dataset = prepare_hf_dataset(df_val, CONFIG["blob_prefix"])

print(f"\nFinal dataset sizes:")
print(f"  Train: {len(train_dataset)}")
print(f"  Val: {len(val_dataset)}")

In [None]:
# Preview a random sample
import random
random.seed(CONFIG["random_seed"])

sample_idx = random.randint(0, len(train_dataset) - 1)
sample = train_dataset[sample_idx]

print(f"Sample {sample_idx}:")
print(f"  Audio duration: {len(sample['audio']['array']) / sample['audio']['sampling_rate']:.1f}s")
print(f"  Transcript preview: {sample['sentence'][:200]}...")

## 4. Initialize Model

**Note on model choice**: For fine-tuning we use `openai/whisper-large-v3` from HuggingFace transformers. This is different from inference where we use `faster-whisper` (CTranslate2 optimized). The fine-tuned weights can later be converted to faster-whisper format for inference.

In [None]:
# Load processor
processor = WhisperProcessor.from_pretrained(CONFIG["model_name"])

# Load model in fp16
print(f"Loading model: {CONFIG['model_name']}")
model = WhisperForConditionalGeneration.from_pretrained(
    CONFIG["model_name"],
    torch_dtype=torch.float16,
    device_map="auto",
    load_in_8bit=False,  # IMPORTANT: Don't use 8-bit for V3 (causes hallucinations)
)

# Clear forced decoder IDs (important for fine-tuning)
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

print(f"Model loaded. Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")

## 5. Apply LoRA

### Why LoRA for Whisper Fine-tuning?

**LoRA (Low-Rank Adaptation)** is chosen over full fine-tuning for several reasons:

1. **Memory Efficiency**: Full fine-tuning of Whisper large-v3 (1.5B params) requires 30-40GB VRAM. LoRA reduces this to ~20GB by only training adapter weights.

2. **Catastrophic Forgetting Prevention**: LoRA preserves the base model's general ASR capabilities while adapting to domain-specific audio. Full fine-tuning risks losing pre-trained knowledge.

3. **Faster Training**: Only ~1% of parameters are trainable (15.7M vs 1.5B), significantly reducing training time.

4. **Easy Model Merging**: LoRA weights can be merged with base model for deployment, or kept separate for A/B testing.

**References:**
- [HuggingFace PEFT Whisper Training](https://github.com/huggingface/peft/blob/main/examples/int8_training/peft_bnb_whisper_large_v2_training.ipynb)
- [LoRA Paper](https://arxiv.org/abs/2106.09685)

In [None]:
# Configure LoRA
lora_config = LoraConfig(
    r=CONFIG["lora_r"],
    lora_alpha=CONFIG["lora_alpha"],
    target_modules=CONFIG["target_modules"],
    lora_dropout=CONFIG["lora_dropout"],
    bias="none",
    task_type="SEQ_2_SEQ_LM",
)

# Apply LoRA to model
model = get_peft_model(model, lora_config)
model.config.use_cache = False  # Disable cache for training

# Print trainable parameters
model.print_trainable_parameters()

## 6. Data Preprocessing

In [None]:
def prepare_dataset(batch):
    """Preprocess audio and text for training."""
    audio = batch["audio"]
    
    # Extract features from audio
    batch["input_features"] = processor(
        audio["array"],
        sampling_rate=audio["sampling_rate"],
        return_tensors="pt"
    ).input_features[0]
    
    # Tokenize transcription
    batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
    
    return batch

# Apply preprocessing
print("Preprocessing training data...")
train_dataset = train_dataset.map(
    prepare_dataset,
    remove_columns=train_dataset.column_names,
    num_proc=4,
)

print("Preprocessing validation data...")
val_dataset = val_dataset.map(
    prepare_dataset,
    remove_columns=val_dataset.column_names,
    num_proc=4,
)

print(f"Preprocessing complete.")

## 7. Data Collator and Metrics

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """Data collator for Whisper seq2seq training."""
    processor: Any
    
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]):
        # Process input features
        input_features = [{"input_features": f["input_features"]} for f in features]
        batch = self.processor.feature_extractor.pad(
            input_features,
            return_tensors="pt"
        )
        
        # Process label features
        label_features = [{"input_ids": f["labels"]} for f in features]
        labels_batch = self.processor.tokenizer.pad(
            label_features,
            return_tensors="pt"
        )
        
        # Replace padding with -100 (ignored by loss)
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )
        
        # Remove BOS token if present
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        
        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

# WER metric (using HuggingFace evaluate library)
wer_metric = hf_evaluate.load("wer")

def compute_metrics(pred):
    """Compute WER for evaluation."""
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    
    # Replace -100 with pad token
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    
    # Decode
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    # Compute WER
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    
    return {"wer": wer}

print("Data collator and metrics ready.")

## 8. Training Configuration

### Why Seq2SeqTrainer?

Whisper is a sequence-to-sequence (encoder-decoder) model that takes audio input and generates text output. `Seq2SeqTrainer` from HuggingFace is specifically designed for this architecture and provides:

1. **Proper generation during evaluation**: Uses `model.generate()` instead of forward pass
2. **Label handling**: Correctly handles the decoder input/output shift
3. **Beam search support**: For better generation quality during eval

**Reference**: [HuggingFace Fine-tune Whisper Guide](https://huggingface.co/blog/fine-tune-whisper)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir=CONFIG["output_dir"],
    per_device_train_batch_size=CONFIG["batch_size"],
    gradient_accumulation_steps=CONFIG["gradient_accumulation"],
    learning_rate=CONFIG["learning_rate"],
    warmup_steps=CONFIG["warmup_steps"],
    max_steps=CONFIG["max_steps"],
    evaluation_strategy="steps",
    eval_steps=CONFIG["eval_steps"],
    save_steps=CONFIG["save_steps"],
    logging_steps=50,
    save_total_limit=3,
    fp16=CONFIG["fp16"],
    bf16=CONFIG["bf16"],
    weight_decay=0.01,
    dataloader_num_workers=4,
    remove_unused_columns=False,
    label_names=["labels"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    report_to=["tensorboard"],
)

print(f"Training configuration:")
print(f"  Learning rate: {CONFIG['learning_rate']}")
print(f"  Batch size: {CONFIG['batch_size']} x {CONFIG['gradient_accumulation']} = {CONFIG['batch_size'] * CONFIG['gradient_accumulation']}")
print(f"  Max steps: {CONFIG['max_steps']}")
print(f"  Precision: {'fp16' if CONFIG['fp16'] else 'fp32'}")

## 9. Train

In [None]:
# Initialize trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor.feature_extractor,
)

# Clear GPU cache before training
torch.cuda.empty_cache()

print("Starting training...")
print("="*60)

In [None]:
# Train the model
trainer.train()

## 10. Fine-tuning Summary

**Is LoRA "real" fine-tuning?**

LoRA is a form of **parameter-efficient fine-tuning (PEFT)**, not full fine-tuning. The distinction:

- **Full fine-tuning**: Updates all 1.5B parameters. Higher capacity but requires more VRAM and risks overfitting.
- **LoRA**: Updates only ~15M adapter parameters (~1%). More efficient, preserves base knowledge, still achieves strong domain adaptation.

For domain adaptation (like VHP historical audio), LoRA is often preferred because:
1. The base model already has strong ASR capabilities
2. We want to adapt to acoustic characteristics, not relearn language
3. Limited training data makes full fine-tuning prone to overfitting

In [None]:
# Save LoRA weights
lora_path = os.path.join(CONFIG["output_dir"], "lora-weights")
model.save_pretrained(lora_path)
processor.save_pretrained(lora_path)

print(f"LoRA weights saved to: {lora_path}")

# Optionally merge and save full model
print("\nMerging LoRA weights with base model...")
merged_model = model.merge_and_unload()
merged_path = os.path.join(CONFIG["output_dir"], "merged-model")
merged_model.save_pretrained(merged_path)
processor.save_pretrained(merged_path)

print(f"Merged model saved to: {merged_path}")

In [None]:
# Print summary
print("="*60)
print("FINE-TUNING COMPLETE")
print("="*60)
print(f"\nBase Model: {CONFIG['model_name']}")
print(f"LoRA config: r={CONFIG['lora_r']}, alpha={CONFIG['lora_alpha']}")
print(f"Learning rate: {CONFIG['learning_rate']}")
print(f"\nData:")
print(f"  Train parquet: {CONFIG['train_parquet']}")
print(f"  Val parquet: {CONFIG['val_parquet']}")
print(f"\nOutputs:")
print(f"  LoRA weights: {lora_path}")
print(f"  Merged model: {merged_path}")

## 11. Test Inference

Quick test of the fine-tuned model on samples from the test set.

**Note**: The merged model is in HuggingFace format. Our production `infer_whisper.py` uses faster-whisper (CTranslate2 format) for speed. For this quick test, we use HuggingFace transformers pipeline directly. See "Next Steps" for converting to faster-whisper format for full evaluation.

In [None]:
# Quick test using HuggingFace transformers pipeline
# (For production inference, convert to faster-whisper format - see Next Steps)

from transformers import pipeline as hf_pipeline

# Load the merged fine-tuned model
pipe = hf_pipeline(
    "automatic-speech-recognition",
    model=merged_path,
    torch_dtype=torch.float16,
    device=0 if torch.cuda.is_available() else -1,
)

# Load test parquet (small sample)
test_parquet = CONFIG["train_parquet"].replace("_train", "_test")
df_test = data_loader.load_vhp_dataset(test_parquet, sample_size=10)

print(f"Test samples: {len(df_test)}")
print(f"Model: {merged_path}")

In [None]:
# Run inference on test samples
from tempfile import NamedTemporaryFile
from pydub import AudioSegment

test_results = []

for idx, row in df_test.iterrows():
    blob_paths = data_loader.get_blob_path_for_row(row, idx, CONFIG["blob_prefix"])
    
    for blob_path in blob_paths:
        if azure_utils.blob_exists(blob_path):
            print(f"[{idx}] Processing: {blob_path}")
            
            try:
                # Download audio
                audio_bytes = azure_utils.download_blob_to_memory(blob_path)
                
                # Save to temp file
                with NamedTemporaryFile(suffix=Path(blob_path).suffix, delete=False) as tmp:
                    tmp.write(audio_bytes)
                    tmp_path = tmp.name
                
                # Convert to wav 16kHz mono (like infer_whisper.py)
                audio_seg = AudioSegment.from_file(tmp_path)
                audio_seg = audio_seg.set_frame_rate(16000).set_channels(1)
                
                # Limit to first 5 minutes for quick test
                if len(audio_seg) > 300000:  # 300 seconds in ms
                    audio_seg = audio_seg[:300000]
                
                wav_path = tmp_path.replace(Path(blob_path).suffix, '.wav')
                audio_seg.export(wav_path, format='wav')
                
                # Run inference
                result = pipe(wav_path, return_timestamps=True)
                hypothesis = result["text"]
                
                # Get ground truth
                gt = clean_raw_transcript_str(row.get('fulltext_file_str', ''))
                
                test_results.append({
                    "file_id": idx,
                    "hypothesis": hypothesis,
                    "ground_truth": gt,
                    "blob_path": blob_path
                })
                
                # Cleanup
                os.unlink(tmp_path)
                if os.path.exists(wav_path):
                    os.unlink(wav_path)
                    
            except Exception as e:
                print(f"  Error: {e}")
            
            break  # Only process first available blob path

print(f"\nCompleted: {len(test_results)} files")

In [None]:
# View test results
print("=" * 70)
print("TEST INFERENCE RESULTS")
print("=" * 70)

for r in test_results[:3]:  # Show first 3
    print(f"\nFile ID: {r['file_id']}")
    print(f"Blob: {r['blob_path']}")
    print(f"\nHypothesis (first 300 chars):")
    print(r['hypothesis'][:300] + "..." if len(r['hypothesis']) > 300 else r['hypothesis'])
    print(f"\nGround truth (first 300 chars):")
    print(r['ground_truth'][:300] + "..." if len(r['ground_truth']) > 300 else r['ground_truth'])
    print("-" * 70)

# Save test results
test_output_dir = Path(CONFIG["output_dir"]) / "test-inference"
test_output_dir.mkdir(parents=True, exist_ok=True)

df_test_results = pd.DataFrame(test_results)
df_test_results.to_parquet(test_output_dir / "test_results.parquet", index=False)
print(f"\nTest results saved to: {test_output_dir / 'test_results.parquet'}")

## 12. Next Steps

To run full evaluation on the test set using the production inference pipeline:

1. **Convert to faster-whisper format** (optional, for speed):
   ```bash
   ct2-transformers-converter --model {merged_path} --output_dir models/whisper-large-v3-vhp-lora
   ```

2. **Create inference config** (e.g., `configs/runs/vhp-pre2010-whisper-large-v3-lora-sample100.yaml`):
   ```yaml
   experiment_id: vhp-pre2010-whisper-large-v3-lora-sample100
   model:
     name: "whisper-large-v3-lora"
     dir: "./models/whisper-large-v3-vhp-lora"  # or use HF path
     batch_size: 12
     device: "cuda"
     compute_type: "float16"
   input:
     source: "azure_blob"
     parquet_path: "data/raw/loc/veterans_history_project_resources_pre2010_test.parquet"
     blob_prefix: "loc_vhp"
     sample_size: 100
   output:
     dir: "outputs/vhp-pre2010-whisper-large-v3-lora-sample100"
   ```

3. **Run inference**:
   ```bash
   uv run python scripts/infer_whisper.py --config configs/runs/vhp-pre2010-whisper-large-v3-lora-sample100.yaml
   ```