In [None]:
!pip install kagglehub==0.3.5 speechbrain==1.0.2 torchaudio==2.5.1 soundfile==0.12.1 torch==2.5.1 tqdm==4.67.1 scikit-learn==1.6.0 transformers==4.47.1 datasets==3.2.0 jiwer==3.0.5 hugginface_hub==0.27.1 optuna optimum[onnxruntime] safetensors logger

Vocab Generator

In [None]:
from datasets import DatasetDict, load_dataset
from typing import Dict
import json


def generate_vocab(dataset_dict: DatasetDict) -> Dict:
    """Generate vocabulary from dataset transcriptions."""
    chars = set()
    for split in ["train", "validation", "test"]:
        for example in dataset_dict[split]:
            chars.update(example["transcription"].lower())  # Normalize to lowercase

    # Add special tokens required by Wav2Vec2 CTC
    vocab = ["<pad>", "<unk>", "|"] + sorted(chars)
    return {char: idx for idx, char in enumerate(vocab)}

# Generate and save vocabulary
dataset = load_dataset("StefanStefan/STT")
vocab = generate_vocab(dataset)
with open("vocab.json", "w") as f:
    json.dump(vocab, f)

Fine-Tuning script

In [None]:
import os
import json
from dataclasses import dataclass
from typing import Dict, List, Union, Optional

import torch
import torchaudio
import numpy as np
from datasets import Dataset, DatasetDict, Features, Sequence, Value, load_dataset
from transformers import (
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,
    TrainingArguments,
    Trainer,
)
import jiwer
import logging
from abc import ABC, abstractmethod


class ConfigManager:
    """Manages configuration settings for the speech recognition pipeline."""

    def __init__(self, config_path: Optional[str] = None):
        self.config = self._load_default_config()
        if config_path:
            self.update_config(config_path)

    def _load_default_config(self) -> Dict:
        return {
            "data": {
                "dataset_name": "StefanStefan/STT",
                "audio_dir": "1/output"
            },
            "processor": {
                "pretrained_model_name_or_path": "facebook/wav2vec2-base-100h"
            },
            "training": {
                "output_dir": "./wav2vec2-finetuned",
                "group_by_length": False,
                "lr_scheduler_type": "cosine_with_restarts",
                "length_column_name": "length",
                "per_device_train_batch_size": 4,
                "per_device_eval_batch_size": 8,
                "gradient_accumulation_steps": 5,
                "load_best_model_at_end": True,
                "metric_for_best_model": "wer",
                "greater_is_better": False,
                "eval_strategy": "steps",
                "num_train_epochs": 30,
                "fp16": True,
                "save_steps": 500,
                "eval_steps": 500,
                "logging_steps": 100,
                "learning_rate": 4.48e-5,
                "weight_decay": 0.01008,
                "warmup_ratio": 0.1167,
                "max_grad_norm": 0.3097,
                "save_total_limit": 5,
                "save_strategy": "steps",
                "report_to": "none",
                "gradient_checkpointing": True,
                "fp16_full_eval": True,
                "dataloader_num_workers": 4,
                "prediction_loss_only": False,
                "optim": "adamw_torch",
            }
        }

    def update_config(self, config_path: str) -> None:
        """Update configuration with values from a JSON file."""
        with open(config_path, 'r') as f:
            custom_config = json.load(f)
        self.config = {**self.config, **custom_config}

    def get_config(self) -> Dict:
        """Get the current configuration."""
        return self.config


class LoggerSetup:
    """Sets up logging configuration."""

    @staticmethod
    def setup(level: int = logging.INFO) -> logging.Logger:
        logging.basicConfig(
            format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
            level=level
        )
        return logging.getLogger(__name__)


class DataProcessor(ABC):
    """Abstract base class for data processing."""

    @abstractmethod
    def process(self, *args, **kwargs):
        pass


class AudioProcessor(DataProcessor):
    """Handles audio data processing."""

    def __init__(self, processor: Wav2Vec2Processor):
        self.processor = processor

    def process(self, audio_dict: dict) -> np.ndarray:
        # Extract audio array and sampling rate from the dataset's audio dict
        speech = audio_dict["array"]
        sampling_rate = audio_dict["sampling_rate"]

        # Resample if needed (using the processor's expected sampling rate)
        if sampling_rate != self.processor.feature_extractor.sampling_rate:
            resampler = torchaudio.transforms.Resample(
                sampling_rate,
                self.processor.feature_extractor.sampling_rate
            )
            speech = resampler(torch.tensor(speech)).squeeze().numpy()

        return speech


@dataclass
class DataCollatorCTCWithPadding:
    """Data collator that handles dynamic padding for CTC training."""

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.feature_extractor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        labels_batch = self.processor.tokenizer.pad(
            label_features,
            padding=self.padding,
            return_tensors="pt",
        )

        labels = labels_batch["input_ids"].masked_fill(labels_batch["attention_mask"].ne(1), -100)
        batch["labels"] = labels

        return batch


class DatasetManager:
    """Manages dataset operations."""

    def __init__(self, config: Dict, processor: Wav2Vec2Processor):
        self.config = config
        self.processor = processor
        self.audio_processor = AudioProcessor(processor)

    def load_dataset_from_hf(self, dataset_name: str) -> DatasetDict:
        """Load dataset from Hugging Face."""
        dataset = load_dataset(dataset_name)
        return dataset

    def preprocess_function(self, batch: Dict) -> Dict:
        """Updated with text normalization."""
        # Normalize text to lowercase
        batch["transcription"] = batch["transcription"].lower()

        speech = self.audio_processor.process(batch["audio"])

        input_values = self.processor(
            speech,
            sampling_rate=self.processor.feature_extractor.sampling_rate,
            return_attention_mask=False
        ).input_values[0]

        labels = self.processor(
            text=batch["transcription"],
            return_attention_mask=False
        ).input_ids

        return {"input_values": input_values, "labels": labels}

    def create_dataset_dict(self) -> DatasetDict:
        """Create dataset dictionary from Hugging Face dataset."""
        dataset = self.load_dataset_from_hf(self.config["data"]["dataset_name"])

        # Assuming the dataset has 'train', 'validation', and 'test' splits
        datasets = {
            "train": dataset["train"],
            "validation": dataset["validation"],
            "test": dataset["test"]
        }

        return DatasetDict(datasets)


class ModelManager:
    """Manages model operations."""

    def __init__(self, config: Dict, processor: Wav2Vec2Processor):
        self.config = config
        self.processor = processor

    def initialize_model(self) -> Wav2Vec2ForCTC:
        """Initialize and configure the model."""
        model = Wav2Vec2ForCTC.from_pretrained(
            self.config["processor"]["pretrained_model_name_or_path"],
            vocab_size=len(self.processor.tokenizer),
            pad_token_id=self.processor.tokenizer.pad_token_id,
            ignore_mismatched_sizes=True
        )

        model.config.label_pad_token_id = -100
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        return model

    @staticmethod
    def compute_metrics(pred, processor: Wav2Vec2Processor) -> Dict[str, float]:
        """Compute model metrics."""
        pred_logits = pred.predictions
        pred_ids = np.argmax(pred_logits, axis=-1)
        pred_str = processor.batch_decode(pred_ids)

        label_ids = pred.label_ids
        label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
        label_str = processor.batch_decode(label_ids, group_tokens=False)

        return {
            "wer": jiwer.wer(label_str, pred_str),
            "cer": jiwer.cer(label_str, pred_str)
        }


class ProcessorManager:
    def __init__(self, config: Dict):
        """
        Initialize the ProcessorManager with configuration.

        Args:
            config (Dict): Configuration dictionary containing processor settings
        """
        self.config = config

    def initialize_processor(self) -> Wav2Vec2Processor:
        """Initialize processor with custom vocabulary."""
        # Load custom tokenizer
        tokenizer = Wav2Vec2CTCTokenizer(
            "vocab.json",
            unk_token="<unk>",
            pad_token="<pad>",
            word_delimiter_token="|"
        )

        # Load feature extractor from pretrained model
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            self.config["processor"]["pretrained_model_name_or_path"]
        )

        return Wav2Vec2Processor(
            feature_extractor=feature_extractor,
            tokenizer=tokenizer
        )


class TrainingPipeline:
    def __init__(self, config_path: Optional[str] = None):
        self.logger = LoggerSetup.setup()
        self.config_manager = ConfigManager(config_path)
        self.config = self.config_manager.get_config()

        # First load dataset to generate vocab
        temp_dataset = load_dataset(self.config["data"]["dataset_name"])
        if not os.path.exists("vocab.json"):
            self.logger.info("Generating vocabulary...")
            vocab = generate_vocab(temp_dataset)
            with open("vocab.json", "w") as f:
                json.dump(vocab, f)

        # Now initialize processor with custom vocab
        self.processor_manager = ProcessorManager(self.config)
        self.processor = self.processor_manager.initialize_processor()

        # Rest of initialization remains the same
        self.dataset_manager = DatasetManager(self.config, self.processor)
        self.model_manager = ModelManager(self.config, self.processor)

    def prepare_datasets(self) -> DatasetDict:
        """Prepare and preprocess datasets."""
        self.logger.info("Loading and preprocessing datasets...")
        dataset_dict = self.dataset_manager.create_dataset_dict()

        # Get all original column names to remove
        original_columns = dataset_dict["train"].column_names

        processed_dataset = dataset_dict.map(
            self.dataset_manager.preprocess_function,
            remove_columns=original_columns,
            batched=False,
            num_proc=4,
            desc="Preprocessing the dataset"
        )

        features = Features({
            'input_values': Sequence(Value('float32')),
            'labels': Sequence(Value('int64')),
        })

        return processed_dataset.cast(features)

    def initialize_trainer(self, processed_dataset: DatasetDict) -> Trainer:
        """Initialize the trainer."""
        training_args = TrainingArguments(**self.config["training"])

        model = self.model_manager.initialize_model()
        data_collator = DataCollatorCTCWithPadding(processor=self.processor)

        return Trainer(
            model=model,
            args=training_args,
            data_collator=data_collator,
            train_dataset=processed_dataset["train"],
            eval_dataset=processed_dataset["validation"],
            compute_metrics=lambda pred: self.model_manager.compute_metrics(pred, self.processor)
        )

    def run(self):
        """Execute the complete training pipeline."""
        try:
            processed_dataset = self.prepare_datasets()
            trainer = self.initialize_trainer(processed_dataset)

            self.logger.info("Starting training...")
            trainer.train()

            self.logger.info("Evaluating on test set...")
            test_results = trainer.evaluate(eval_dataset=processed_dataset["test"])
            self.logger.info(f"Test Results: {test_results}")

            save_path = "my-wav2vec2-finetuned"
            trainer.save_model(save_path)
            self.processor.save_pretrained(save_path)
            self.logger.info(f"Model and processor saved to {save_path}")

        except Exception as e:
            self.logger.error(f"Pipeline failed: {str(e)}")
            raise


def main():
    pipeline = TrainingPipeline()
    pipeline.run()


Quantization

In [None]:
import os
import torch
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datasets import load_dataset
import torchaudio
import jiwer
import json
import time
from pathlib import Path


def evaluate_transcription(model, processor, dataset, num_samples=10, device="cpu"):
    """Evaluate a model on transcription quality."""
    model.to(device)
    model.eval()

    # Check if model is using half precision
    is_half_precision = next(model.parameters()).dtype == torch.float16

    # Select samples for evaluation
    if num_samples and num_samples < len(dataset):
        indices = np.linspace(0, len(dataset)-1, num_samples, dtype=int)
        indices = [int(idx) for idx in indices]  # Convert to Python ints
        test_dataset = dataset.select(indices)
    else:
        test_dataset = dataset

    references = []
    predictions = []

    start_time = time.time()

    for idx in range(len(test_dataset)):
        # Process audio
        sample = test_dataset[idx]
        audio = sample["audio"]
        speech = audio["array"]
        sampling_rate = audio["sampling_rate"]

        # Resample if needed
        if sampling_rate != processor.feature_extractor.sampling_rate:
            resampler = torchaudio.transforms.Resample(
                sampling_rate,
                processor.feature_extractor.sampling_rate
            )
            speech = resampler(torch.tensor(speech)).squeeze().numpy()

        # Preprocess
        inputs = processor(
            speech,
            sampling_rate=processor.feature_extractor.sampling_rate,
            return_tensors="pt"
        )

        # Move to device
        inputs = inputs.to(device)

        # Convert input to half precision if model is in half precision
        if is_half_precision:
            inputs.input_values = inputs.input_values.half()

        # Run inference
        with torch.no_grad():
            outputs = model(input_values=inputs.input_values)

        # Decode
        logits = outputs.logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]

        # Store results
        predictions.append(transcription)
        references.append(sample["transcription"].lower())

    # Calculate metrics
    wer = jiwer.wer(references, predictions)
    elapsed_time = time.time() - start_time

    return {
        "wer": wer,
        "inference_time": elapsed_time / len(test_dataset),
        "predictions": predictions,
        "references": references
    }


def get_model_size_mb(model):
    """Get the size of a PyTorch model in MB."""
    # Save model to a temporary file to get its size
    tmp_path = "tmp_model.pt"
    torch.save(model.state_dict(), tmp_path)
    size_mb = os.path.getsize(tmp_path) / (1024 * 1024)
    os.remove(tmp_path)  # Clean up
    return size_mb


def apply_fp16_quantization(model, device='cpu'):
    """
    Apply FP16 quantization to a PyTorch model.

    Args:
        model: PyTorch model
        device: Device to use

    Returns:
        FP16 quantized model
    """
    try:
        print(f"Applying FP16 quantization to the model...")
        model.eval()  # Set model to evaluation mode

        # Convert model to float16
        quantized_model = model.half().to(device)
        print("Successfully converted model to FP16")
        return quantized_model

    except Exception as e:
        print(f"FP16 quantization failed: {e}")
        return model


def quantize_model(model_id, dataset_name, output_dir, num_test_samples=5):
    """
    Quantize a Hugging Face model to FP16 and evaluate it.

    Args:
        model_id: ID of the Hugging Face model
        dataset_name: ID of the dataset for evaluation
        output_dir: Directory to save the results
        num_test_samples: Number of samples to use for evaluation

    Returns:
        Dict containing information about the quantized model
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Create a separate directory for evaluation results
    eval_dir = os.path.join(output_dir, "evaluation_results")
    os.makedirs(eval_dir, exist_ok=True)

    # Load dataset
    print(f"Loading dataset: {dataset_name}")
    dataset = load_dataset(dataset_name)
    test_dataset = dataset["test"]

    # Load model and processor
    print(f"Loading model: {model_id}")
    model = Wav2Vec2ForCTC.from_pretrained(model_id)
    processor = Wav2Vec2Processor.from_pretrained(model_id)

    # Determine device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Step 1: Original Model Evaluation
    print("Evaluating original model...")
    original_results = evaluate_transcription(
        model, processor, test_dataset, num_test_samples, device
    )

    # Get original model size
    original_size_mb = get_model_size_mb(model)

    # Step 2: Apply FP16 quantization
    print("\nApplying FP16 quantization...")
    quantized_model = apply_fp16_quantization(model, device)

    # Get quantized model size
    quantized_size_mb = get_model_size_mb(quantized_model)

    # Step 3: Evaluate quantized model
    print("\nEvaluating FP16 quantized model...")
    quantized_results = evaluate_transcription(
        quantized_model, processor, test_dataset, num_test_samples, device
    )

    # Save quantized model and processor to the output directory
    print(f"Saving FP16 quantized model and processor to {output_dir}...")
    quantized_model.save_pretrained(output_dir)
    processor.save_pretrained(output_dir)
    print(f"Successfully saved model and processor to {output_dir}")

    # Calculate WER change
    wer_change_pct = 100 * (quantized_results["wer"] - original_results["wer"]) / original_results["wer"]

    # Calculate size reduction
    size_reduction_pct = 100 * (original_size_mb - quantized_size_mb) / original_size_mb

    # Create model info dictionary
    model_info = {
        "original": {
            "wer": original_results["wer"],
            "inference_time": original_results["inference_time"],
            "size_mb": original_size_mb
        },
        "quantized": {
            "path": output_dir,
            "size_mb": quantized_size_mb,
            "wer": quantized_results["wer"],
            "wer_change_pct": wer_change_pct,
            "inference_time": quantized_results["inference_time"],
            "quantization_type": "float16"
        }
    }

    # Save results in the evaluation directory
    with open(os.path.join(eval_dir, "quantization_results.json"), "w") as f:
        json.dump(model_info, f, indent=2)

    # Generate markdown report
    report_path = os.path.join(eval_dir, "quantization_report.md")
    with open(report_path, "w") as f:
        f.write("# Wav2Vec2 Model FP16 Quantization Report\n\n")

        f.write("## Model Information\n\n")
        f.write(f"- Original model: {model_id}\n")
        f.write(f"- Dataset: {dataset_name}\n")
        f.write(f"- Quantization: FP16\n\n")

        f.write("## Quantization Results\n\n")
        f.write("| Format | WER | Inference Time (s) | Size (MB) |\n")
        f.write("|--------|-----|---------------------|----------|\n")
        f.write(f"| Original | {original_results['wer']:.4f} | {original_results['inference_time']:.4f} | {original_size_mb:.2f} |\n")
        f.write(f"| FP16 | {quantized_results['wer']:.4f} | {quantized_results['inference_time']:.4f} | {quantized_size_mb:.2f} |\n\n")

        f.write(f"- Size reduction from quantization: {size_reduction_pct:.2f}%\n")

        f.write("## Analysis\n\n")
        f.write(f"- WER change: {wer_change_pct:.2f}%\n")

        inference_time_change = 100 * (quantized_results['inference_time'] - original_results['inference_time']) / original_results['inference_time']
        f.write(f"- Inference time change: {inference_time_change:.2f}%\n\n")

        # Add some conclusions
        f.write("### Conclusion\n\n")
        f.write(f"The FP16 quantization ")

        if abs(wer_change_pct) < 1.0:
            f.write(f"was successful with negligible impact on accuracy ({wer_change_pct:.2f}% WER change). ")
        elif wer_change_pct > 0:
            f.write(f"resulted in a {wer_change_pct:.2f}% increase in Word Error Rate, ")
            f.write("which may be acceptable depending on your use case and the benefits gained. ")
        else:
            f.write(f"surprisingly improved accuracy by {abs(wer_change_pct):.2f}%. ")

        if inference_time_change < 0:
            f.write(f"The model also shows a {abs(inference_time_change):.2f}% improvement in inference speed. ")
        else:
            f.write(f"However, the model is {inference_time_change:.2f}% slower than the original model. ")

        f.write(f"The quantization reduced the model size by {size_reduction_pct:.2f}%, ")
        if abs(wer_change_pct) < 5.0:
            f.write("with a reasonable trade-off in accuracy.")
        elif wer_change_pct > 0:
            f.write(f"but at the cost of a {wer_change_pct:.2f}% increase in WER.")
        else:
            f.write(f"while actually improving accuracy by {abs(wer_change_pct):.2f}%.")

    # Print summary to console
    print("\n" + "="*50)
    print("FP16 QUANTIZATION RESULTS SUMMARY")
    print("="*50)
    print(f"Original WER: {original_results['wer']:.4f}")
    print(f"FP16 quantized WER: {quantized_results['wer']:.4f} ({wer_change_pct:.2f}% change)")
    print(f"Original model size: {original_size_mb:.2f} MB")
    print(f"FP16 quantized model size: {quantized_size_mb:.2f} MB")
    print(f"Size reduction: {size_reduction_pct:.2f}%")
    print(f"Quantized model path: {output_dir}")
    print(f"\nDetailed report saved to {report_path}")

    return model_info


def main():
    # Configuration
    model_id = "StefanStefan/Wav2Vec-100-CSR"
    dataset_name = "StefanStefan/STT"
    output_dir = "wav2vec2-fp16"
    num_test_samples = 5000

    quantize_model(model_id, dataset_name, output_dir, num_test_samples)


if __name__ == "__main__":
    main()

Knowledge Distillation

In [None]:
import os
import torch
import numpy as np
import json
import logging
import time
from tqdm import tqdm
from datetime import datetime
from transformers import (
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    Wav2Vec2Config,
    get_linear_schedule_with_warmup
)
from torch.optim.lr_scheduler import OneCycleLR
from datasets import load_dataset
import jiwer
from dataclasses import dataclass
from typing import Dict, List, Union
from prettytable import PrettyTable
import pandas as pd

# Setup logging
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class DataCollatorCTCWithPadding:
    """Data collator for CTC inference."""
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.feature_extractor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        labels_batch = self.processor.tokenizer.pad(
            label_features,
            padding=self.padding,
            return_tensors="pt",
        )

        labels = labels_batch["input_ids"].masked_fill(labels_batch["attention_mask"].ne(1), -100)
        batch["labels"] = labels

        # Ensure we're using float32 (not float64/double)
        if batch["input_values"].dtype == torch.float64:
            batch["input_values"] = batch["input_values"].to(torch.float32)

        return batch


def preprocess_dataset(dataset, processor):
    """Preprocess the dataset for speech recognition."""
    def preprocess_function(batch):
        # Normalize text
        batch["transcription"] = batch["transcription"].lower()

        # Process audio
        speech = batch["audio"]["array"]
        sampling_rate = batch["audio"]["sampling_rate"]

        # Resample if needed
        if sampling_rate != processor.feature_extractor.sampling_rate:
            import torchaudio
            resampler = torchaudio.transforms.Resample(
                sampling_rate,
                processor.feature_extractor.sampling_rate
            )
            speech = resampler(torch.tensor(speech)).squeeze().numpy()

        # Get input values and labels
        input_values = processor(
            speech,
            sampling_rate=processor.feature_extractor.sampling_rate,
            return_attention_mask=False
        ).input_values[0]

        # Ensure we're using float32 (not float64/double)
        input_values = np.array(input_values, dtype=np.float32)

        # Apply SpecAugment-like time masking for training data
        if batch.get("split", "") == "train":
            input_values_tensor = torch.tensor(input_values)
            seq_len = input_values_tensor.shape[0]

            # Time masking (mask random segments of the audio)
            for _ in range(2):  # Apply 2 time masks
                mask_length = int(seq_len * 0.05)  # 5% of sequence length
                if mask_length > 0:
                    start = np.random.randint(0, seq_len - mask_length)
                    input_values_tensor[start:start+mask_length] = 0.0

            input_values = input_values_tensor.numpy()

        labels = processor(
            text=batch["transcription"],
            return_attention_mask=False
        ).input_ids

        return {"input_values": input_values, "labels": labels}

    # Process all splits
    processed_dataset = {}
    for split in dataset:
        logger.info(f"Processing {split} split...")
        # Use smaller subset for faster iterations
        subset_size = len(dataset[split]) if split == "train" else min(100, len(dataset[split]))
        dataset_subset = dataset[split].select(range(subset_size))
        logger.info(f"Using {subset_size} examples from {split} split")

        processed_dataset[split] = dataset_subset.map(
            preprocess_function,
            remove_columns=dataset_subset.column_names,
            num_proc=4,
            desc=f"Processing {split} split"
        )
        logger.info(f"Finished processing {split} split: {len(processed_dataset[split])} examples")

    return processed_dataset


def evaluate_model(model, processor, test_dataset, device, max_eval_samples=200):
    """Evaluate model on test dataset with validation loss."""
    from torch.utils.data import DataLoader

    # Use a larger subset for evaluation for more accurate WER measurement
    if len(test_dataset) > max_eval_samples:
        logger.info(f"Evaluating on {max_eval_samples} samples instead of {len(test_dataset)}")
        test_dataset = test_dataset.select(range(max_eval_samples))
    else:
        logger.info(f"Evaluating on all {len(test_dataset)} samples")

    model.to(device)
    model.eval()

    data_collator = DataCollatorCTCWithPadding(processor=processor)
    dataloader = DataLoader(test_dataset, batch_size=8, collate_fn=data_collator)

    all_predictions = []
    all_references = []

    # Track validation loss
    val_loss = 0
    val_steps = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            # Ensure all tensors are float32
            input_values = batch["input_values"].to(device, dtype=torch.float32)
            labels = batch["labels"].to(device)

            outputs = model(input_values=input_values, labels=labels)
            val_loss += outputs.loss.item()
            val_steps += 1

            predicted_ids = torch.argmax(outputs.logits, dim=-1).cpu().numpy()
            label_ids = labels.cpu().numpy()

            # Decode predictions and references
            predictions = processor.batch_decode(predicted_ids)

            # Clean up label ids (-100 -> pad_token_id)
            label_ids_cleaned = np.where(label_ids == -100, processor.tokenizer.pad_token_id, label_ids)
            references = processor.batch_decode(label_ids_cleaned, group_tokens=False)

            all_predictions.extend(predictions)
            all_references.extend(references)

    # Calculate metrics
    wer = jiwer.wer(all_references, all_predictions)
    cer = jiwer.cer(all_references, all_predictions)
    avg_val_loss = val_loss / val_steps if val_steps > 0 else float('inf')

    # Store more examples for analysis
    examples = list(zip(all_references[:20], all_predictions[:20]))

    # Log WER calculation details
    logger.info(f"WER calculated on {len(all_references)} test samples: {wer:.4f}")

    return {"wer": wer, "cer": cer, "val_loss": avg_val_loss,
            "examples": examples,
            "num_samples": len(all_references)}


def calculate_model_size(model, as_mb=True):
    """Calculate model size in parameters or MB."""
    num_params = sum(p.numel() for p in model.parameters())

    if as_mb:
        # Estimate model size in MB (4 bytes per float32 parameter)
        size_mb = (num_params * 4) / (1024 * 1024)
        return size_mb
    else:
        return num_params


def copy_matching_layers(student_model, teacher_model):
    """Copy matching layer weights from teacher to student."""
    # Create parameter name mappings
    teacher_params = dict(teacher_model.named_parameters())
    student_params = dict(student_model.named_parameters())

    # Find matching parameters and copy weights
    copied_params = 0
    total_params = len(student_params)

    for name, param in student_model.named_parameters():
        if name in teacher_params:
            # If shapes match exactly, copy directly
            if param.shape == teacher_params[name].shape:
                param.data.copy_(teacher_params[name].data)
                copied_params += 1
            # For layers that were reduced, try to initialize from corresponding teacher layers
            elif "layers" in name:
                # Extract layer indices
                parts = name.split(".")
                for i, part in enumerate(parts):
                    if part == "layers" and i+1 < len(parts) and parts[i+1].isdigit():
                        student_layer_idx = int(parts[i+1])
                        # Map student layer to teacher layer
                        teacher_layers = teacher_model.config.num_hidden_layers
                        student_layers = student_model.config.num_hidden_layers

                        # Find corresponding teacher layer using proportional mapping
                        teacher_layer_idx = min(
                            teacher_layers - 1,
                            int(student_layer_idx * (teacher_layers / student_layers))
                        )

                        # Create teacher parameter name with mapped index
                        teacher_name = name.replace(
                            f"layers.{student_layer_idx}",
                            f"layers.{teacher_layer_idx}"
                        )

                        if teacher_name in teacher_params:
                            if param.shape == teacher_params[teacher_name].shape:
                                param.data.copy_(teacher_params[teacher_name].data)
                                copied_params += 1

    logger.info(f"Copied weights for {copied_params}/{total_params} parameters")
    return student_model


class DistilledWav2Vec2(torch.nn.Module):
    """Single-stage distilled Wav2Vec2 model with feature-level distillation."""

    def __init__(
        self,
        teacher_model,
        student_config,
        temperature=2.0,  # Lower temperature for sharper distribution
        alpha_ce=0.5,     # Balanced weight for CTC loss
        alpha_kd=0.4,     # Increased weight for knowledge distillation
        alpha_feat=0.1    # Same weight for feature distillation
    ):
        super().__init__()

        # Print info about dimensions
        logger.info(f"Teacher hidden_size: {teacher_model.config.hidden_size}")
        logger.info(f"Student hidden_size: {student_config.hidden_size}")

        # Create student model
        self.student = Wav2Vec2ForCTC(student_config)

        # Copy matching weights from teacher to student for better initialization
        self.student = copy_matching_layers(self.student, teacher_model)

        # Copy CTC head from teacher to student for better initialization
        if self.student.lm_head.out_features == teacher_model.lm_head.out_features:
            logger.info("Copying CTC head from teacher to student")
            self.student.lm_head.weight.data = teacher_model.lm_head.weight.data.clone()
            if hasattr(self.student.lm_head, 'bias') and self.student.lm_head.bias is not None:
                self.student.lm_head.bias.data = teacher_model.lm_head.bias.data.clone()

        # Teacher model (frozen)
        self.teacher = teacher_model
        for param in self.teacher.parameters():
            param.requires_grad = False
        self.teacher.eval()

        # Distillation hyperparameters
        self.temperature = temperature
        self.alpha_ce = alpha_ce
        self.alpha_kd = alpha_kd
        self.alpha_feat = alpha_feat

        # Create feature adapters if dimensions don't match
        self.feat_adapter = None
        if teacher_model.config.hidden_size != student_config.hidden_size:
            self.feat_adapter = torch.nn.Linear(
                student_config.hidden_size,
                teacher_model.config.hidden_size
            )

    def forward(self, input_values, attention_mask=None, labels=None):
        """Forward pass with distillation during training."""
        # Ensure inputs are float32
        if input_values.dtype != torch.float32:
            input_values = input_values.to(torch.float32)

        # Get student outputs - request hidden states only (no attentions)
        student_outputs = self.student(
            input_values=input_values,
            attention_mask=attention_mask,
            output_hidden_states=True,    # Request hidden states
            output_attentions=False,      # Don't request attentions
            return_dict=True,
            labels=labels if not self.training else None  # Don't compute CTC loss here if training
        )

        # For inference, just return student outputs
        if not self.training or labels is None:
            # If not training but we requested hidden states, create proper output format
            if hasattr(student_outputs, 'loss') and student_outputs.loss is None and labels is not None:
                # Manually compute CTC loss if needed
                log_probs = torch.nn.functional.log_softmax(student_outputs.logits, dim=-1)
                input_lengths = torch.full(
                    (input_values.shape[0],),
                    log_probs.shape[1],
                    dtype=torch.long,
                    device=log_probs.device
                )
                target_lengths = torch.sum(labels != -100, dim=1)
                labels_no_pad = labels.clone()
                labels_no_pad[labels_no_pad == -100] = 0  # Replace padding with valid token
                loss = torch.nn.functional.ctc_loss(
                    log_probs.transpose(0, 1),
                    labels_no_pad,
                    input_lengths,
                    target_lengths,
                    blank=0,  # Assuming 0 is blank/pad token
                    reduction='mean'
                )
                student_outputs.loss = loss

            return student_outputs

        # For training, compute multi-component distillation loss
        with torch.no_grad():
            teacher_outputs = self.teacher(
                input_values=input_values,
                attention_mask=attention_mask,
                output_hidden_states=True,  # Request teacher hidden states
                output_attentions=False,    # Don't request teacher attentions
                return_dict=True
            )

        # 1. Modified CTC loss with focal loss component
        gamma = 2.0  # Focal loss gamma parameter
        log_probs = torch.nn.functional.log_softmax(student_outputs.logits, dim=-1)
        input_lengths = torch.full(
            (input_values.shape[0],),
            log_probs.shape[1],
            dtype=torch.long,
            device=log_probs.device
        )
        target_lengths = torch.sum(labels != -100, dim=1)
        labels_no_pad = labels.clone()
        labels_no_pad[labels_no_pad == -100] = 0  # Replace padding with valid token

        # Handle cases where all labels might be padding
        if torch.all(target_lengths == 0):
            ctc_loss = torch.tensor(0.0, device=log_probs.device)
        else:
            # Standard CTC loss
            standard_ctc_loss = torch.nn.functional.ctc_loss(
                log_probs.transpose(0, 1),
                labels_no_pad,
                input_lengths,
                target_lengths,
                blank=0,  # Assuming 0 is blank/pad token
                reduction='none'  # Get per-sample loss
            )

            # Apply focal loss modulation - focus more on hard examples
            pt = torch.exp(-standard_ctc_loss)
            focal_weight = (1 - pt) ** gamma

            # Final focal CTC loss
            ctc_loss = torch.mean(standard_ctc_loss * focal_weight)

        # 2. CTC Logit Distillation using KL divergence
        # First ensure the sequence lengths match by resampling if needed
        teacher_logits = teacher_outputs.logits  # [batch, teacher_seq_len, vocab_size]
        student_logits = student_outputs.logits  # [batch, student_seq_len, vocab_size]

        # Check if sequence lengths match
        if teacher_logits.size(1) != student_logits.size(1):
            # Interpolate student logits to match teacher sequence length
            # Reshape for interpolation [batch, seq_len, vocab] -> [batch, vocab, seq_len]
            student_logits_trans = student_logits.transpose(1, 2)
            # Interpolate to match teacher sequence length
            student_logits_resized = torch.nn.functional.interpolate(
                student_logits_trans,
                size=teacher_logits.size(1),
                mode='linear'
            )
            # Reshape back [batch, vocab, seq_len] -> [batch, seq_len, vocab]
            student_logits = student_logits_resized.transpose(1, 2)

        # Now apply temperature and KL divergence
        soft_student_logits = torch.nn.functional.log_softmax(
            student_logits / self.temperature, dim=-1
        )
        soft_teacher_logits = torch.nn.functional.softmax(
            teacher_logits / self.temperature, dim=-1
        )

        kd_loss = torch.nn.functional.kl_div(
            soft_student_logits,
            soft_teacher_logits,
            reduction='batchmean'
        ) * (self.temperature ** 2)

        # 3. Feature-level distillation with cosine similarity
        feat_loss = 0.0
        if self.alpha_feat > 0:
            student_hidden_states = student_outputs.hidden_states
            teacher_hidden_states = teacher_outputs.hidden_states

            s_layers = len(student_hidden_states)
            t_layers = len(teacher_hidden_states)

            # Map all student layers to teacher layers using more sophisticated mapping
            layer_pairs = []
            for s_idx in range(s_layers):
                # Map student layer index to corresponding teacher layer index
                t_idx = min(t_layers - 1, int(s_idx * (t_layers / s_layers)))
                layer_pairs.append((s_idx, t_idx))

            # Compute cosine similarity loss on all matched layer pairs
            total_feat_loss = 0.0
            for s_idx, t_idx in layer_pairs:
                s_feat = student_hidden_states[s_idx]
                t_feat = teacher_hidden_states[t_idx]

                # Handle sequence length differences with interpolation if needed
                if s_feat.size(1) != t_feat.size(1):
                    # Interpolate student features to match teacher sequence length
                    s_feat_trans = s_feat.transpose(1, 2)
                    s_feat_resized = torch.nn.functional.interpolate(
                        s_feat_trans,
                        size=t_feat.size(1),
                        mode='linear'
                    )
                    s_feat = s_feat_resized.transpose(1, 2)

                # Apply adapter if dimensions don't match
                if self.feat_adapter is not None:
                    s_feat = self.feat_adapter(s_feat)

                # Compute cosine similarity instead of MSE (1 - cosine_similarity)
                s_feat_norm = torch.nn.functional.normalize(s_feat, p=2, dim=-1)
                t_feat_norm = torch.nn.functional.normalize(t_feat, p=2, dim=-1)

                # Calculate cosine similarity loss (mean over sequence length)
                similarity = torch.sum(s_feat_norm * t_feat_norm, dim=-1)
                cos_loss = torch.mean(1.0 - similarity)

                # Add weighted loss (give more weight to later layers)
                layer_weight = 0.5 + 0.5 * (s_idx / (s_layers - 1))  # Weight increases with layer depth
                total_feat_loss += cos_loss * layer_weight

            # Normalize by total weight
            feat_loss = total_feat_loss / sum([0.5 + 0.5 * (i / (s_layers - 1)) for i in range(s_layers)])

        # Combine losses
        combined_loss = (
            self.alpha_ce * ctc_loss +
            self.alpha_kd * kd_loss +
            self.alpha_feat * feat_loss
        )

        # Return with updated loss
        student_outputs.loss = combined_loss
        return student_outputs


def create_student_config(teacher_config, size_reduction=0.15):
    """Create a student configuration for distillation."""
    # Start with a copy of the teacher config
    student_config = Wav2Vec2Config.from_dict(teacher_config.to_dict())

    # Reduce number of layers more carefully based on importance
    original_layers = teacher_config.num_hidden_layers
    # Keep more layers (at least 70% of original)
    target_layers = max(6, int(original_layers * (1 - size_reduction * 0.7)))
    student_config.num_hidden_layers = target_layers

    # Keep the original hidden size to avoid dimensionality issues
    # Hidden size must be divisible by conv groups and attention heads
    # Safer to maintain the original hidden size to avoid compatibility issues
    hidden_size = teacher_config.hidden_size

    # Reduce intermediate size less aggressively
    student_config.intermediate_size = int(teacher_config.intermediate_size * 0.8)

    # Ensure intermediate size is divisible by 8 for better hardware utilization
    student_config.intermediate_size = (student_config.intermediate_size // 8) * 8

    # Adjust attention heads - ensuring divisibility
    if teacher_config.num_attention_heads > 8:
        # Find all possible divisors of the hidden size
        divisors = [i for i in range(1, min(teacher_config.num_attention_heads, 16) + 1)
                  if student_config.hidden_size % i == 0]

        # Choose the largest divisor that's smaller than the original head count
        for divisor in sorted(divisors, reverse=True):
            if divisor < teacher_config.num_attention_heads:
                student_config.num_attention_heads = divisor
                break

    # Dropout adjustments for better convergence
    student_config.hidden_dropout = 0.1
    student_config.attention_dropout = 0.1
    student_config.activation_dropout = 0.1
    student_config.feat_proj_dropout = 0.0

    # Log the configuration differences
    logger.info("Student Model Configuration:")
    logger.info(f"Teacher config: layers={original_layers}, hidden={teacher_config.hidden_size}, heads={teacher_config.num_attention_heads}")
    logger.info(f"Student config: layers={target_layers}, hidden={student_config.hidden_size}, heads={student_config.num_attention_heads}")

    return student_config


def train_distilled_model(
    distilled_model,
    processor,
    train_dataset,
    eval_dataset,
    test_dataset,
    device,
    epochs=5,
    batch_size=4,
    learning_rate=3e-5,
    eval_steps=500,
    save_path=None,
    gradient_accumulation_steps=2  # New parameter for gradient accumulation
):
    """Train a distilled model with tabular output of metrics at each evaluation step."""
    from torch.utils.data import DataLoader

    # Create data collator
    data_collator = DataCollatorCTCWithPadding(processor=processor)

    # Create data loaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=data_collator
    )

    # Calculate total steps and prepare progress tracking
    total_steps = len(train_dataloader) * epochs // gradient_accumulation_steps
    step_count = 0
    steps_since_eval = 0

    # Setup optimizer with parameter groups and weight decay
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in distilled_model.student.named_parameters()
                      if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
        },
        {
            "params": [p for n, p in distilled_model.student.named_parameters()
                      if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

    # Add adapter parameters if they exist
    if hasattr(distilled_model, 'feat_adapter') and distilled_model.feat_adapter is not None:
        adapter_params = list(distilled_model.feat_adapter.parameters())
        if adapter_params:
            optimizer_grouped_parameters.append({
                "params": adapter_params,
                "weight_decay": 0.01,
            })

    # Create optimizer
    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters,
        lr=learning_rate,
        eps=1e-6  # Slightly higher for better stability
    )

    # Create scheduler with one-cycle policy for better convergence
    scheduler = OneCycleLR(
        optimizer,
        max_lr=learning_rate,
        total_steps=total_steps,
        pct_start=0.1,  # Warmup for 10% of training
        div_factor=10.0,  # initial_lr = max_lr / div_factor
        final_div_factor=100.0  # final_lr = initial_lr / final_div_factor
    )

    # Create results directory if saving
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        metrics_file = os.path.join(save_path, "training_metrics.csv")
        metrics_md_file = os.path.join(save_path, "training_metrics.md")

    # Setup metrics storage for tabular output
    metrics_data = {
        "Step": [],
        "Training Loss": [],
        "Validation Loss": [],
        "Wer": [],
        "Cer": []
    }

    # Training variables
    distilled_model.to(device)
    best_wer = float('inf')
    best_model = None
    train_start_time = time.time()

    logger.info(f"Starting training for {epochs} epochs ({total_steps} steps)")
    logger.info(f"Evaluating every {eval_steps} steps")

    try:
        # Training loop
        for epoch in range(epochs):
            epoch_loss = 0
            steps_in_epoch = len(train_dataloader)

            # Training epoch
            distilled_model.train()
            progress_bar = tqdm(train_dataloader,
                            desc=f"Epoch {epoch+1}/{epochs}",
                            total=steps_in_epoch)

            # Progressive distillation schedule - adjust at the start of each epoch
            progress = epoch / epochs
            distilled_model.alpha_kd = max(0.1, distilled_model.alpha_kd * (1.0 - 0.5 * progress))
            distilled_model.alpha_feat = max(0.05, distilled_model.alpha_feat * (1.0 - 0.5 * progress))
            distilled_model.alpha_ce = 1.0 - (distilled_model.alpha_kd + distilled_model.alpha_feat)

            logger.info(f"Epoch {epoch+1}: alpha_ce={distilled_model.alpha_ce:.3f}, "
                        f"alpha_kd={distilled_model.alpha_kd:.3f}, "
                        f"alpha_feat={distilled_model.alpha_feat:.3f}")

            # Learning rate warmup for first epoch
            if epoch == 0:
                for i, param_group in enumerate(optimizer.param_groups):
                    param_group['lr'] = learning_rate * min(1.0, step_count / (0.1 * len(train_dataloader)))

            optimizer.zero_grad()  # Zero gradients at the beginning

            for i, batch in enumerate(progress_bar):
                # Ensure tensors are float32
                input_values = batch["input_values"].to(device, dtype=torch.float32)
                labels = batch["labels"].to(device)

                try:
                    outputs = distilled_model(input_values=input_values, labels=labels)
                    # Scale the loss for gradient accumulation
                    loss = outputs.loss / gradient_accumulation_steps

                    # Layer-wise regularization for better generalization
                    l2_reg = 0.0
                    for name, param in distilled_model.student.named_parameters():
                        if 'layers' in name and 'weight' in name:
                            # Extract layer index more safely using regex
                            import re
                            match = re.search(r'layers\.(\d+)\.', name)
                            if match:
                                layer_idx = int(match.group(1))
                                # Apply stronger regularization to earlier layers
                                layer_weight = 1.0 / (layer_idx + 1)  # Higher weight for earlier layers
                                l2_reg += layer_weight * torch.sum(param ** 2)

                    # Add small regularization loss
                    if l2_reg > 0:
                        reg_lambda = 1e-5  # Small weight for regularization
                        loss = loss + reg_lambda * l2_reg

                    loss.backward()

                    # Only update weights after accumulating enough gradients
                    if (i + 1) % gradient_accumulation_steps == 0 or (i + 1) == len(train_dataloader):
                        # Gradient clipping
                        torch.nn.utils.clip_grad_norm_(
                            list(distilled_model.student.parameters()) +
                            (list(distilled_model.feat_adapter.parameters())
                             if hasattr(distilled_model, 'feat_adapter') and
                                distilled_model.feat_adapter is not None else []),
                            max_norm=3.0
                        )

                        optimizer.step()
                        scheduler.step()
                        optimizer.zero_grad()

                        step_count += 1
                        steps_since_eval += 1

                    epoch_loss += loss.item() * gradient_accumulation_steps

                    # Update progress bar
                    progress_bar.set_postfix({
                        'loss': f"{loss.item() * gradient_accumulation_steps:.4f}",
                        'lr': f"{scheduler.get_last_lr()[0]:.6f}"
                    })

                except Exception as e:
                    logger.error(f"Error in training step: {e}")
                    continue

                # Regular evaluation
                if steps_since_eval >= eval_steps:
                    steps_since_eval = 0

                    # Evaluation
                    distilled_model.eval()
                    student_model = distilled_model.student
                    try:
                        eval_metrics = evaluate_model(
                            student_model,
                            processor,
                            eval_dataset,
                            device,
                            max_eval_samples=200 # Increased evaluation samples for better WER accuracy
                        )

                        # Store metrics for table
                        metrics_data["Step"].append(step_count)
                        metrics_data["Training Loss"].append(f"{loss.item() * gradient_accumulation_steps:.6f}")
                        metrics_data["Validation Loss"].append(f"{eval_metrics['val_loss']:.6f}")
                        metrics_data["Wer"].append(f"{eval_metrics['wer']:.6f}")
                        metrics_data["Cer"].append(f"{eval_metrics['cer']:.6f}")

                        # Display only the latest metrics in a consistent table format
                        if step_count == eval_steps:  # First evaluation
                            # Create table header
                            header = f"{'Step':<8} {'Train Loss':<12} {'Valid Loss':<12} {'WER':<10} {'CER':<10}"
                            divider = "-" * len(header)
                            print("\n" + divider)
                            print(header)
                            print(divider)

                        # Print only the latest row
                        latest = f"{step_count:<8} {loss.item() * gradient_accumulation_steps:<12.6f} {eval_metrics['val_loss']:<12.6f} {eval_metrics['wer']:<10.6f} {eval_metrics['cer']:<10.6f}"
                        print(latest)

                        # Print example predictions - more comprehensive
                        print(f"\nExample predictions (from {eval_metrics['num_samples']} test samples):")
                        for i, (ref, pred) in enumerate(eval_metrics['examples'][:20]):
                            print(f"Example {i+1}:")
                            print(f"Reference: {ref}")
                            print(f"Prediction: {pred}")
                            print("-" * 40)

                        # Save best model
                        if eval_metrics["wer"] < best_wer:
                            best_wer = eval_metrics["wer"]
                            best_model = student_model
                            logger.info(f"New best model with WER: {best_wer*100:.2f}%")

                            if save_path:
                                best_model_path = os.path.join(save_path, "best_model")
                                os.makedirs(best_model_path, exist_ok=True)
                                best_model.save_pretrained(best_model_path)
                                processor.save_pretrained(best_model_path)

                    except Exception as e:
                        logger.error(f"Error during evaluation: {e}")

                    # Return to training mode
                    distilled_model.train()

    except Exception as e:
        logger.error(f"Training error: {e}")

    # Final evaluation
    logger.info(f"Final evaluation on test set...")
    if best_model is None:
        best_model = distilled_model.student

    test_metrics = evaluate_model(best_model, processor, test_dataset, device, max_eval_samples=100)
    logger.info(f"Test WER: {test_metrics['wer']*100:.2f}%")

    # Save final model if best model wasn't saved
    if save_path:
        final_model_path = os.path.join(save_path, "final_model")
        os.makedirs(final_model_path, exist_ok=True)
        best_model.save_pretrained(final_model_path)
        processor.save_pretrained(final_model_path)
        logger.info(f"Saved final model to {final_model_path}")

        # Generate full metrics table at the end of training
        print("\n")
        print("Full Training Metrics History:")
        # Display all collected metrics in a clear tabular format
        headers = ["Step", "Train Loss", "Valid Loss", "WER", "CER"]
        row_format = "{:<8} {:<12} {:<12} {:<10} {:<10}"

        # Print headers
        divider = "-" * 60
        print(divider)
        print(row_format.format(*headers))
        print(divider)

        # Print all rows
        for i in range(len(metrics_data["Step"])):
            print(row_format.format(
                metrics_data["Step"][i],
                float(metrics_data["Training Loss"][i]),
                float(metrics_data["Validation Loss"][i]),
                float(metrics_data["Wer"][i]),
                float(metrics_data["Cer"][i])
            ))
        print(divider)

    return best_model, metrics_data


def distill_wav2vec2(
    model_name,
    dataset_name,
    output_dir,
    size_reduction=0.12, 
    temperature=2.0,      
    epochs=15,           
    batch_size=8,         
    learning_rate=5e-5,   
    eval_steps=500
):
    """Single-stage distillation for Wav2Vec2 with tabular metrics output."""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    # Explicitly set default dtype to float32
    torch.set_default_dtype(torch.float32)

    # 1. Load model, processor, and dataset
    logger.info(f"Loading model from {model_name}...")
    model = Wav2Vec2ForCTC.from_pretrained(model_name)
    processor = Wav2Vec2Processor.from_pretrained(model_name)

    logger.info(f"Loading dataset from {dataset_name}...")
    dataset = load_dataset(dataset_name)

    # 2. Preprocess dataset
    logger.info("Preprocessing dataset...")
    processed_dataset = preprocess_dataset(dataset, processor)

    # 3. Evaluate original model with increased sample size
    logger.info("Evaluating original model...")
    original_metrics = evaluate_model(model, processor, processed_dataset["test"], device, max_eval_samples=200)
    logger.info(f"Original model WER: {original_metrics['wer']*100:.2f}% (calculated on {original_metrics['num_samples']} samples)")

    original_size = calculate_model_size(model)
    logger.info(f"Original model size: {original_size:.2f} MB")

    # 4. Create student model for distillation
    logger.info("Creating student model...")
    student_config = create_student_config(
        model.config,
        size_reduction=size_reduction
    )

    # Create distilled model with improved implementation
    distillation_model = DistilledWav2Vec2(
        teacher_model=model,
        student_config=student_config,
        temperature=temperature,
        alpha_ce=0.5,      # Weight for CTC loss
        alpha_kd=0.4,      # Weight for logit distillation
        alpha_feat=0.1     # Weight for feature distillation
    )

    # 5. Train and get tabular metrics
    logger.info(f"Training distilled model for {epochs} epochs...")
    distilled_model, metrics_data = train_distilled_model(
        distillation_model,
        processor,
        processed_dataset["train"],
        processed_dataset["validation"],
        processed_dataset["test"],
        device,
        epochs=epochs,
        batch_size=batch_size,
        learning_rate=learning_rate,
        eval_steps=eval_steps,
        save_path=output_dir,
        gradient_accumulation_steps=2  # Enable gradient accumulation
    )

    # 6. Calculate size reduction
    distilled_size = calculate_model_size(distilled_model)
    size_reduction_pct = (original_size - distilled_size) / original_size * 100

    # 7. Print final summary
    logger.info(f"Training completed!")
    logger.info(f"Original model WER: {original_metrics['wer']*100:.2f}% (calculated on {original_metrics['num_samples']} samples)")
    logger.info(f"Original model size: {original_size:.2f} MB")

    # Test final model with increased sample size for more accurate WER
    final_metrics = evaluate_model(distilled_model, processor, processed_dataset["test"], device, max_eval_samples=300)
    logger.info(f"Distilled model WER: {final_metrics['wer']*100:.2f}% (calculated on {final_metrics['num_samples']} samples)")
    logger.info(f"Distilled model size: {distilled_size:.2f} MB")
    logger.info(f"Size reduction: {size_reduction_pct:.2f}%")
    logger.info(f"WER difference: {(final_metrics['wer'] - original_metrics['wer'])*100:.2f}%")

    # Create final metrics table as .md and .csv
    df = pd.DataFrame(metrics_data)
    df.to_csv(os.path.join(output_dir, "final_metrics.csv"), index=False)

    with open(os.path.join(output_dir, "final_metrics.md"), 'w') as f:
        f.write("# Knowledge Distillation Training Metrics\n\n")
        f.write(df.to_markdown(index=False))
        f.write("\n\n## Summary\n\n")
        f.write(f"- Original model size: {original_size:.2f} MB\n")
        f.write(f"- Distilled model size: {distilled_size:.2f} MB\n")
        f.write(f"- Size reduction: {size_reduction_pct:.2f}%\n")
        f.write(f"- Original WER: {original_metrics['wer']*100:.2f}% (calculated on {original_metrics['num_samples']} samples)\n")
        f.write(f"- Final WER: {final_metrics['wer']*100:.2f}% (calculated on {final_metrics['num_samples']} samples)\n")
        f.write(f"- WER change: {(final_metrics['wer'] - original_metrics['wer'])*100:.2f}%\n")

        # Add example transcriptions for analysis
        f.write("\n\n## Example Transcriptions\n\n")
        for i, (ref, pred) in enumerate(final_metrics['examples'][:20]):
            f.write(f"### Example {i+1}:\n")
            f.write(f"- **Reference**: {ref}\n")
            f.write(f"- **Prediction**: {pred}\n\n")

    return distilled_model, processor, metrics_data


if __name__ == "__main__":
    # Single-stage distillation with tabular metrics output
    distill_wav2vec2(
        model_name="StefanStefan/Wav2Vec-100-CSR",
        dataset_name="StefanStefan/STT",
        output_dir="./wav2vec2_distillation",
        size_reduction=0.10,  # Reduced from 0.15 to keep more capacity
        temperature=2.0,      # Lower temperature for better knowledge transfer
        epochs=15,            # More epochs for better convergence
        batch_size=8,         # Larger batch size with accumulation
        learning_rate=5e-5,   # Slightly higher learning rate with one-cycle policy
        eval_steps=500        # Evaluate every 500 steps to match table formatting
    )

Model Size Calculator

In [None]:
import os
import sys
import torch
import psutil
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from prettytable import PrettyTable

def get_directory_size(path):
    """Calculate the total size of a directory in MB."""
    total_size = 0
    for dirpath, dirnames, filenames in os.walk(path):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            if not os.path.islink(fp):
                total_size += os.path.getsize(fp)
    return total_size / (1024 * 1024)  # Convert to MB

def count_parameters(model):
    """Count the number of parameters in a model."""
    return sum(p.numel() for p in model.parameters())

def count_non_zero_parameters(model):
    """Count the number of non-zero parameters in a model."""
    return sum(torch.count_nonzero(p).item() for p in model.parameters())

def get_model_sparsity(model):
    """Calculate the sparsity of a model (percentage of zero parameters)."""
    total_params = count_parameters(model)
    non_zero_params = count_non_zero_parameters(model)
    return (1 - non_zero_params / total_params) * 100

def analyze_model(model_path, device='cpu'):
    """Analyze a model and return size statistics."""
    print(f"Analyzing model at {model_path}...")

    # Measure disk usage
    disk_size = get_directory_size(model_path)

    # Load model and measure memory
    processor = Wav2Vec2Processor.from_pretrained(model_path)
    model = Wav2Vec2ForCTC.from_pretrained(model_path).to(device)

    # Count parameters
    total_params = count_parameters(model)
    non_zero_params = count_non_zero_parameters(model)
    sparsity = get_model_sparsity(model)

    # Get file sizes for individual components
    model_files = {}
    for file in os.listdir(model_path):
        file_path = os.path.join(model_path, file)
        if os.path.isfile(file_path):
            file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
            model_files[file] = file_size

    return {
        "disk_size": disk_size,
        "total_params": total_params,
        "non_zero_params": non_zero_params,
        "sparsity": sparsity,
        "files": model_files
    }

def format_number(num):
    """Format large numbers with commas."""
    return f"{num:,.2f}" if isinstance(num, float) else f"{num:,}"

def print_model_comparison(original_stats, optimized_stats):
    """Print a comparison table between original and optimized models."""
    table = PrettyTable()
    table.field_names = ["Metric", "Original Model", "Optimized Model", "Difference", "Reduction %"]

    # Add disk size
    disk_diff = original_stats["disk_size"] - optimized_stats["disk_size"]
    disk_percent = (disk_diff / original_stats["disk_size"]) * 100 if original_stats["disk_size"] > 0 else 0
    table.add_row([
        "Disk Size (MB)",
        f"{original_stats['disk_size']:.2f}",
        f"{optimized_stats['disk_size']:.2f}",
        f"{disk_diff:.2f}",
        f"{disk_percent:.2f}%"
    ])

    # Add parameter counts
    param_diff = original_stats["total_params"] - optimized_stats["total_params"]
    param_percent = (param_diff / original_stats["total_params"]) * 100 if original_stats["total_params"] > 0 else 0
    table.add_row([
        "Total Parameters",
        format_number(original_stats["total_params"]),
        format_number(optimized_stats["total_params"]),
        format_number(param_diff),
        f"{param_percent:.2f}%"
    ])

    # Add non-zero parameter counts
    nonzero_diff = original_stats["non_zero_params"] - optimized_stats["non_zero_params"]
    nonzero_percent = (nonzero_diff / original_stats["non_zero_params"]) * 100 if original_stats["non_zero_params"] > 0 else 0
    table.add_row([
        "Non-Zero Parameters",
        format_number(original_stats["non_zero_params"]),
        format_number(optimized_stats["non_zero_params"]),
        format_number(nonzero_diff),
        f"{nonzero_percent:.2f}%"
    ])

    # Add sparsity
    sparsity_diff = optimized_stats["sparsity"] - original_stats["sparsity"]
    table.add_row([
        "Sparsity (%)",
        f"{original_stats['sparsity']:.2f}%",
        f"{optimized_stats['sparsity']:.2f}%",
        f"{sparsity_diff:.2f}%",
        "N/A"
    ])

    print(table)

    # Print file breakdown
    print("\nFile Size Breakdown (MB):")
    file_table = PrettyTable()
    file_table.field_names = ["File", "Original Size", "Optimized Size", "Difference", "Reduction %"]

    all_files = set(list(original_stats["files"].keys()) + list(optimized_stats["files"].keys()))
    for file in sorted(all_files):
        orig_size = original_stats["files"].get(file, 0)
        opt_size = optimized_stats["files"].get(file, 0)
        diff = orig_size - opt_size
        percent = (diff / orig_size) * 100 if orig_size > 0 else 0

        file_table.add_row([
            file,
            f"{orig_size:.2f}",
            f"{opt_size:.2f}",
            f"{diff:.2f}",
            f"{percent:.2f}%"
        ])

    print(file_table)

def main():

    original_model_path = "/content/wav2vec2-quantization_fp16_kd"
    optimized_model_path = "/content/wav2vec2-quantization_fp16"

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Analyze both models
    original_stats = analyze_model(original_model_path, device)
    optimized_stats = analyze_model(optimized_model_path, device)

    # Print comparison
    print("\n=== Model Size Comparison ===")
    print_model_comparison(original_stats, optimized_stats)

if __name__ == "__main__":
    main()

In [None]:
!pip install huggingface_hub
!huggingface-cli login

Push model to Hugging Face

In [None]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

# Load your model and processor from the local folder
model = Wav2Vec2ForCTC.from_pretrained("/content/wav2vec2-quantization_fp16")
processor = Wav2Vec2Processor.from_pretrained("/content/wav2vec2-quantization_fp16")

# Push to the Hub (replace with your HF Hub repository name)
repo_name = "StefanStefan/Wav2Vec-100-CSR-Quantized"
model.push_to_hub(repo_name)
processor.push_to_hub(repo_name)
