# Winston Flywheel — Whisper Fine-Tuning
**Run this notebook in Google Colab after uploading training data with `scripts/upload_training_data.py`.**

This notebook:
1. Loads kitchen speech audio from your Google Drive
2. Generates high-quality pseudo-labels using Whisper Large-v3
3. Fine-tunes Whisper Small.en with QLoRA on utterances where Small was wrong
4. Saves the merged model back to Drive for local MLX conversion

**Estimated runtime:** 30–90 minutes depending on dataset size and GPU assigned.  
**Required GPU:** T4 or better (Colab free tier works; Pro is faster).


In [None]:
# ── Step 1: Verify GPU and mount Google Drive ──────────────────────────────
import subprocess
result = subprocess.run(["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader"],
                        capture_output=True, text=True)
print("GPU:", result.stdout.strip() or "NOT FOUND — go to Runtime → Change runtime type → T4 GPU")

from google.colab import drive
drive.mount('/content/drive')
print("Drive mounted ✓")


In [None]:
# ── Step 2: Install dependencies ───────────────────────────────────────────
# This takes ~2 minutes on first run. Output is suppressed for clarity.
import subprocess, sys

packages = [
    "transformers>=4.40",
    "peft>=0.10",
    "datasets>=2.18",
    "accelerate>=0.29",
    "evaluate",
    "jiwer",
    "librosa",
    "soundfile",
    "bitsandbytes",  # for 8-bit loading
]

subprocess.run([sys.executable, "-m", "pip", "install", "-q"] + packages, check=True)
print("Dependencies installed ✓")


## ⚙️ Configuration
**Edit the cell below.** Everything else in the notebook uses these values.

- `DRIVE_DATA_DIR`: where your uploaded WAVs live (from `upload_training_data.py`)
- `DRIVE_MODELS_DIR`: where the trained model will be saved
- `CYCLE_NAME`: increment this each run so you don't overwrite previous adapters
- `CONFIDENCE_THRESHOLD`: utterances below this from Whisper Small are training candidates


In [None]:
# ══════════════════════════════════════════════════
# ⚙️  EDIT THIS CELL — all config lives here
# ══════════════════════════════════════════════════

DRIVE_DATA_DIR   = "/content/drive/MyDrive/winston-flywheel/data"
DRIVE_MODELS_DIR = "/content/drive/MyDrive/winston-flywheel/models"

CYCLE_NAME = "cycle-1"       # ← increment each training run

WHISPER_LARGE   = "openai/whisper-large-v3"
WHISPER_SMALL   = "openai/whisper-small.en"

CONFIDENCE_THRESHOLD = 0.7   # utterances below this are training candidates
VALIDATION_SPLIT     = 0.15  # 15% held out for validation
MAX_AUDIO_SECS       = 30    # skip utterances longer than this

# Training hyperparameters — reasonable defaults for a small kitchen dataset
LORA_R           = 8
LORA_ALPHA       = 32
LEARNING_RATE    = 1e-3
TRAIN_EPOCHS     = 10
BATCH_SIZE       = 8         # reduce to 4 if you get OOM errors
GRAD_ACCUM_STEPS = 2

# ══════════════════════════════════════════════════
import os
os.makedirs(f"{DRIVE_MODELS_DIR}/merged/{CYCLE_NAME}", exist_ok=True)
os.makedirs(f"{DRIVE_MODELS_DIR}/checkpoints/{CYCLE_NAME}", exist_ok=True)
print(f"Cycle: {CYCLE_NAME}")
print(f"Data:  {DRIVE_DATA_DIR}/audio")
print(f"Output: {DRIVE_MODELS_DIR}/merged/{CYCLE_NAME}")


In [None]:
# ── Step 3: Load audio data from Drive ──────────────────────────────────────
import json
import glob
import librosa
import numpy as np

audio_dir = f"{DRIVE_DATA_DIR}/audio"
wav_files = sorted(glob.glob(f"{audio_dir}/*.wav"))

print(f"Found {len(wav_files)} WAV files in {audio_dir}")
if len(wav_files) == 0:
    raise RuntimeError("No WAV files found. Did you run upload_training_data.py?")

examples = []
skipped = 0
for wav_path in wav_files:
    json_path = wav_path.replace(".wav", ".json")
    if not os.path.exists(json_path):
        skipped += 1
        continue
    with open(json_path) as f:
        meta = json.load(f)

    # Load audio
    audio, sr = librosa.load(wav_path, sr=16000, mono=True)
    duration = len(audio) / 16000

    if duration > MAX_AUDIO_SECS:
        skipped += 1
        continue

    examples.append({
        "path":         wav_path,
        "audio":        audio,
        "duration":     duration,
        "small_text":   meta.get("text", ""),
        "confidence":   meta.get("confidence", 1.0),
        "low_confidence": meta.get("low_confidence", False),
        "speaker_id":   meta.get("speaker_id"),
    })

print(f"Loaded: {len(examples)} examples  |  Skipped: {skipped}")

# Stats
confs = [e["confidence"] for e in examples]
low_conf = [e for e in examples if e["confidence"] < CONFIDENCE_THRESHOLD]
print(f"Low-confidence (<{CONFIDENCE_THRESHOLD}): {len(low_conf)} / {len(examples)}")
print(f"Confidence range: {min(confs):.3f} – {max(confs):.3f}  mean: {np.mean(confs):.3f}")
durations = [e["duration"] for e in examples]
print(f"Duration range: {min(durations):.1f}s – {max(durations):.1f}s  mean: {np.mean(durations):.1f}s")


In [None]:
# ── Step 4: Generate pseudo-labels with Whisper Large-v3 ───────────────────
# Large-v3 is our "oracle". We run it on every example to get high-quality labels.
# Training signal = cases where Small's prediction differs from Large's.
#
# NOTE: This is the slowest step. ~1-3s per utterance on T4.
# For 500 utterances, expect ~15-25 minutes.

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration

print(f"Loading {WHISPER_LARGE}...")
large_processor = WhisperProcessor.from_pretrained(WHISPER_LARGE)
large_model = WhisperForConditionalGeneration.from_pretrained(
    WHISPER_LARGE,
    torch_dtype=torch.float16,
    device_map="cuda",
)
large_model.eval()
print("Large-v3 loaded ✓")

def transcribe_large(audio: np.ndarray) -> str:
    inputs = large_processor(
        audio, sampling_rate=16000, return_tensors="pt"
    ).input_features.to("cuda", dtype=torch.float16)
    with torch.no_grad():
        predicted_ids = large_model.generate(inputs, language="en", task="transcribe")
    return large_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()

print(f"\nRunning Large-v3 on {len(examples)} examples...")
for i, ex in enumerate(examples):
    ex["large_text"] = transcribe_large(ex["audio"])
    if (i + 1) % 20 == 0:
        print(f"  {i+1}/{len(examples)}")

# Show divergence stats
same  = sum(1 for e in examples if e["large_text"].lower() == e["small_text"].lower())
diff  = len(examples) - same
print(f"\nLarge vs Small: {diff}/{len(examples)} differ ({100*diff/len(examples):.1f}%)")
print("Sample divergences:")
for ex in [e for e in examples if e["large_text"].lower() != e["small_text"].lower()][:5]:
    print(f"  Small: {ex['small_text']!r}")
    print(f"  Large: {ex['large_text']!r}")
    print()

# Free Large model memory before training
del large_model
torch.cuda.empty_cache()
print("Large-v3 unloaded — GPU memory freed for training.")


In [None]:
# ── Step 5: Prepare training dataset ────────────────────────────────────────
# Training set = all examples with large_text as the label.
# We use ALL examples (not just divergences) so Small learns the full distribution.
# The divergences are where the learning signal is strongest.

import random
from datasets import Dataset
from transformers import WhisperFeatureExtractor, WhisperTokenizer

feature_extractor = WhisperFeatureExtractor.from_pretrained(WHISPER_SMALL)
tokenizer = WhisperTokenizer.from_pretrained(WHISPER_SMALL, language="en", task="transcribe")

def prepare_example(ex: dict) -> dict:
    # Audio → mel spectrogram features
    input_features = feature_extractor(
        ex["audio"], sampling_rate=16000, return_tensors="np"
    ).input_features[0]
    # Text → token ids (use Large's label as ground truth)
    labels = tokenizer(ex["large_text"]).input_ids
    return {
        "input_features": input_features,
        "labels":         labels,
        "reference":      ex["large_text"],
        "small_pred":     ex["small_text"],
    }

print("Preparing dataset...")
prepared = [prepare_example(ex) for ex in examples]

# Train / val split (stratify by whether Small was correct)
random.shuffle(prepared)
n_val = max(1, int(len(prepared) * VALIDATION_SPLIT))
val_data   = prepared[:n_val]
train_data = prepared[n_val:]

train_ds = Dataset.from_list(train_data)
val_ds   = Dataset.from_list(val_data)

print(f"Train: {len(train_ds)}  |  Val: {len(val_ds)}")


In [None]:
# ── Step 6: Data collator ───────────────────────────────────────────────────
# Whisper requires padding input_features to 3000 frames and labels to max_length.
import torch
from dataclasses import dataclass
from typing import Any

@dataclass
class WhisperDataCollator:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features):
        input_features = [{"input_features": f["input_features"]} for f in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": f["labels"]} for f in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding token id with -100 so loss ignores it
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )
        # Remove bos token if present
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch


In [None]:
# ── Step 7: Load Whisper Small + apply LoRA ─────────────────────────────────
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from peft import LoraConfig, get_peft_model

print(f"Loading {WHISPER_SMALL}...")
processor = WhisperProcessor.from_pretrained(WHISPER_SMALL)
model = WhisperForConditionalGeneration.from_pretrained(
    WHISPER_SMALL,
    torch_dtype=torch.float16,
    device_map="cuda",
)
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.generation_config.forced_decoder_ids = None

# Apply LoRA to encoder query/value attention projections
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

data_collator = WhisperDataCollator(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)
print("Model ready ✓")


In [None]:
# ── Step 8: Define WER metric ───────────────────────────────────────────────
import evaluate
import numpy as np

wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids   = pred.predictions
    label_ids  = pred.label_ids

    # Replace -100 with pad token id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str  = processor.tokenizer.batch_decode(pred_ids,  skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": round(wer, 4)}


In [None]:
# ── Step 9: Train ───────────────────────────────────────────────────────────
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

checkpoint_dir = f"{DRIVE_MODELS_DIR}/checkpoints/{CYCLE_NAME}"

training_args = Seq2SeqTrainingArguments(
    output_dir=checkpoint_dir,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    learning_rate=LEARNING_RATE,
    warmup_steps=max(10, len(train_ds) // 20),
    num_train_epochs=TRAIN_EPOCHS,
    fp16=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    predict_with_generate=True,
    generation_max_length=225,
    logging_steps=10,
    remove_unused_columns=False,
    label_names=["labels"],
    report_to="none",    # disable wandb
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

print(f"Starting training — {TRAIN_EPOCHS} epochs on {len(train_ds)} examples")
print(f"Checkpoints → {checkpoint_dir}")
print()
trainer.train()
print("\nTraining complete ✓")


In [None]:
# ── Step 10: Quick validation ───────────────────────────────────────────────
print("Validation set results:")
print()

correct = 0
for ex in val_data[:10]:    # spot-check first 10
    audio = next(e["audio"] for e in examples
                 if e["large_text"] == ex["reference"])
    inputs = processor(
        audio, sampling_rate=16000, return_tensors="pt"
    ).input_features.to("cuda", dtype=torch.float16)
    with torch.no_grad():
        pred_ids = model.generate(inputs, language="en", task="transcribe")
    pred = processor.batch_decode(pred_ids, skip_special_tokens=True)[0].strip()

    match = pred.lower() == ex["reference"].lower()
    correct += match
    status = "✓" if match else "✗"
    print(f"  {status} REF: {ex['reference']!r}")
    if not match:
        print(f"    HYP: {pred!r}")

print(f"\nSpot-check accuracy: {correct}/10")


In [None]:
# ── Step 11: Merge LoRA and save to Drive ───────────────────────────────────
# Merge LoRA weights into the base model and save as a standard HuggingFace model.
# This is what download_adapter.py will pick up and convert to MLX.

merged_dir = f"{DRIVE_MODELS_DIR}/merged/{CYCLE_NAME}"

print(f"Merging LoRA weights...")
merged_model = model.merge_and_unload()
merged_model = merged_model.to(torch.float32)   # convert from fp16 before saving

print(f"Saving to Drive: {merged_dir}")
merged_model.save_pretrained(merged_dir)
processor.save_pretrained(merged_dir)

print(f"\nSaved ✓")
print(f"\nNext steps on your local machine:")
print(f"  python scripts/download_adapter.py --cycle {CYCLE_NAME}")
print(f"  python scripts/evaluate_whisper.py --label {CYCLE_NAME}")


## Done — what to do next on your Mac

```bash
# 1. Download and convert to MLX
python scripts/download_adapter.py --cycle cycle-1

# 2. Evaluate — compare baseline vs this cycle
python scripts/evaluate_whisper.py --label cycle-1

# 3. Compare results
python scripts/evaluate_whisper.py --compare \
    data/benchmark/evals/<baseline>.json \
    data/benchmark/evals/<cycle-1>.json
```

The `download_adapter.py` script will automatically update `config/default.yaml` 
to point to the new MLX model. Restart the perception service to use it.

**To run a second flywheel cycle:**
1. Increment `CYCLE_NAME` to `cycle-2` in the config cell
2. Collect more data with the updated model (it will generate different low-confidence samples)
3. Re-run this notebook
