# Fine-tuning Parakeet-TDT with NeMo

**Target Hardware**: RunPod A6000 (48GB VRAM)

This notebook fine-tunes NVIDIA Parakeet on VHP oral history audio using NeMo's official `speech_to_text_finetune.py` script.

## Important Notes

Based on [official NVIDIA tutorials](https://github.com/nvidia-riva/tutorials/blob/main/asr-finetune-parakeet-nemo.ipynb) and [community discussions](https://github.com/NVIDIA-NeMo/NeMo/issues/13825):

1. **Data Requirements**: Fine-tuning typically needs 1000+ hours for good WER. With ~31 demo samples, this is just a **smoke test**.
2. **Memory**: The 0.6B model ideally needs 80GB+ VRAM. A6000 (48GB) may require small batch sizes.
3. **Training Method**: Uses `speech_to_text_finetune.py` CLI script (not Python API directly).
4. **Model**: Uses FastConformer Hybrid Transducer-CTC architecture.

## Data Requirements
- **Segmented parquet files** with `segmented_audio_url` (Azure blob) and `segmented_audio_transcript` columns
- Audio segments are downloaded from Azure blob storage to local cache
- NeMo manifest format: `{"audio_filepath": "...", "text": "...", "duration": ...}`

## References
- [NVIDIA Riva Parakeet Fine-tuning Tutorial](https://github.com/nvidia-riva/tutorials/blob/main/asr-finetune-parakeet-nemo.ipynb)
- [NeMo ASR Models Documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/models.html)
- [Parakeet-TDT-0.6B-v3 on HuggingFace](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3)

## 1. Setup Dependencies

Fine-tuning with NeMo requires additional packages. Add them via uv:

In [None]:
# Add fine-tuning dependencies to pyproject.toml (run once)
# uv add nemo_toolkit[asr] pytorch-lightning>=2.0
#
# For A6000 with CUDA 11.8:
# uv add torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
import sys
import os
import json
import re
import tempfile
from pathlib import Path

# Add project root to path for package imports
PROJECT_ROOT = Path.cwd().parent
sys.path.insert(0, str(PROJECT_ROOT))

# Load Azure credentials from .env file
from dotenv import load_dotenv
load_dotenv(dotenv_path=PROJECT_ROOT / 'credentials/creds.env')

import torch
import librosa
import pandas as pd
from omegaconf import OmegaConf, DictConfig
import pytorch_lightning as pl
import nemo.collections.asr as nemo_asr
from nemo.utils import exp_manager
from pydub import AudioSegment

# Import project modules
from scripts.eval.evaluate import clean_raw_transcript_str
from scripts.cloud import azure_utils

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"Project root: {PROJECT_ROOT}")
print(f"Azure credentials loaded: {os.getenv('AZURE_STORAGE_CONNECTION_STRING') is not None}")

## 2. Configuration

In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================

CONFIG = {
    # Data paths - Demo segmented parquets (small 31-sample set for testing)
    "train_parquet": str(PROJECT_ROOT / "data/raw/loc/veterans_history_project_resources_pre2010_train_nfa_segmented_demo.parquet"),
    "val_parquet": str(PROJECT_ROOT / "data/raw/loc/veterans_history_project_resources_pre2010_val_nfa_segmented.parquet"),
    
    # Azure blob settings (for downloading segmented audio)
    "blob_prefix": "loc_vhp",
    
    # Sampling - None to use all data in demo parquets
    "train_sample_size": None,
    "val_sample_size": 10,  # Small val set for demo
    "random_seed": 42,
    
    # Local audio cache directory (segments downloaded from Azure)
    "audio_cache_dir": "/workspace/audio_cache/parakeet_ft",
    
    # NeMo manifest output (will be generated)
    "train_manifest": "/workspace/data/parakeet_train_manifest.json",
    "val_manifest": "/workspace/data/parakeet_val_manifest.json",
    
    # Output directory - follows convention: {dataset}-{model}-{task}-{infra}
    "output_dir": "/workspace/outputs/vhp-pre2010-parakeet-tdt-0.6b-ft-a6000",
    "exp_name": "parakeet_tdt_vhp",
    
    # Model
    "model_name": "nvidia/parakeet-tdt-0.6b-v3",
    
    # Training hyperparameters (reduced for demo)
    "learning_rate": 5e-5,
    "batch_size": 4,
    "gradient_accumulation": 2,
    "warmup_steps": 50,
    "max_steps": 500,        # Small for demo
    "val_check_interval": 100,
    
    # Audio settings
    "sample_rate": 16000,
    "max_duration": 30,   # Max audio duration per sample (seconds)
    "min_duration": 0.1,
    
    # Precision
    "precision": "16-mixed",
}

# Create output directories
os.makedirs(CONFIG["output_dir"], exist_ok=True)
os.makedirs(CONFIG["audio_cache_dir"], exist_ok=True)
os.makedirs(os.path.dirname(CONFIG["train_manifest"]), exist_ok=True)
print(f"Output directory: {CONFIG['output_dir']}")
print(f"Audio cache: {CONFIG['audio_cache_dir']}")
print(f"Train sample size: {CONFIG['train_sample_size']} (None = use all)")
print(f"Val sample size: {CONFIG['val_sample_size']}")

## 3. Load and Sample Data

Load segmented parquet files and sample train/val sets.

In [None]:
def load_segmented_parquet(parquet_path: str, sample_size: int = None, random_seed: int = 42):
    """
    Load segmented parquet and optionally sample.
    
    Filters to rows that have both segmented_audio_url and segmented_audio_transcript.
    """
    df = pd.read_parquet(parquet_path)
    print(f"Loaded {len(df)} rows from {parquet_path}")
    
    # Filter to rows with required columns
    df = df[df['segmented_audio_url'].notna() & (df['segmented_audio_url'] != '')]
    print(f"Filtered to {len(df)} rows with segmented_audio_url")
    
    df = df[df['segmented_audio_transcript'].notna() & (df['segmented_audio_transcript'] != '')]
    print(f"Filtered to {len(df)} rows with segmented_audio_transcript")
    
    # Sample if requested
    if sample_size and sample_size < len(df):
        df = df.sample(n=sample_size, random_state=random_seed)
        print(f"Sampled {len(df)} rows")
    
    return df.reset_index(drop=True)

# Load datasets
print("Loading training data...")
df_train = load_segmented_parquet(
    CONFIG["train_parquet"],
    sample_size=CONFIG["train_sample_size"],
    random_seed=CONFIG["random_seed"]
)

print("\nLoading validation data...")
df_val = load_segmented_parquet(
    CONFIG["val_parquet"],
    sample_size=CONFIG["val_sample_size"],
    random_seed=CONFIG["random_seed"]
)

print(f"\nTrain: {len(df_train)} samples")
print(f"Val: {len(df_val)} samples")

In [None]:
def download_audio_to_local(blob_url: str, cache_dir: str) -> str:
    """
    Download audio from Azure blob to local cache.
    
    Returns local path, or None if failed.
    Skips download if file already exists in cache.
    """
    # Create deterministic local filename from blob URL
    # e.g., "loc_vhp_segments/123/segment_0.wav" -> "123_segment_0.wav"
    blob_name = blob_url.split('/')[-2] + '_' + blob_url.split('/')[-1]
    local_path = os.path.join(cache_dir, blob_name)
    
    # Skip if already cached
    if os.path.exists(local_path):
        return local_path
    
    try:
        # Download from Azure
        audio_bytes = azure_utils.download_blob_to_memory(blob_url)
        
        # Write to local file
        with open(local_path, 'wb') as f:
            f.write(audio_bytes)
        
        return local_path
    except Exception as e:
        print(f"  Error downloading {blob_url}: {e}")
        return None


def create_nemo_manifest_from_segmented(
    df: pd.DataFrame,
    cache_dir: str,
    output_path: str,
    max_duration_sec: float = 30.0
):
    """
    Create NeMo manifest from segmented parquet.
    
    Downloads audio from Azure blob URLs to local cache,
    then creates NeMo manifest with local paths.
    
    NeMo manifest format (JSON lines):
    {"audio_filepath": "/path/to/audio.wav", "text": "transcription", "duration": 15.4}
    """
    entries = []
    skipped_download = 0
    skipped_duration = 0
    skipped_empty = 0
    
    print(f"Processing {len(df)} samples...")
    
    for idx, row in df.iterrows():
        # Get transcript (already cleaned in segmentation)
        transcript = row.get('segmented_audio_transcript', '').strip()
        if not transcript:
            skipped_empty += 1
            continue
        
        # Download audio to local cache
        blob_url = row.get('segmented_audio_url', '')
        local_path = download_audio_to_local(blob_url, cache_dir)
        
        if local_path is None:
            skipped_download += 1
            continue
        
        # Get duration from metadata or load file
        duration = row.get('segmented_audio_duration', None)
        if duration is None:
            try:
                y, sr = librosa.load(local_path, sr=None)
                duration = librosa.get_duration(y=y, sr=sr)
            except Exception as e:
                print(f"  Error loading {local_path}: {e}")
                skipped_download += 1
                continue
        
        # Skip if too long
        if duration > max_duration_sec:
            skipped_duration += 1
            continue
        
        # NeMo expects lowercase text
        entries.append({
            "audio_filepath": local_path,
            "text": transcript.lower(),
            "duration": round(float(duration), 2)
        })
        
        if len(entries) % 200 == 0:
            print(f"  Processed {len(entries)} entries...")
    
    # Write manifest
    with open(output_path, 'w') as f:
        for entry in entries:
            f.write(json.dumps(entry) + '\n')
    
    total_hours = sum(e['duration'] for e in entries) / 3600
    print(f"\nCreated manifest: {output_path}")
    print(f"  Valid entries: {len(entries)}")
    print(f"  Skipped (download failed): {skipped_download}")
    print(f"  Skipped (empty transcript): {skipped_empty}")
    print(f"  Skipped (too long): {skipped_duration}")
    print(f"  Total hours: {total_hours:.2f}")
    
    return entries

In [None]:
# Create training manifest (downloads audio to local cache)
print("="*60)
print("CREATING TRAINING MANIFEST")
print("="*60)
train_entries = create_nemo_manifest_from_segmented(
    df_train,
    CONFIG["audio_cache_dir"],
    CONFIG["train_manifest"],
    max_duration_sec=CONFIG["max_duration"]
)

print("\n" + "="*60)
print("CREATING VALIDATION MANIFEST")
print("="*60)
val_entries = create_nemo_manifest_from_segmented(
    df_val,
    CONFIG["audio_cache_dir"],
    CONFIG["val_manifest"],
    max_duration_sec=CONFIG["max_duration"]
)

In [None]:
# Preview a sample from manifest
import random
random.seed(CONFIG["random_seed"])

with open(CONFIG["train_manifest"], 'r') as f:
    manifest_lines = f.readlines()

if manifest_lines:
    sample_entry = json.loads(random.choice(manifest_lines))
    print("Sample manifest entry:")
    print(f"  Audio: {sample_entry['audio_filepath']}")
    print(f"  Duration: {sample_entry['duration']:.1f}s")
    print(f"  Text (first 200 chars): {sample_entry['text'][:200]}...")
else:
    print("WARNING: No entries in manifest!")

## 4. Find NeMo Installation Path

We need to locate the NeMo examples directory for the fine-tuning script.

In [None]:
# Find NeMo installation path
# Note: NeMo examples are NOT included in pip install, must clone the repo
import nemo

NEMO_GIT_DIR = "/workspace/NeMo"

# Always use the git repo for examples (pip install doesn't include them)
if not os.path.exists(NEMO_GIT_DIR):
    print("Cloning NeMo repo for training scripts...")
    !git clone --depth 1 https://github.com/NVIDIA/NeMo.git {NEMO_GIT_DIR}
else:
    print(f"NeMo repo already exists at {NEMO_GIT_DIR}")

NEMO_EXAMPLES = os.path.join(NEMO_GIT_DIR, "examples")

print(f"NeMo package: {os.path.dirname(nemo.__file__)}")
print(f"NeMo examples: {NEMO_EXAMPLES}")
print(f"Finetune script exists: {os.path.exists(os.path.join(NEMO_EXAMPLES, 'asr/speech_to_text_finetune.py'))}")

## 5. Configure Training

The official method uses NeMo's `speech_to_text_finetune.py` script with command-line config overrides.

**Key Parameters:**
- `init_from_pretrained_model`: Load pretrained FastConformer model
- `trainer.max_epochs`: Number of epochs (set low for demo)
- `trainer.precision`: bf16 for memory efficiency
- `model.optim.lr`: Learning rate

In [None]:
# Training configuration
TRAIN_CONFIG = {
    # Pretrained model to fine-tune
    # Options: stt_en_fastconformer_hybrid_large_pc (recommended for Parakeet-like)
    #          nvidia/parakeet-tdt-0.6b-v3 (direct, but may need different script)
    "pretrained_model": "stt_en_fastconformer_hybrid_large_pc",
    
    # Training parameters (reduced for demo)
    "max_epochs": 5,           # Set to 50-200 for real training
    "precision": "bf16",       # bf16 or 16-mixed
    "lr": 0.0001,              # Learning rate (0.1 in tutorial, but that's aggressive)
    "weight_decay": 0.001,
    "warmup_steps": 50,
    
    # Batch size (small for A6000 memory)
    "batch_size": 4,
    
    # Checkpointing
    "exp_dir": CONFIG["output_dir"],
    "exp_name": "parakeet_ft_demo",
}

print("Training Configuration:")
for k, v in TRAIN_CONFIG.items():
    print(f"  {k}: {v}")

In [None]:
# Build the training command
# Based on: https://github.com/nvidia-riva/tutorials/blob/main/asr-finetune-parakeet-nemo.ipynb

finetune_script = os.path.join(NEMO_EXAMPLES, "asr/speech_to_text_finetune.py")
config_path = "../asr/conf/fastconformer/hybrid_transducer_ctc/"
config_name = "fastconformer_hybrid_transducer_ctc_bpe"

train_cmd = f"""python {finetune_script} \\
  --config-path="{config_path}" \\
  --config-name={config_name} \\
  +init_from_pretrained_model={TRAIN_CONFIG['pretrained_model']} \\
  ++model.train_ds.manifest_filepath="{CONFIG['train_manifest']}" \\
  ++model.validation_ds.manifest_filepath="{CONFIG['val_manifest']}" \\
  ++model.train_ds.batch_size={TRAIN_CONFIG['batch_size']} \\
  ++model.validation_ds.batch_size={TRAIN_CONFIG['batch_size']} \\
  ++model.optim.sched.d_model=1024 \\
  ++trainer.devices=1 \\
  ++trainer.max_epochs={TRAIN_CONFIG['max_epochs']} \\
  ++trainer.precision={TRAIN_CONFIG['precision']} \\
  ++model.optim.name="adamw" \\
  ++model.optim.lr={TRAIN_CONFIG['lr']} \\
  ++model.optim.weight_decay={TRAIN_CONFIG['weight_decay']} \\
  ++model.optim.sched.warmup_steps={TRAIN_CONFIG['warmup_steps']} \\
  ++exp_manager.exp_dir={TRAIN_CONFIG['exp_dir']} \\
  ++exp_manager.name={TRAIN_CONFIG['exp_name']} \\
  ++exp_manager.use_datetime_version=False \\
  ++exp_manager.version=v1
"""

print("Training command:")
print(train_cmd)
print("\nScript exists:", os.path.exists(finetune_script))

## 6. Run Training

⚠️ **Warning**: This will take a while even for the demo dataset. For the full dataset, consider running in a screen session.

In [None]:
# Run training
# Note: This uses subprocess to capture output properly in Jupyter
import subprocess

print("="*60)
print("STARTING TRAINING")
print("="*60)
print(f"Pretrained model: {TRAIN_CONFIG['pretrained_model']}")
print(f"Train manifest: {CONFIG['train_manifest']}")
print(f"Val manifest: {CONFIG['val_manifest']}")
print(f"Max epochs: {TRAIN_CONFIG['max_epochs']}")
print(f"Output dir: {TRAIN_CONFIG['exp_dir']}")
print("="*60)

# Execute training
!{train_cmd}

## 7. Find Saved Model

After training, NeMo saves checkpoints to the experiment directory.

In [None]:
# Find saved model checkpoint
import glob

checkpoint_dir = os.path.join(TRAIN_CONFIG['exp_dir'], TRAIN_CONFIG['exp_name'], "v1", "checkpoints")
print(f"Looking for checkpoints in: {checkpoint_dir}")

if os.path.exists(checkpoint_dir):
    nemo_files = glob.glob(os.path.join(checkpoint_dir, "*.nemo"))
    ckpt_files = glob.glob(os.path.join(checkpoint_dir, "*.ckpt"))
    
    print(f"\n.nemo files: {nemo_files}")
    print(f".ckpt files: {ckpt_files}")
    
    if nemo_files:
        MODEL_PATH = nemo_files[0]
        print(f"\nUsing model: {MODEL_PATH}")
    elif ckpt_files:
        MODEL_PATH = ckpt_files[-1]  # Use latest
        print(f"\nUsing checkpoint: {MODEL_PATH}")
    else:
        MODEL_PATH = None
        print("\nNo checkpoints found!")
else:
    print(f"Checkpoint directory not found: {checkpoint_dir}")
    MODEL_PATH = None

In [None]:
print("="*60)
print("TRAINING COMPLETE")
print("="*60)
print(f"\nModel saved to: {MODEL_PATH}")
print(f"\nTo load later:")
print(f"  model = nemo_asr.models.ASRModel.restore_from('{MODEL_PATH}')")

## 8. Test Inference

Load the fine-tuned model and test on validation samples.

In [None]:
# Load fine-tuned model and run test inference
test_results = []

if MODEL_PATH and os.path.exists(MODEL_PATH):
    print(f"Loading model from: {MODEL_PATH}")
    finetuned_model = nemo_asr.models.ASRModel.restore_from(MODEL_PATH)
    finetuned_model.eval()
    finetuned_model = finetuned_model.cuda()
    print("Model loaded successfully!\n")
    
    # Test on validation manifest
    with open(CONFIG["val_manifest"], 'r') as f:
        val_manifest_lines = f.readlines()
    
    print(f"Testing on {min(5, len(val_manifest_lines))} samples from validation manifest...")
    
    for i, line in enumerate(val_manifest_lines[:5]):
        entry = json.loads(line)
        audio_path = entry['audio_filepath']
        ground_truth = entry['text']
        
        print(f"  [{i}] Processing: {os.path.basename(audio_path)}")
        
        try:
            result = finetuned_model.transcribe([audio_path])
            hypothesis = result[0] if result else ""
            
            test_results.append({
                "file_id": i,
                "hypothesis": hypothesis.lower(),
                "ground_truth": ground_truth,
                "audio_path": audio_path
            })
        except Exception as e:
            print(f"    Error: {e}")
    
    print(f"\nCompleted: {len(test_results)} files")
else:
    print("No model checkpoint found. Skipping inference test.")

In [None]:
# View test results
if 'test_results' in dir() and test_results:
    print("=" * 70)
    print("TEST INFERENCE RESULTS")
    print("=" * 70)
    
    for r in test_results[:3]:  # Show first 3
        print(f"\nFile ID: {r['file_id']}")
        print(f"Audio: {os.path.basename(r['audio_path'])}")
        print(f"\nHypothesis (first 200 chars):")
        print(r['hypothesis'][:200] + "..." if len(r['hypothesis']) > 200 else r['hypothesis'])
        print(f"\nGround truth (first 200 chars):")
        print(r['ground_truth'][:200] + "..." if len(r['ground_truth']) > 200 else r['ground_truth'])
        print("-" * 70)
else:
    print("No test results available.")

## 10. Next Steps

The fine-tuned model is saved to:
```
/workspace/outputs/vhp-pre2010-parakeet-tdt-0.6b-ft-a6000/parakeet_tdt_finetuned.nemo
```

To run full evaluation:

1. **Create inference script** for the fine-tuned NeMo model (similar to `infer_parakeet.py`)

2. **Run evaluation** with `scripts/eval/evaluate.py` to compute WER metrics

3. **Compare** with baseline Parakeet (pre-fine-tuning) results

To load the fine-tuned model later:
```python
import nemo.collections.asr as nemo_asr
model = nemo_asr.models.ASRModel.restore_from("/workspace/outputs/vhp-pre2010-parakeet-tdt-0.6b-ft-a6000/parakeet_tdt_finetuned.nemo")
```