# Speech-to-Text Model Fine-Tuning Pipeline

A configurable notebook for fine-tuning various STT models (Whisper, Parakeet, and more) on custom datasets.

## Features
- **Multi-model support**: Whisper, Parakeet, and extensible architecture for future models
- **Unified dataset handling**: Common dataset format for all models
- **Flexible data splitting**: Train/validation/test splits or separate datasets
- **Parallel training**: Optional concurrent training of multiple models
- **Colab compatible**: Designed to run on Google Colab with GPU support

## 1. Environment Setup

In [None]:
# Google Colab Setup (Run this first if on Colab)
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running on Google Colab")
    print("Mounting Google Drive for data persistence...")
    from google.colab import drive
    drive.mount('/content/drive')
else:
    print("Running locally")

In [None]:
# Install required packages
# IMPORTANT: Run this cell, then RESTART RUNTIME before running subsequent cells

INSTALL_PACKAGES = True  # Set to False if packages are already installed

if INSTALL_PACKAGES:
    import sys
    IN_COLAB = 'google.colab' in sys.modules
    
    if IN_COLAB:
        # Fix NumPy compatibility issue on Colab
        print("Step 1/4: Upgrading NumPy to fix compatibility issues...")
        !pip uninstall -y numpy
        !pip install numpy>=1.26.0
        
        print("\nStep 2/4: Installing audio decoding dependencies...")
        !pip install -q torchcodec soundfile librosa
        
        print("\nStep 3/4: Installing core packages...")
        !pip install -q --upgrade transformers datasets accelerate evaluate jiwer
        !pip install -q torch torchaudio --upgrade
        !pip install -q huggingface_hub
        
        print("\nStep 4/4: Installing NeMo for Parakeet (optional, may take a while)...")
        # Uncomment the line below if you need Parakeet support
        # !pip install -q nemo_toolkit[asr]
        
        print("\n" + "="*60)
        print("IMPORTANT: Please restart the runtime now!")
        print("Go to: Runtime -> Restart runtime")
        print("Then run cells starting from the imports cell (skip this cell)")
        print("="*60)
        
        # Auto-restart option for Colab
        # Uncomment the lines below to auto-restart
        # import os
        # os.kill(os.getpid(), 9)
    else:
        # Local installation
        !pip install -q transformers datasets accelerate evaluate jiwer
        !pip install -q torch torchaudio torchcodec
        !pip install -q soundfile librosa
        !pip install -q huggingface_hub
        # !pip install -q nemo_toolkit[asr]  # Uncomment for Parakeet support
        print("Packages installed successfully!")

In [None]:
# Core imports
# NOTE: If you see NumPy errors, restart the runtime first (Runtime -> Restart runtime)

import os
import json
import logging
from pathlib import Path
from typing import Dict, List, Optional, Union, Any, Callable
from dataclasses import dataclass, field, asdict
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
import multiprocessing as mp

import numpy as np
print(f"NumPy version: {np.__version__}")

import torch
import torchaudio
from datasets import Dataset, DatasetDict, Audio, load_dataset, concatenate_datasets
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainerCallback,
)
import evaluate

# Detect Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

print("\nAll imports successful!")

## 2. Configuration

All configurable parameters are centralized here for easy modification.

In [None]:
@dataclass
class DatasetConfig:
    """Configuration for dataset handling."""
    # Primary dataset sources (can be HuggingFace dataset names or local paths)
    train_datasets: List[str] = field(default_factory=lambda: ["mozilla-foundation/common_voice_11_0"])
    
    # Optional separate validation/test datasets (if None, will split from train)
    validation_dataset: Optional[str] = None
    test_dataset: Optional[str] = None
    
    # Dataset split ratios (used if validation/test datasets not provided)
    train_split: float = 0.8
    validation_split: float = 0.1
    test_split: float = 0.1
    
    # Column names in dataset
    audio_column: str = "audio"
    text_column: str = "sentence"
    
    # Audio settings
    sampling_rate: int = 16000
    max_audio_length_seconds: float = 30.0
    min_audio_length_seconds: float = 0.5
    
    # HuggingFace dataset settings
    # For datasets like Common Voice that require a config name (e.g., language code)
    dataset_config_name: Optional[str] = "en"  # e.g., "en" for Common Voice
    dataset_split: str = "train"  # Which split to load
    
    # Data limits (for quick testing)
    max_train_samples: Optional[int] = None
    max_val_samples: Optional[int] = None
    max_test_samples: Optional[int] = None
    
    # Trust remote code (for some HuggingFace datasets)
    trust_remote_code: bool = True

In [None]:
@dataclass
class ModelConfig:
    """Configuration for a single model."""
    model_type: str  # "whisper", "parakeet", etc.
    model_name: str  # e.g., "openai/whisper-small", "nvidia/parakeet-ctc-1.1b"
    
    # Training hyperparameters
    learning_rate: float = 1e-5
    batch_size: int = 8
    gradient_accumulation_steps: int = 2
    num_epochs: int = 3
    warmup_steps: int = 500
    weight_decay: float = 0.01
    
    # Model-specific settings
    freeze_encoder: bool = False
    freeze_encoder_layers: int = 0  # Number of encoder layers to freeze (0 = none)
    
    # Output settings
    output_dir: str = "./outputs"
    save_steps: int = 500
    eval_steps: int = 500
    logging_steps: int = 100
    
    # Mixed precision
    fp16: bool = True
    bf16: bool = False
    
    # Additional kwargs for model-specific configurations
    extra_kwargs: Dict[str, Any] = field(default_factory=dict)

In [None]:
@dataclass
class TrainingConfig:
    """Master configuration for the training pipeline."""
    # Models to train
    models: List[ModelConfig] = field(default_factory=list)
    
    # Dataset configuration
    dataset: DatasetConfig = field(default_factory=DatasetConfig)
    
    # Parallel training settings
    enable_parallel: bool = False
    max_parallel_models: int = 2  # Max models to train in parallel
    
    # General settings
    seed: int = 42
    output_base_dir: str = "./outputs"
    
    # Evaluation settings
    compute_wer: bool = True
    compute_cer: bool = True
    
    # Checkpointing
    resume_from_checkpoint: Optional[str] = None
    save_total_limit: int = 3
    
    def __post_init__(self):
        Path(self.output_base_dir).mkdir(parents=True, exist_ok=True)

In [None]:
# ============================================================================
# MAIN CONFIGURATION - MODIFY THIS SECTION
# ============================================================================

# Define models to fine-tune
MODELS_TO_TRAIN = [
    ModelConfig(
        model_type="whisper",
        model_name="openai/whisper-small",  # Use smaller model for testing
        learning_rate=1e-5,
        batch_size=8,
        num_epochs=3,
        output_dir="./outputs/whisper-small-finetuned",
        freeze_encoder=False,
    ),
    # Uncomment to add more models:
    # ModelConfig(
    #     model_type="whisper",
    #     model_name="openai/whisper-medium",
    #     learning_rate=5e-6,
    #     batch_size=4,
    #     num_epochs=3,
    #     output_dir="./outputs/whisper-medium-finetuned",
    # ),
    # ModelConfig(
    #     model_type="parakeet",
    #     model_name="nvidia/parakeet-ctc-1.1b",
    #     learning_rate=1e-5,
    #     batch_size=8,
    #     num_epochs=3,
    #     output_dir="./outputs/parakeet-finetuned",
    # ),
]

# Dataset configuration
DATASET_CONFIG = DatasetConfig(
    # Add your dataset sources here
    train_datasets=[
        "mozilla-foundation/common_voice_11_0",
        # Add more datasets to combine them:
        # "librispeech_asr",
        # "/path/to/local/dataset",
    ],
    
    # Optional: Provide separate validation/test datasets
    # If None, will split from train_datasets
    validation_dataset=None,
    test_dataset=None,
    
    # Split ratios (used if validation/test datasets not provided)
    train_split=0.8,
    validation_split=0.1,
    test_split=0.1,
    
    # Column names (adjust based on your dataset)
    audio_column="audio",
    text_column="sentence",
    
    # HuggingFace dataset settings
    # For Common Voice: use language code like "en", "es", "fr", etc.
    # For other datasets: set to None if no config name needed
    dataset_config_name="en",
    dataset_split="train",
    
    # Limit samples for testing (set to None for full dataset)
    max_train_samples=1000,  # Set to None for full dataset
    max_val_samples=100,
    max_test_samples=100,
    
    # Trust remote code for some HuggingFace datasets
    trust_remote_code=True,
)

# Master configuration
CONFIG = TrainingConfig(
    models=MODELS_TO_TRAIN,
    dataset=DATASET_CONFIG,
    
    # Parallel training
    enable_parallel=False,  # Set to True to train models in parallel
    max_parallel_models=2,
    
    # General settings
    seed=42,
    output_base_dir="./outputs",
    
    # Evaluation
    compute_wer=True,
    compute_cer=True,
)

# Colab-specific paths
if IN_COLAB:
    CONFIG.output_base_dir = "/content/drive/MyDrive/stt-finetuning/outputs"
    for model in CONFIG.models:
        model.output_dir = f"/content/drive/MyDrive/stt-finetuning/{model.model_name.split('/')[-1]}-finetuned"

print("Configuration loaded successfully!")
print(f"Models to train: {[m.model_name for m in CONFIG.models]}")
print(f"Dataset: {CONFIG.dataset.train_datasets}")
print(f"Dataset config: {CONFIG.dataset.dataset_config_name}")
print(f"Parallel training: {CONFIG.enable_parallel}")

## 3. Dataset Handler

Unified dataset loading, preprocessing, and splitting for all models.

In [None]:
class DatasetHandler:
    """Handles dataset loading, preprocessing, and splitting."""
    
    def __init__(self, config: DatasetConfig):
        self.config = config
        self._datasets: Optional[DatasetDict] = None
    
    def load_single_dataset(self, source: str) -> Dataset:
        """Load a single dataset from HuggingFace or local path."""
        logger.info(f"Loading dataset from: {source}")
        
        if os.path.exists(source):
            # Local dataset
            if source.endswith('.json') or source.endswith('.jsonl'):
                dataset = load_dataset('json', data_files=source, split='train')
            elif source.endswith('.csv'):
                dataset = load_dataset('csv', data_files=source, split='train')
            elif os.path.isdir(source):
                dataset = load_dataset('audiofolder', data_dir=source, split='train')
            else:
                raise ValueError(f"Unsupported local dataset format: {source}")
        else:
            # HuggingFace dataset
            logger.info(f"Loading HuggingFace dataset: {source}, config: {self.config.dataset_config_name}, split: {self.config.dataset_split}")
            
            try:
                # Try loading with config name (for datasets like Common Voice)
                if self.config.dataset_config_name:
                    dataset = load_dataset(
                        source,
                        self.config.dataset_config_name,
                        split=self.config.dataset_split,
                        trust_remote_code=self.config.trust_remote_code,
                    )
                else:
                    dataset = load_dataset(
                        source,
                        split=self.config.dataset_split,
                        trust_remote_code=self.config.trust_remote_code,
                    )
            except Exception as e:
                logger.warning(f"Failed to load with config name, trying without: {e}")
                # Fallback: try loading without config name
                dataset = load_dataset(
                    source,
                    split=self.config.dataset_split,
                    trust_remote_code=self.config.trust_remote_code,
                )
            
            # Handle DatasetDict if returned
            if isinstance(dataset, DatasetDict):
                available_splits = list(dataset.keys())
                logger.info(f"Available splits: {available_splits}")
                dataset = dataset[available_splits[0]]
        
        logger.info(f"Loaded {len(dataset)} samples")
        return dataset
    
    def combine_datasets(self, datasets: List[Dataset]) -> Dataset:
        """Combine multiple datasets into one."""
        if len(datasets) == 1:
            return datasets[0]
        
        # Standardize column names across datasets
        standardized = []
        for ds in datasets:
            # Rename columns if needed
            if self.config.audio_column not in ds.column_names:
                # Try common audio column names
                for col in ['audio', 'file', 'path', 'audio_path']:
                    if col in ds.column_names:
                        ds = ds.rename_column(col, self.config.audio_column)
                        break
            
            if self.config.text_column not in ds.column_names:
                # Try common text column names
                for col in ['sentence', 'text', 'transcription', 'transcript']:
                    if col in ds.column_names:
                        ds = ds.rename_column(col, self.config.text_column)
                        break
            
            # Keep only necessary columns
            cols_to_keep = [self.config.audio_column, self.config.text_column]
            cols_to_keep = [c for c in cols_to_keep if c in ds.column_names]
            ds = ds.select_columns(cols_to_keep)
            standardized.append(ds)
        
        return concatenate_datasets(standardized)
    
    def filter_by_audio_length(self, dataset: Dataset) -> Dataset:
        """Filter samples by audio length."""
        def is_valid_length(example):
            audio = example[self.config.audio_column]
            if isinstance(audio, dict) and 'array' in audio:
                duration = len(audio['array']) / audio.get('sampling_rate', self.config.sampling_rate)
            else:
                return True  # Can't determine length, keep sample
            
            return self.config.min_audio_length_seconds <= duration <= self.config.max_audio_length_seconds
        
        return dataset.filter(is_valid_length)
    
    def prepare_datasets(self) -> DatasetDict:
        """Load and prepare all datasets with train/val/test splits."""
        if self._datasets is not None:
            return self._datasets
        
        # Load and combine training datasets
        train_datasets = [self.load_single_dataset(src) for src in self.config.train_datasets]
        combined_train = self.combine_datasets(train_datasets)
        
        logger.info(f"Combined dataset columns: {combined_train.column_names}")
        
        # Cast audio column
        if self.config.audio_column in combined_train.column_names:
            combined_train = combined_train.cast_column(
                self.config.audio_column, 
                Audio(sampling_rate=self.config.sampling_rate)
            )
        
        # Filter by audio length
        logger.info("Filtering by audio length...")
        combined_train = self.filter_by_audio_length(combined_train)
        logger.info(f"After filtering: {len(combined_train)} samples")
        
        # Handle validation dataset
        if self.config.validation_dataset:
            validation = self.load_single_dataset(self.config.validation_dataset)
            validation = validation.cast_column(
                self.config.audio_column,
                Audio(sampling_rate=self.config.sampling_rate)
            )
        else:
            validation = None
        
        # Handle test dataset
        if self.config.test_dataset:
            test = self.load_single_dataset(self.config.test_dataset)
            test = test.cast_column(
                self.config.audio_column,
                Audio(sampling_rate=self.config.sampling_rate)
            )
        else:
            test = None
        
        # Split if validation/test not provided
        if validation is None or test is None:
            logger.info("Splitting dataset into train/val/test...")
            # First split: train + (val + test)
            split1 = combined_train.train_test_split(
                test_size=(self.config.validation_split + self.config.test_split),
                seed=42
            )
            train = split1['train']
            
            # Second split: val + test
            val_test_ratio = self.config.test_split / (self.config.validation_split + self.config.test_split)
            split2 = split1['test'].train_test_split(test_size=val_test_ratio, seed=42)
            
            if validation is None:
                validation = split2['train']
            if test is None:
                test = split2['test']
        else:
            train = combined_train
        
        # Apply sample limits
        if self.config.max_train_samples:
            train = train.select(range(min(len(train), self.config.max_train_samples)))
        if self.config.max_val_samples:
            validation = validation.select(range(min(len(validation), self.config.max_val_samples)))
        if self.config.max_test_samples:
            test = test.select(range(min(len(test), self.config.max_test_samples)))
        
        self._datasets = DatasetDict({
            'train': train,
            'validation': validation,
            'test': test
        })
        
        logger.info(f"Dataset sizes - Train: {len(train)}, Val: {len(validation)}, Test: {len(test)}")
        
        return self._datasets
    
    def get_datasets(self) -> DatasetDict:
        """Get prepared datasets."""
        return self.prepare_datasets()

## 4. Model Trainers

Abstract base class and implementations for different STT models.

In [None]:
class BaseSTTTrainer(ABC):
    """Abstract base class for STT model trainers."""
    
    def __init__(self, model_config: ModelConfig, dataset_config: DatasetConfig):
        self.model_config = model_config
        self.dataset_config = dataset_config
        self.model = None
        self.processor = None
        self.wer_metric = evaluate.load("wer")
        self.cer_metric = evaluate.load("cer")
    
    @abstractmethod
    def load_model(self):
        """Load the model and processor."""
        pass
    
    @abstractmethod
    def preprocess_dataset(self, dataset: DatasetDict) -> DatasetDict:
        """Preprocess dataset for this specific model."""
        pass
    
    @abstractmethod
    def train(self, dataset: DatasetDict) -> Dict[str, Any]:
        """Train the model and return results."""
        pass
    
    @abstractmethod
    def evaluate(self, dataset: Dataset) -> Dict[str, float]:
        """Evaluate the model on a dataset."""
        pass
    
    def compute_metrics(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
        """Compute WER and CER metrics."""
        # Normalize texts
        predictions = [pred.lower().strip() for pred in predictions]
        references = [ref.lower().strip() for ref in references]
        
        metrics = {}
        metrics['wer'] = self.wer_metric.compute(predictions=predictions, references=references)
        metrics['cer'] = self.cer_metric.compute(predictions=predictions, references=references)
        
        return metrics
    
    def save_model(self, path: str):
        """Save the fine-tuned model."""
        Path(path).mkdir(parents=True, exist_ok=True)
        self.model.save_pretrained(path)
        if self.processor:
            self.processor.save_pretrained(path)
        logger.info(f"Model saved to {path}")

In [None]:
class WhisperTrainer(BaseSTTTrainer):
    """Trainer for OpenAI Whisper models."""
    
    def __init__(self, model_config: ModelConfig, dataset_config: DatasetConfig):
        super().__init__(model_config, dataset_config)
        self.data_collator = None
    
    def load_model(self):
        """Load Whisper model and processor."""
        logger.info(f"Loading Whisper model: {self.model_config.model_name}")
        
        self.processor = WhisperProcessor.from_pretrained(self.model_config.model_name)
        self.model = WhisperForConditionalGeneration.from_pretrained(
            self.model_config.model_name
        )
        
        # Configure model
        self.model.config.forced_decoder_ids = None
        self.model.config.suppress_tokens = []
        self.model.config.use_cache = False
        
        # Freeze encoder if specified
        if self.model_config.freeze_encoder:
            logger.info("Freezing encoder")
            for param in self.model.model.encoder.parameters():
                param.requires_grad = False
        elif self.model_config.freeze_encoder_layers > 0:
            logger.info(f"Freezing first {self.model_config.freeze_encoder_layers} encoder layers")
            for i, layer in enumerate(self.model.model.encoder.layers):
                if i < self.model_config.freeze_encoder_layers:
                    for param in layer.parameters():
                        param.requires_grad = False
        
        # Create data collator
        self.data_collator = self._create_data_collator()
        
        return self.model
    
    def _create_data_collator(self):
        """Create data collator for Whisper."""
        from dataclasses import dataclass
        from typing import Any, Dict, List, Union
        
        @dataclass
        class DataCollatorSpeechSeq2SeqWithPadding:
            processor: Any
            decoder_start_token_id: int

            def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
                # Split inputs and labels
                input_features = [{"input_features": feature["input_features"]} for feature in features]
                label_features = [{"input_ids": feature["labels"]} for feature in features]

                batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
                labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

                # Replace padding with -100
                labels = labels_batch["input_ids"].masked_fill(
                    labels_batch.attention_mask.ne(1), -100
                )

                # Remove BOS token if present
                if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
                    labels = labels[:, 1:]

                batch["labels"] = labels
                return batch
        
        return DataCollatorSpeechSeq2SeqWithPadding(
            processor=self.processor,
            decoder_start_token_id=self.model.config.decoder_start_token_id,
        )
    
    def preprocess_dataset(self, dataset: DatasetDict) -> DatasetDict:
        """Preprocess dataset for Whisper."""
        logger.info("Preprocessing dataset for Whisper")
        
        def prepare_dataset(batch):
            audio = batch[self.dataset_config.audio_column]
            
            # Compute input features
            batch["input_features"] = self.processor.feature_extractor(
                audio["array"],
                sampling_rate=audio["sampling_rate"]
            ).input_features[0]
            
            # Encode target text
            batch["labels"] = self.processor.tokenizer(
                batch[self.dataset_config.text_column]
            ).input_ids
            
            return batch
        
        processed = dataset.map(
            prepare_dataset,
            remove_columns=dataset['train'].column_names,
            num_proc=1,  # Audio processing doesn't parallelize well
        )
        
        return processed
    
    def train(self, dataset: DatasetDict) -> Dict[str, Any]:
        """Train Whisper model."""
        logger.info("Starting Whisper training")
        
        # Preprocess dataset
        processed_dataset = self.preprocess_dataset(dataset)
        
        # Training arguments
        training_args = Seq2SeqTrainingArguments(
            output_dir=self.model_config.output_dir,
            per_device_train_batch_size=self.model_config.batch_size,
            per_device_eval_batch_size=self.model_config.batch_size,
            gradient_accumulation_steps=self.model_config.gradient_accumulation_steps,
            learning_rate=self.model_config.learning_rate,
            warmup_steps=self.model_config.warmup_steps,
            num_train_epochs=self.model_config.num_epochs,
            weight_decay=self.model_config.weight_decay,
            fp16=self.model_config.fp16 and torch.cuda.is_available(),
            bf16=self.model_config.bf16 and torch.cuda.is_available(),
            evaluation_strategy="steps",
            eval_steps=self.model_config.eval_steps,
            save_steps=self.model_config.save_steps,
            logging_steps=self.model_config.logging_steps,
            save_total_limit=3,
            predict_with_generate=True,
            generation_max_length=225,
            load_best_model_at_end=True,
            metric_for_best_model="wer",
            greater_is_better=False,
            push_to_hub=False,
            report_to=["tensorboard"],
        )
        
        # Compute metrics function
        def compute_metrics(pred):
            pred_ids = pred.predictions
            label_ids = pred.label_ids
            
            # Replace -100 with pad token id
            label_ids[label_ids == -100] = self.processor.tokenizer.pad_token_id
            
            # Decode predictions and labels
            pred_str = self.processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
            label_str = self.processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            
            return self.compute_metrics(pred_str, label_str)
        
        # Create trainer
        trainer = Seq2SeqTrainer(
            model=self.model,
            args=training_args,
            train_dataset=processed_dataset['train'],
            eval_dataset=processed_dataset['validation'],
            data_collator=self.data_collator,
            compute_metrics=compute_metrics,
            processing_class=self.processor.feature_extractor,
        )
        
        # Train
        train_result = trainer.train()
        
        # Save model
        trainer.save_model()
        self.processor.save_pretrained(self.model_config.output_dir)
        
        # Evaluate on test set
        test_results = trainer.evaluate(processed_dataset['test'])
        
        return {
            'train_results': train_result,
            'test_results': test_results,
            'model_path': self.model_config.output_dir
        }
    
    def evaluate(self, dataset: Dataset) -> Dict[str, float]:
        """Evaluate model on dataset."""
        self.model.eval()
        predictions = []
        references = []
        
        for sample in dataset:
            audio = sample[self.dataset_config.audio_column]
            input_features = self.processor.feature_extractor(
                audio["array"],
                sampling_rate=audio["sampling_rate"],
                return_tensors="pt"
            ).input_features.to(device)
            
            with torch.no_grad():
                predicted_ids = self.model.generate(input_features)
            
            transcription = self.processor.batch_decode(
                predicted_ids, skip_special_tokens=True
            )[0]
            
            predictions.append(transcription)
            references.append(sample[self.dataset_config.text_column])
        
        return self.compute_metrics(predictions, references)

In [None]:
class ParakeetTrainer(BaseSTTTrainer):
    """Trainer for NVIDIA Parakeet models using NeMo."""
    
    def __init__(self, model_config: ModelConfig, dataset_config: DatasetConfig):
        super().__init__(model_config, dataset_config)
        self.nemo_trainer = None
    
    def load_model(self):
        """Load Parakeet model."""
        logger.info(f"Loading Parakeet model: {self.model_config.model_name}")
        
        try:
            import nemo.collections.asr as nemo_asr
            from omegaconf import OmegaConf, open_dict
        except ImportError:
            raise ImportError("NeMo toolkit not installed. Run: pip install nemo_toolkit[asr]")
        
        # Load pre-trained model
        if self.model_config.model_name.endswith('.nemo'):
            self.model = nemo_asr.models.ASRModel.restore_from(self.model_config.model_name)
        else:
            self.model = nemo_asr.models.ASRModel.from_pretrained(self.model_config.model_name)
        
        # Freeze encoder if specified
        if self.model_config.freeze_encoder:
            logger.info("Freezing encoder")
            self.model.encoder.freeze()
        
        return self.model
    
    def _create_manifest(self, dataset: Dataset, manifest_path: str):
        """Create NeMo manifest file from dataset."""
        import soundfile as sf
        
        manifest_dir = Path(manifest_path).parent
        audio_dir = manifest_dir / "audio"
        audio_dir.mkdir(parents=True, exist_ok=True)
        
        entries = []
        for idx, sample in enumerate(dataset):
            audio = sample[self.dataset_config.audio_column]
            text = sample[self.dataset_config.text_column]
            
            # Save audio file
            audio_path = audio_dir / f"audio_{idx}.wav"
            sf.write(str(audio_path), audio['array'], audio['sampling_rate'])
            
            # Create manifest entry
            duration = len(audio['array']) / audio['sampling_rate']
            entries.append({
                "audio_filepath": str(audio_path),
                "text": text,
                "duration": duration
            })
        
        # Write manifest
        with open(manifest_path, 'w') as f:
            for entry in entries:
                f.write(json.dumps(entry) + '\n')
        
        logger.info(f"Created manifest with {len(entries)} entries at {manifest_path}")
        return manifest_path
    
    def preprocess_dataset(self, dataset: DatasetDict) -> Dict[str, str]:
        """Preprocess dataset by creating NeMo manifest files."""
        logger.info("Creating NeMo manifest files")
        
        manifest_dir = Path(self.model_config.output_dir) / "manifests"
        manifest_dir.mkdir(parents=True, exist_ok=True)
        
        manifests = {}
        for split in ['train', 'validation', 'test']:
            manifest_path = manifest_dir / f"{split}_manifest.json"
            self._create_manifest(dataset[split], str(manifest_path))
            manifests[split] = str(manifest_path)
        
        return manifests
    
    def train(self, dataset: DatasetDict) -> Dict[str, Any]:
        """Train Parakeet model using manual training loop (PyTorch Lightning compatibility workaround)."""
        logger.info("Starting Parakeet training")
        
        try:
            from omegaconf import OmegaConf, open_dict
        except ImportError as e:
            raise ImportError(f"Required packages not installed: {e}")
        
        # Create manifest files
        manifests = self.preprocess_dataset(dataset)
        
        # Create output directory
        output_dir = Path(self.model_config.output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Update model config for fine-tuning using OmegaConf
        with open_dict(self.model.cfg):
            # Update training data
            self.model.cfg.train_ds.manifest_filepath = manifests['train']
            self.model.cfg.train_ds.batch_size = self.model_config.batch_size
            
            # Update validation data  
            self.model.cfg.validation_ds.manifest_filepath = manifests['validation']
            self.model.cfg.validation_ds.batch_size = self.model_config.batch_size
            
            # Update optimizer
            self.model.cfg.optim.lr = self.model_config.learning_rate
        
        # Setup data loaders
        self.model.setup_training_data(self.model.cfg.train_ds)
        self.model.setup_validation_data(self.model.cfg.validation_ds)
        
        # Use manual training loop (bypasses PyTorch Lightning compatibility issues)
        self._train_manual()
        
        # Save model
        model_path = output_dir / "parakeet_finetuned.nemo"
        self.model.save_to(str(model_path))
        logger.info(f"Model saved to {model_path}")
        
        # Evaluate on test set
        test_results = self.evaluate(dataset['test'])
        
        return {
            'model_path': str(model_path),
            'test_results': test_results
        }
    
    def _train_manual(self):
        """Manual training loop that bypasses PyTorch Lightning trainer.fit() issues."""
        logger.info("Using manual training loop for NeMo model")
        
        # Move model to device
        self.model = self.model.to(device)
        self.model.train()
        
        # Setup optimizer
        optimizer = torch.optim.AdamW(
            self.model.parameters(), 
            lr=self.model_config.learning_rate,
            weight_decay=self.model_config.weight_decay,
            betas=(0.9, 0.98)
        )
        
        # Get the training dataloader
        train_dl = self.model._train_dl
        if train_dl is None:
            raise ValueError("Training dataloader not set up. Call setup_training_data first.")
        
        # Setup mixed precision if enabled
        scaler = torch.cuda.amp.GradScaler() if self.model_config.fp16 and torch.cuda.is_available() else None
        
        total_steps = 0
        best_val_loss = float('inf')
        
        for epoch in range(self.model_config.num_epochs):
            logger.info(f"Epoch {epoch + 1}/{self.model_config.num_epochs}")
            
            epoch_loss = 0.0
            num_batches = 0
            self.model.train()
            
            for batch_idx, batch in enumerate(train_dl):
                # Move batch to device - handle different batch formats
                if isinstance(batch, (list, tuple)):
                    batch = [b.to(device) if isinstance(b, torch.Tensor) else b for b in batch]
                elif isinstance(batch, dict):
                    batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
                
                optimizer.zero_grad()
                
                try:
                    if scaler is not None:
                        # Mixed precision training
                        with torch.cuda.amp.autocast():
                            loss, _ = self.model.training_step(batch, batch_idx)
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        # Standard training
                        loss, _ = self.model.training_step(batch, batch_idx)
                        if isinstance(loss, dict):
                            loss = loss.get('loss', loss.get('val_loss', list(loss.values())[0]))
                        loss.backward()
                        optimizer.step()
                    
                    loss_value = loss.item() if hasattr(loss, 'item') else float(loss)
                    epoch_loss += loss_value
                    num_batches += 1
                    total_steps += 1
                    
                    if batch_idx % self.model_config.logging_steps == 0:
                        logger.info(f"  Step {total_steps}, Batch {batch_idx}, Loss: {loss_value:.4f}")
                        
                except Exception as e:
                    logger.warning(f"Error in batch {batch_idx}: {e}")
                    continue
            
            # Epoch summary
            avg_loss = epoch_loss / max(num_batches, 1)
            logger.info(f"  Epoch {epoch + 1} - Avg Train Loss: {avg_loss:.4f}")
            
            # Validation
            val_loss = self._validate()
            logger.info(f"  Epoch {epoch + 1} - Val Loss: {val_loss:.4f}")
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                checkpoint_path = Path(self.model_config.output_dir) / "best_checkpoint.nemo"
                self.model.save_to(str(checkpoint_path))
                logger.info(f"  New best model saved (val_loss: {val_loss:.4f})")
        
        logger.info(f"Training complete! Best validation loss: {best_val_loss:.4f}")
    
    def _validate(self) -> float:
        """Run validation and return average loss."""
        self.model.eval()
        val_dl = self.model._validation_dl
        
        if val_dl is None:
            return float('inf')
        
        total_loss = 0.0
        num_batches = 0
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_dl):
                # Move batch to device
                if isinstance(batch, (list, tuple)):
                    batch = [b.to(device) if isinstance(b, torch.Tensor) else b for b in batch]
                elif isinstance(batch, dict):
                    batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
                
                try:
                    loss = self.model.validation_step(batch, batch_idx)
                    if isinstance(loss, dict):
                        loss = loss.get('val_loss', loss.get('loss', list(loss.values())[0]))
                    
                    loss_value = loss.item() if hasattr(loss, 'item') else float(loss)
                    total_loss += loss_value
                    num_batches += 1
                except Exception as e:
                    logger.warning(f"Validation error in batch {batch_idx}: {e}")
                    continue
        
        self.model.train()
        return total_loss / max(num_batches, 1)
    
    def evaluate(self, dataset: Dataset) -> Dict[str, float]:
        """Evaluate model on dataset."""
        logger.info("Evaluating Parakeet model...")
        
        self.model.eval()
        self.model = self.model.to(device)
        
        predictions = []
        references = []
        
        for idx, sample in enumerate(dataset):
            audio = sample[self.dataset_config.audio_column]
            
            try:
                # Transcribe using the model
                transcription = self.model.transcribe(
                    [audio['array']],
                    batch_size=1
                )
                
                # Handle different return types
                if isinstance(transcription, list):
                    transcription = transcription[0]
                if isinstance(transcription, tuple):
                    transcription = transcription[0]
                    
                predictions.append(str(transcription))
                references.append(sample[self.dataset_config.text_column])
                
            except Exception as e:
                logger.warning(f"Failed to transcribe sample {idx}: {e}")
                continue
        
        if not predictions:
            logger.warning("No successful predictions made")
            return {'wer': 1.0, 'cer': 1.0}
        
        return self.compute_metrics(predictions, references)

In [None]:
# Model factory for creating trainers
class ModelFactory:
    """Factory for creating model trainers."""
    
    _trainers = {
        'whisper': WhisperTrainer,
        'parakeet': ParakeetTrainer,
    }
    
    @classmethod
    def register(cls, model_type: str, trainer_class: type):
        """Register a new model trainer."""
        cls._trainers[model_type] = trainer_class
    
    @classmethod
    def create(cls, model_config: ModelConfig, dataset_config: DatasetConfig) -> BaseSTTTrainer:
        """Create a trainer for the specified model type."""
        model_type = model_config.model_type.lower()
        
        if model_type not in cls._trainers:
            raise ValueError(
                f"Unknown model type: {model_type}. "
                f"Available types: {list(cls._trainers.keys())}"
            )
        
        return cls._trainers[model_type](model_config, dataset_config)
    
    @classmethod
    def available_models(cls) -> List[str]:
        """List available model types."""
        return list(cls._trainers.keys())

## 5. Training Pipeline

Orchestrates training with optional parallelization.

In [None]:
class TrainingPipeline:
    """Main training pipeline with parallel training support."""
    
    def __init__(self, config: TrainingConfig):
        self.config = config
        self.dataset_handler = DatasetHandler(config.dataset)
        self.results: Dict[str, Any] = {}
        
        # Set random seeds
        self._set_seed(config.seed)
    
    def _set_seed(self, seed: int):
        """Set random seeds for reproducibility."""
        import random
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
    
    def train_single_model(self, model_config: ModelConfig) -> Dict[str, Any]:
        """Train a single model."""
        logger.info(f"Training model: {model_config.model_name}")
        
        try:
            # Create trainer
            trainer = ModelFactory.create(model_config, self.config.dataset)
            
            # Load model
            trainer.load_model()
            
            # Get datasets
            datasets = self.dataset_handler.get_datasets()
            
            # Train
            results = trainer.train(datasets)
            results['status'] = 'success'
            results['model_name'] = model_config.model_name
            
            return results
            
        except Exception as e:
            logger.error(f"Error training {model_config.model_name}: {str(e)}")
            return {
                'status': 'failed',
                'model_name': model_config.model_name,
                'error': str(e)
            }
    
    def train_parallel(self) -> Dict[str, Any]:
        """Train multiple models in parallel."""
        logger.info(f"Starting parallel training of {len(self.config.models)} models")
        
        # Note: For GPU training, we use ThreadPoolExecutor since
        # ProcessPoolExecutor can have issues with CUDA contexts
        results = {}
        
        with ThreadPoolExecutor(max_workers=self.config.max_parallel_models) as executor:
            future_to_model = {
                executor.submit(self.train_single_model, model_config): model_config
                for model_config in self.config.models
            }
            
            for future in as_completed(future_to_model):
                model_config = future_to_model[future]
                try:
                    result = future.result()
                    results[model_config.model_name] = result
                except Exception as e:
                    logger.error(f"Training failed for {model_config.model_name}: {e}")
                    results[model_config.model_name] = {
                        'status': 'failed',
                        'error': str(e)
                    }
        
        return results
    
    def train_sequential(self) -> Dict[str, Any]:
        """Train models sequentially."""
        logger.info(f"Starting sequential training of {len(self.config.models)} models")
        
        results = {}
        for model_config in self.config.models:
            result = self.train_single_model(model_config)
            results[model_config.model_name] = result
            
            # Clear GPU memory between models
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        return results
    
    def run(self) -> Dict[str, Any]:
        """Run the training pipeline."""
        logger.info("Starting training pipeline")
        
        # Load datasets first (shared across all models)
        datasets = self.dataset_handler.get_datasets()
        logger.info(f"Datasets loaded: {datasets}")
        
        # Train models
        if self.config.enable_parallel and len(self.config.models) > 1:
            self.results = self.train_parallel()
        else:
            self.results = self.train_sequential()
        
        # Summary
        self._print_summary()
        
        return self.results
    
    def _print_summary(self):
        """Print training summary."""
        print("\n" + "="*60)
        print("TRAINING SUMMARY")
        print("="*60)
        
        for model_name, result in self.results.items():
            print(f"\nModel: {model_name}")
            print(f"  Status: {result.get('status', 'unknown')}")
            
            if result.get('status') == 'success':
                test_results = result.get('test_results', {})
                if 'wer' in test_results or 'eval_wer' in test_results:
                    wer = test_results.get('wer', test_results.get('eval_wer', 'N/A'))
                    print(f"  Test WER: {wer:.4f}" if isinstance(wer, float) else f"  Test WER: {wer}")
                if 'cer' in test_results or 'eval_cer' in test_results:
                    cer = test_results.get('cer', test_results.get('eval_cer', 'N/A'))
                    print(f"  Test CER: {cer:.4f}" if isinstance(cer, float) else f"  Test CER: {cer}")
                print(f"  Model saved: {result.get('model_path', 'N/A')}")
            else:
                print(f"  Error: {result.get('error', 'Unknown error')}")
        
        print("\n" + "="*60)

## 6. Run Training

In [None]:
# Verify configuration
print("Configuration Summary:")
print(f"  Models to train: {len(CONFIG.models)}")
for m in CONFIG.models:
    print(f"    - {m.model_type}: {m.model_name}")
print(f"  Training datasets: {CONFIG.dataset.train_datasets}")
print(f"  Parallel training: {CONFIG.enable_parallel}")
print(f"  Output directory: {CONFIG.output_base_dir}")

In [None]:
# Initialize and run the training pipeline
pipeline = TrainingPipeline(CONFIG)
results = pipeline.run()

## 7. Evaluation & Results

In [None]:
# Display detailed results
import pandas as pd

def create_results_table(results: Dict[str, Any]) -> pd.DataFrame:
    """Create a results summary table."""
    rows = []
    for model_name, result in results.items():
        row = {
            'Model': model_name,
            'Status': result.get('status', 'unknown'),
        }
        
        if result.get('status') == 'success':
            test_results = result.get('test_results', {})
            row['WER'] = test_results.get('wer', test_results.get('eval_wer', None))
            row['CER'] = test_results.get('cer', test_results.get('eval_cer', None))
            row['Model Path'] = result.get('model_path', '')
        else:
            row['Error'] = result.get('error', '')
        
        rows.append(row)
    
    return pd.DataFrame(rows)

results_df = create_results_table(results)
display(results_df)

In [None]:
# Save results to JSON
results_path = Path(CONFIG.output_base_dir) / "training_results.json"

# Convert results to JSON-serializable format
def make_serializable(obj):
    if isinstance(obj, dict):
        return {k: make_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [make_serializable(v) for v in obj]
    elif isinstance(obj, (np.integer, np.floating)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif hasattr(obj, '__dict__'):
        return make_serializable(obj.__dict__)
    return obj

with open(results_path, 'w') as f:
    json.dump(make_serializable(results), f, indent=2)

print(f"Results saved to: {results_path}")

## 8. Inference Example

In [None]:
def load_finetuned_model(model_path: str, model_type: str = "whisper"):
    """Load a fine-tuned model for inference."""
    if model_type == "whisper":
        processor = WhisperProcessor.from_pretrained(model_path)
        model = WhisperForConditionalGeneration.from_pretrained(model_path)
        model.to(device)
        return processor, model
    elif model_type == "parakeet":
        import nemo.collections.asr as nemo_asr
        model = nemo_asr.models.ASRModel.restore_from(model_path)
        return None, model
    else:
        raise ValueError(f"Unknown model type: {model_type}")


def transcribe_audio(audio_path: str, processor, model, model_type: str = "whisper") -> str:
    """Transcribe audio using a fine-tuned model."""
    if model_type == "whisper":
        # Load audio
        audio, sr = torchaudio.load(audio_path)
        if sr != 16000:
            audio = torchaudio.functional.resample(audio, sr, 16000)
        
        # Prepare input
        input_features = processor.feature_extractor(
            audio.squeeze().numpy(),
            sampling_rate=16000,
            return_tensors="pt"
        ).input_features.to(device)
        
        # Generate
        with torch.no_grad():
            predicted_ids = model.generate(input_features)
        
        # Decode
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        return transcription
    
    elif model_type == "parakeet":
        transcription = model.transcribe([audio_path])[0]
        return transcription
    
    else:
        raise ValueError(f"Unknown model type: {model_type}")

In [None]:
# Example: Load fine-tuned model and transcribe
# Uncomment and modify the path to use

# model_path = "./outputs/whisper-small-finetuned"
# processor, model = load_finetuned_model(model_path, model_type="whisper")

# # Transcribe a test audio file
# audio_path = "path/to/your/audio.wav"
# transcription = transcribe_audio(audio_path, processor, model, model_type="whisper")
# print(f"Transcription: {transcription}")

## 9. Adding Custom Models

To add support for a new STT model, create a class that inherits from `BaseSTTTrainer`.

In [None]:
# Example: Adding a custom model trainer

class CustomModelTrainer(BaseSTTTrainer):
    """Template for adding a custom STT model trainer."""
    
    def load_model(self):
        """Load your custom model."""
        # self.model = YourModel.from_pretrained(self.model_config.model_name)
        # self.processor = YourProcessor.from_pretrained(self.model_config.model_name)
        raise NotImplementedError("Implement model loading for your custom model")
    
    def preprocess_dataset(self, dataset: DatasetDict) -> DatasetDict:
        """Preprocess dataset for your model."""
        # Implement preprocessing specific to your model
        raise NotImplementedError("Implement preprocessing for your custom model")
    
    def train(self, dataset: DatasetDict) -> Dict[str, Any]:
        """Train your model."""
        # Implement training logic
        raise NotImplementedError("Implement training for your custom model")
    
    def evaluate(self, dataset: Dataset) -> Dict[str, float]:
        """Evaluate your model."""
        # Implement evaluation logic
        raise NotImplementedError("Implement evaluation for your custom model")


# Register the custom model
# ModelFactory.register('custom', CustomModelTrainer)

# Then use it in configuration:
# ModelConfig(
#     model_type="custom",
#     model_name="your/model-name",
#     ...
# )

## 10. Utility Functions

In [None]:
def compare_models(results: Dict[str, Any]) -> pd.DataFrame:
    """Compare performance of trained models."""
    comparison = []
    
    for model_name, result in results.items():
        if result.get('status') == 'success':
            test_results = result.get('test_results', {})
            comparison.append({
                'Model': model_name,
                'WER': test_results.get('wer', test_results.get('eval_wer')),
                'CER': test_results.get('cer', test_results.get('eval_cer')),
            })
    
    df = pd.DataFrame(comparison)
    if not df.empty:
        df = df.sort_values('WER')
    return df


def export_config(config: TrainingConfig, path: str):
    """Export configuration to JSON file."""
    config_dict = asdict(config)
    with open(path, 'w') as f:
        json.dump(config_dict, f, indent=2, default=str)
    print(f"Configuration exported to: {path}")


def load_config(path: str) -> TrainingConfig:
    """Load configuration from JSON file."""
    with open(path, 'r') as f:
        config_dict = json.load(f)
    
    # Reconstruct dataclasses
    models = [ModelConfig(**m) for m in config_dict.pop('models', [])]
    dataset = DatasetConfig(**config_dict.pop('dataset', {}))
    
    return TrainingConfig(models=models, dataset=dataset, **config_dict)

In [None]:
# Compare trained models
if results:
    comparison_df = compare_models(results)
    if not comparison_df.empty:
        print("Model Comparison (sorted by WER):")
        display(comparison_df)

In [None]:
# Export configuration for reproducibility
config_export_path = Path(CONFIG.output_base_dir) / "training_config.json"
export_config(CONFIG, str(config_export_path))

---

## Quick Reference

### Adding a New Dataset
```python
DATASET_CONFIG = DatasetConfig(
    train_datasets=[
        "mozilla-foundation/common_voice_11_0",  # HuggingFace dataset
        "/path/to/local/data",                   # Local directory
        "/path/to/manifest.json",                # JSON manifest
    ],
    audio_column="audio",
    text_column="sentence",
    dataset_config_name="en",  # Language code for Common Voice, None for others
    dataset_split="train",
)
```

### Adding a New Model
```python
MODELS_TO_TRAIN.append(
    ModelConfig(
        model_type="whisper",  # or "parakeet", or custom
        model_name="openai/whisper-large-v3",
        learning_rate=1e-5,
        batch_size=4,
        num_epochs=5,
    )
)
```

### Enabling Parallel Training
```python
CONFIG = TrainingConfig(
    ...
    enable_parallel=True,
    max_parallel_models=2,
)
```

### Using Separate Validation/Test Sets
```python
DATASET_CONFIG = DatasetConfig(
    train_datasets=["dataset/train"],
    validation_dataset="dataset/validation",
    test_dataset="dataset/test",
    dataset_config_name=None,  # Set if needed
)
```

### Common Dataset Examples
```python
# Common Voice (requires language config)
DATASET_CONFIG = DatasetConfig(
    train_datasets=["mozilla-foundation/common_voice_11_0"],
    dataset_config_name="en",  # "es", "fr", "de", etc.
    dataset_split="train",
)

# LibriSpeech (no config needed)
DATASET_CONFIG = DatasetConfig(
    train_datasets=["librispeech_asr"],
    dataset_config_name=None,
    dataset_split="train.clean.100",
    audio_column="audio",
    text_column="text",
)

# Local audiofolder
DATASET_CONFIG = DatasetConfig(
    train_datasets=["/path/to/audio/folder"],
    dataset_config_name=None,
    audio_column="audio",
    text_column="transcription",
)
```