# Fine-tuning Parakeet-TDT-0.6B-v3 with NeMo

**Target Hardware**: RunPod A6000 (48GB VRAM)

This notebook fine-tunes NVIDIA Parakeet-TDT-0.6B-v3 on VHP oral history audio using NeMo.

## Key Configuration
- **Model**: nvidia/parakeet-tdt-0.6b-v3
- **Learning Rate**: 5e-5
- **Precision**: 16-mixed
- **Batch Size**: 8 with gradient accumulation 4 (effective: 32)

## Data Requirements
- Parquet files: `veterans_history_project_resources_pre2010_train.parquet` and `_val.parquet`
- Audio files must be pre-downloaded to local directory (NeMo requires local paths)

## Pre-requisite: Download Audio Files
Before running this notebook, audio files must be downloaded from Azure blob to local storage.
Use the download script or notebook to fetch files to `/workspace/audio/loc_vhp/`.

See [learnings/parakeet-nemo-finetuning.md](../learnings/parakeet-nemo-finetuning.md) for gotchas.

## 1. Setup Dependencies

Fine-tuning with NeMo requires additional packages. Add them via uv:

In [None]:
# Add fine-tuning dependencies to pyproject.toml (run once)
# uv add nemo_toolkit[asr] pytorch-lightning>=2.0
#
# For A6000 with CUDA 11.8:
# uv add torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
import sys
import os
import json
import re
from pathlib import Path

# Add project root to path for package imports
sys.path.insert(0, str(Path.cwd().parent))

import torch
import librosa
import pandas as pd
from omegaconf import OmegaConf, DictConfig
import pytorch_lightning as pl
import nemo.collections.asr as nemo_asr
from nemo.utils import exp_manager

# Import project modules for transcript cleaning
from scripts.eval.evaluate import clean_raw_transcript_str

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Configuration

In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================

CONFIG = {
    # Data paths - VHP parquet files (upload these to RunPod)
    "train_parquet": "/workspace/data/veterans_history_project_resources_pre2010_train.parquet",
    "val_parquet": "/workspace/data/veterans_history_project_resources_pre2010_val.parquet",
    "test_parquet": "/workspace/data/veterans_history_project_resources_pre2010_test.parquet",
    
    # Audio directory - where audio files are pre-downloaded from Azure blob
    # Expected structure: {audio_dir}/{azure_blob_index}/video.mp4 or audio.mp3
    # Must download BEFORE running this notebook (NeMo requires local paths)
    "audio_dir": "/workspace/audio/loc_vhp",
    
    # Random seed for reproducibility
    "random_seed": 42,
    
    # NeMo manifest output (will be generated)
    "train_manifest": "/workspace/data/train_manifest.json",
    "val_manifest": "/workspace/data/val_manifest.json",
    
    # Output directory - follows convention: {dataset}-{model}-{task}-{infra}
    "output_dir": "/workspace/outputs/vhp-pre2010-parakeet-tdt-0.6b-ft-a6000",
    "exp_name": "parakeet_tdt_vhp",
    
    # Model
    "model_name": "nvidia/parakeet-tdt-0.6b-v3",
    
    # Training hyperparameters
    "learning_rate": 5e-5,
    "batch_size": 8,
    "gradient_accumulation": 4,
    "warmup_steps": 10000,
    "max_steps": 100000,
    "val_check_interval": 1000,
    
    # Audio settings
    "sample_rate": 16000,
    "max_duration": 30,   # Max audio duration per sample (seconds)
    "min_duration": 0.1,
    
    # Precision
    "precision": "16-mixed",
}

# Create output directories
os.makedirs(CONFIG["output_dir"], exist_ok=True)
os.makedirs(os.path.dirname(CONFIG["train_manifest"]), exist_ok=True)
print(f"Output directory: {CONFIG['output_dir']}")

## 3. Create NeMo Manifests from Parquet

NeMo requires JSON lines manifest format:
```json
{"audio_filepath": "/path/to/audio.wav", "text": "transcription", "duration": 15.4}
```

We convert VHP parquet files to this format, using:
- `fulltext_file_str` column with `clean_raw_transcript_str()` for ground truth (same as Whisper notebook)
- `azure_blob_index` column to map to pre-downloaded audio files

In [None]:
def find_audio_file(audio_dir: str, blob_idx: int) -> str:
    """
    Find audio file for a given azure_blob_index.
    
    Matches the structure from Azure blob download:
    {audio_dir}/{blob_idx}/video.mp4 or audio.mp3
    """
    base_path = Path(audio_dir) / str(blob_idx)
    
    # Priority order (matches upload script preference)
    candidates = [
        base_path / "video.mp4",
        base_path / "audio.mp3",
        base_path / "audio.wav",
        base_path / "video.mp3",
    ]
    
    for path in candidates:
        if path.exists():
            return str(path)
    
    # Fallback: any audio/video file in directory
    if base_path.exists():
        for ext in ['*.mp4', '*.mp3', '*.wav', '*.m4a']:
            files = list(base_path.glob(ext))
            if files:
                return str(files[0])
    
    return None


def create_nemo_manifest(parquet_path: str, audio_dir: str, output_path: str, max_duration_sec: int = 1440):
    """
    Convert VHP parquet to NeMo manifest format.
    
    Uses:
    - fulltext_file_str column with clean_raw_transcript_str() for ground truth
    - azure_blob_index column to find pre-downloaded audio files
    
    Only includes samples where audio duration <= max_duration_sec (for proper alignment).
    Files longer than this are SKIPPED because we don't have timestamped transcripts.
    
    Args:
        parquet_path: Path to parquet file
        audio_dir: Directory with pre-downloaded audio files
        output_path: Path to write NeMo manifest
        max_duration_sec: Maximum audio duration in seconds (default 1440 = 24 min for Parakeet).
                          Parakeet can handle up to 24 min on A100 80GB with full attention.
    
    NeMo manifest is JSON lines with: audio_filepath, text, duration
    """
    df = pd.read_parquet(parquet_path)
    print(f"Loaded {len(df)} rows from {parquet_path}")
    
    entries = []
    missing_audio = 0
    empty_transcript = 0
    skipped_too_long = 0
    
    for idx, row in df.iterrows():
        # Get blob index for audio file lookup
        blob_idx = row.get('azure_blob_index', row.get('original_parquet_index', idx))
        audio_path = find_audio_file(audio_dir, blob_idx)
        
        if audio_path is None:
            missing_audio += 1
            continue
        
        # Clean transcript using existing evaluate.py function (same as Whisper notebook)
        raw_transcript = row.get('fulltext_file_str', '')
        cleaned_transcript = clean_raw_transcript_str(raw_transcript)
        
        if not cleaned_transcript.strip():
            empty_transcript += 1
            continue
        
        # Get audio duration
        try:
            y, sr = librosa.load(audio_path, sr=None)
            duration = librosa.get_duration(y=y, sr=sr)
        except Exception as e:
            print(f"Error loading {audio_path}: {e}")
            continue
        
        # Skip if audio is too long (no proper alignment possible without timestamps)
        if duration > max_duration_sec:
            print(f"  Skipping {idx}: audio too long ({duration:.1f}s > {max_duration_sec}s)")
            skipped_too_long += 1
            continue
        
        # NeMo expects lowercase text
        entries.append({
            "audio_filepath": audio_path,
            "text": cleaned_transcript.lower(),
            "duration": round(duration, 2)
        })
        
        if len(entries) % 100 == 0:
            print(f"  Processed {len(entries)} entries...")
    
    # Write manifest
    with open(output_path, 'w') as f:
        for entry in entries:
            f.write(json.dumps(entry) + '\n')
    
    total_hours = sum(e['duration'] for e in entries) / 3600
    print(f"\nCreated manifest: {output_path}")
    print(f"  Valid entries: {len(entries)}")
    print(f"  Missing audio: {missing_audio}")
    print(f"  Empty transcripts: {empty_transcript}")
    print(f"  Skipped (too long): {skipped_too_long}")
    print(f"  Total hours: {total_hours:.1f}")
    
    if len(entries) == 0:
        raise ValueError(f"No samples found with duration <= {max_duration_sec}s. "
                         "VHP files are typically 30-60+ minute interviews. "
                         "Consider using forced alignment to create shorter segments.")
    
    return entries

In [None]:
# Create training manifest
print("Creating training manifest...")
train_entries = create_nemo_manifest(
    CONFIG["train_parquet"],
    CONFIG["audio_dir"],
    CONFIG["train_manifest"]
)

print("\nCreating validation manifest...")
val_entries = create_nemo_manifest(
    CONFIG["val_parquet"],
    CONFIG["audio_dir"],
    CONFIG["val_manifest"]
)

In [None]:
def validate_manifest(manifest_path):
    """Validate manifest and print statistics."""
    entries = []
    with open(manifest_path, 'r') as f:
        for line in f:
            entries.append(json.loads(line))
    
    total_duration = sum(e['duration'] for e in entries)
    avg_duration = total_duration / len(entries) if entries else 0
    
    print(f"{manifest_path}:")
    print(f"  Entries: {len(entries)}")
    print(f"  Total: {total_duration/3600:.1f} hours")
    print(f"  Avg duration: {avg_duration:.1f}s")
    return len(entries), total_duration

print("Validating manifests...")
validate_manifest(CONFIG["train_manifest"])
validate_manifest(CONFIG["val_manifest"])

In [None]:
# Preview a sample from manifest
import random
random.seed(CONFIG["random_seed"])

with open(CONFIG["train_manifest"], 'r') as f:
    manifest_lines = f.readlines()

sample_entry = json.loads(random.choice(manifest_lines))
print("Sample manifest entry:")
print(f"  Audio: {sample_entry['audio_filepath']}")
print(f"  Duration: {sample_entry['duration']:.1f}s")
print(f"  Text (first 200 chars): {sample_entry['text'][:200]}...")

## 4. Load Model

### Why Parakeet-TDT?

**Parakeet-TDT (Token-and-Duration Transducer)** is NVIDIA's state-of-the-art ASR model with several advantages:

1. **Streaming-capable**: TDT architecture supports both streaming and offline transcription
2. **Fast inference**: Optimized for NVIDIA GPUs with TensorRT support
3. **Strong baseline**: Competitive with Whisper large-v3 on benchmarks
4. **NeMo integration**: Full fine-tuning support with NeMo toolkit

**References:**
- [Parakeet-TDT-0.6B-v3 on HuggingFace](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3)
- [NeMo ASR Documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/models.html)
- [Turbocharge ASR with Parakeet-TDT](https://developer.nvidia.com/blog/turbocharge-asr-accuracy-and-speed-with-nvidia-nemo-parakeet-tdt/)

In [None]:
torch.cuda.empty_cache()

print(f"Loading model: {CONFIG['model_name']}")
asr_model = nemo_asr.models.ASRModel.from_pretrained(CONFIG["model_name"])

num_params = sum(p.numel() for p in asr_model.parameters())
print(f"Model loaded: {num_params / 1e6:.1f}M parameters")

## 5. Configure Training

### Full Fine-tuning vs LoRA

Unlike Whisper (where we used LoRA), we do **full fine-tuning** for Parakeet because:

1. **Smaller model**: Parakeet-TDT-0.6B has 600M params (vs Whisper's 1.5B), fits in A6000 memory
2. **NeMo optimization**: NeMo's training pipeline is highly optimized for full fine-tuning
3. **TDT architecture**: The transducer head benefits from end-to-end fine-tuning

**Note**: If you encounter OOM issues, reduce `batch_size` or `max_duration`, or use gradient checkpointing.

In [None]:
# Setup data loaders
train_ds_cfg = DictConfig({
    "manifest_filepath": CONFIG["train_manifest"],
    "sample_rate": CONFIG["sample_rate"],
    "batch_size": CONFIG["batch_size"],
    "shuffle": True,
    "num_workers": 4,
    "pin_memory": True,
    "max_duration": CONFIG["max_duration"],
    "min_duration": CONFIG["min_duration"],
})

val_ds_cfg = DictConfig({
    "manifest_filepath": CONFIG["val_manifest"],
    "sample_rate": CONFIG["sample_rate"],
    "batch_size": 16,
    "shuffle": False,
    "num_workers": 4,
    "pin_memory": True,
    "max_duration": CONFIG["max_duration"],
    "min_duration": CONFIG["min_duration"],
})

asr_model.setup_training_data(train_ds_cfg)
asr_model.setup_validation_data(val_ds_cfg)
print("Data loaders ready.")

In [None]:
# Configure optimizer
optim_cfg = DictConfig({
    "name": "adamw",
    "lr": CONFIG["learning_rate"],
    "betas": [0.9, 0.999],
    "weight_decay": 0.0001,
    "sched": {
        "name": "CosineAnnealing",
        "warmup_steps": CONFIG["warmup_steps"],
        "max_steps": CONFIG["max_steps"],
    }
})

asr_model.setup_optimization(optim_cfg)
print(f"Optimizer: lr={CONFIG['learning_rate']}, warmup={CONFIG['warmup_steps']}")

## 6. Setup Trainer

In [None]:
trainer = pl.Trainer(
    devices=1,
    accelerator="gpu",
    max_steps=CONFIG["max_steps"],
    val_check_interval=CONFIG["val_check_interval"],
    log_every_n_steps=50,
    enable_checkpointing=True,
    precision=CONFIG["precision"],
    gradient_clip_val=1.0,
    accumulate_grad_batches=CONFIG["gradient_accumulation"],
)

exp_manager_cfg = DictConfig({
    "exp_dir": CONFIG["output_dir"],
    "name": CONFIG["exp_name"],
    "create_tensorboard_logger": True,
    "resume_if_exists": True,
    "checkpoint_callback_params": {
        "monitor": "val_wer",
        "mode": "min",
        "save_top_k": 3,
        "save_last": True,
    }
})

exp_manager(trainer, exp_manager_cfg)
print(f"Trainer ready. Effective batch: {CONFIG['batch_size'] * CONFIG['gradient_accumulation']}")

## 7. Train

In [None]:
print("="*60)
print("STARTING TRAINING")
print("="*60)
print(f"Model: {CONFIG['model_name']}")
print(f"Learning rate: {CONFIG['learning_rate']}")
print(f"Batch: {CONFIG['batch_size']} x {CONFIG['gradient_accumulation']} = {CONFIG['batch_size'] * CONFIG['gradient_accumulation']}")
print(f"Max steps: {CONFIG['max_steps']}")
print(f"Precision: {CONFIG['precision']}")
print("="*60)

In [None]:
trainer.fit(asr_model)

## 8. Save Model

In [None]:
model_path = os.path.join(CONFIG["output_dir"], "parakeet_tdt_finetuned.nemo")
asr_model.save_to(model_path)
print(f"Model saved: {model_path}")

In [None]:
# Print summary
print("="*60)
print("FINE-TUNING COMPLETE")
print("="*60)
print(f"\nModel: {CONFIG['model_name']}")
print(f"Learning rate: {CONFIG['learning_rate']}")
print(f"\nData:")
print(f"  Train parquet: {CONFIG['train_parquet']}")
print(f"  Val parquet: {CONFIG['val_parquet']}")
print(f"\nSaved to: {model_path}")
print(f"\nTo load: nemo_asr.models.ASRModel.restore_from('{model_path}')")

## 9. Test Inference

Quick test of the fine-tuned model on samples from the test set.

In [None]:
# Load fine-tuned model
finetuned_model = nemo_asr.models.ASRModel.restore_from(model_path)
finetuned_model.eval()
print(f"Loaded fine-tuned model from: {model_path}")

In [None]:
# Test on samples from test parquet
test_results = []

if os.path.exists(CONFIG["test_parquet"]):
    test_df = pd.read_parquet(CONFIG["test_parquet"])
    print(f"Test parquet: {len(test_df)} rows")
    
    # Test on first 10 samples
    for idx, row in test_df.head(10).iterrows():
        blob_idx = row.get('azure_blob_index', row.get('original_parquet_index', idx))
        audio_path = find_audio_file(CONFIG["audio_dir"], blob_idx)
        
        if audio_path is None:
            print(f"  [{idx}] Audio not found for blob_idx={blob_idx}")
            continue
        
        print(f"  [{idx}] Processing: {audio_path}")
        
        try:
            # Run inference
            result = finetuned_model.transcribe([audio_path])
            hypothesis = result[0] if result else ""
            
            # Get ground truth
            gt = clean_raw_transcript_str(row.get('fulltext_file_str', ''))
            
            test_results.append({
                "file_id": idx,
                "blob_idx": blob_idx,
                "hypothesis": hypothesis,
                "ground_truth": gt,
                "audio_path": audio_path
            })
        except Exception as e:
            print(f"    Error: {e}")
    
    print(f"\nCompleted: {len(test_results)} files")
else:
    print(f"Test parquet not found: {CONFIG['test_parquet']}")

In [None]:
# View test results
print("=" * 70)
print("TEST INFERENCE RESULTS")
print("=" * 70)

for r in test_results[:3]:  # Show first 3
    print(f"\nFile ID: {r['file_id']} (blob_idx: {r['blob_idx']})")
    print(f"Audio: {r['audio_path']}")
    print(f"\nHypothesis (first 300 chars):")
    print(r['hypothesis'][:300] + "..." if len(r['hypothesis']) > 300 else r['hypothesis'])
    print(f"\nGround truth (first 300 chars):")
    print(r['ground_truth'][:300] + "..." if len(r['ground_truth']) > 300 else r['ground_truth'])
    print("-" * 70)

# Save test results
if test_results:
    test_output_dir = Path(CONFIG["output_dir"]) / "test-inference"
    test_output_dir.mkdir(parents=True, exist_ok=True)
    
    df_test_results = pd.DataFrame(test_results)
    df_test_results.to_parquet(test_output_dir / "test_results.parquet", index=False)
    print(f"\nTest results saved to: {test_output_dir / 'test_results.parquet'}")

## 10. Next Steps

To run full evaluation on the test set:

1. **Create inference config** for the fine-tuned model:
   ```yaml
   experiment_id: vhp-pre2010-parakeet-tdt-finetuned-test
   model:
     name: "parakeet-tdt-finetuned"
     nemo_path: "/workspace/outputs/vhp-pre2010-parakeet-tdt-0.6b-ft-a6000/parakeet_tdt_finetuned.nemo"
     device: "cuda"
   input:
     parquet_path: "data/raw/loc/veterans_history_project_resources_pre2010_test.parquet"
     audio_dir: "/workspace/audio/loc_vhp"
   output:
     dir: "outputs/vhp-pre2010-parakeet-tdt-finetuned-test"
   ```

2. **Run evaluation** with `scripts/evaluate.py` to compute WER metrics

3. **Compare** with baseline Parakeet (pre-fine-tuning) results