In [None]:
"""
Iteration: 1 - CORAL-Urdu-ASR - CORAL_Iteration1_Confidence_Caliberation.ipynb
==================================================================
Implements:
1. WER/CER computation
2. Confidence calibration metrics (ECE, MCE, Reliability diagrams)
3. Benchmark dataset integration (Mozilla Common Voice Urdu)
4. Baseline establishment across all models
5. Error analysis and categorization
"""

import torch
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import json
from dataclasses import dataclass, asdict
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import editdistance
from sklearn.metrics import mean_absolute_error, mean_squared_error
import warnings
warnings.filterwarnings('ignore')

# Import your existing wrapper
# from coral import UrduASRWrapper


@dataclass
class TranscriptionResult:
    """Store transcription results with metadata"""
    audio_id: str
    reference: str
    hypothesis: str
    word_probs: List[Tuple[str, float]]
    model_name: str
    wer: float
    cer: float
    duration: float
    avg_confidence: float


@dataclass
class CalibrationMetrics:
    """Confidence calibration metrics"""
    ece: float  # Expected Calibration Error
    mce: float  # Maximum Calibration Error
    ace: float  # Average Calibration Error
    brier_score: float
    confidence_accuracy_correlation: float
    reliability_bins: Dict[str, List[float]]


class WERCalculator:
    """Compute Word Error Rate and Character Error Rate"""
    
    @staticmethod
    def normalize_text(text: str) -> str:
        """Normalize Urdu text for comparison"""
        # Remove extra whitespace
        text = ' '.join(text.split())
        # Convert to lowercase (if using Roman Urdu)
        # text = text.lower()
        return text.strip()
    
    @staticmethod
    def compute_wer(reference: str, hypothesis: str) -> float:
        """
        Compute Word Error Rate
        WER = (S + D + I) / N
        where S=substitutions, D=deletions, I=insertions, N=reference words
        """
        ref_words = WERCalculator.normalize_text(reference).split()
        hyp_words = WERCalculator.normalize_text(hypothesis).split()
        
        if len(ref_words) == 0:
            return 0.0 if len(hyp_words) == 0 else 1.0
        
        distance = editdistance.eval(ref_words, hyp_words)
        wer = distance / len(ref_words)
        return wer
    
    @staticmethod
    def compute_cer(reference: str, hypothesis: str) -> float:
        """Compute Character Error Rate"""
        ref_chars = list(WERCalculator.normalize_text(reference))
        hyp_chars = list(WERCalculator.normalize_text(hypothesis))
        
        if len(ref_chars) == 0:
            return 0.0 if len(hyp_chars) == 0 else 1.0
        
        distance = editdistance.eval(ref_chars, hyp_chars)
        cer = distance / len(ref_chars)
        return cer
    
    @staticmethod
    def analyze_errors(reference: str, hypothesis: str) -> Dict[str, int]:
        """Detailed error analysis"""
        ref_words = WERCalculator.normalize_text(reference).split()
        hyp_words = WERCalculator.normalize_text(hypothesis).split()
        
        # Count error types
        errors = {
            'substitutions': 0,
            'deletions': 0,
            'insertions': 0,
            'correct': 0
        }
        
        # Simple alignment-based counting
        distance = editdistance.eval(ref_words, hyp_words)
        len_diff = abs(len(ref_words) - len(hyp_words))
        
        if len(hyp_words) > len(ref_words):
            errors['insertions'] = len(hyp_words) - len(ref_words)
        elif len(hyp_words) < len(ref_words):
            errors['deletions'] = len(ref_words) - len(hyp_words)
        
        errors['substitutions'] = distance - len_diff
        errors['correct'] = len(ref_words) - (errors['substitutions'] + errors['deletions'])
        
        return errors


class ConfidenceCalibrator:
    """Analyze and calibrate confidence scores"""
    
    @staticmethod
    def compute_ece(confidences: np.ndarray, accuracies: np.ndarray, n_bins: int = 10) -> float:
        """
        Expected Calibration Error (ECE)
        Measures average difference between confidence and accuracy across bins
        """
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        ece = 0.0
        
        for i in range(n_bins):
            bin_lower = bin_boundaries[i]
            bin_upper = bin_boundaries[i + 1]
            
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            prop_in_bin = in_bin.mean()
            
            if prop_in_bin > 0:
                accuracy_in_bin = accuracies[in_bin].mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        
        return ece
    
    @staticmethod
    def compute_mce(confidences: np.ndarray, accuracies: np.ndarray, n_bins: int = 10) -> float:
        """Maximum Calibration Error (MCE)"""
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        ce_bins = []
        
        for i in range(n_bins):
            bin_lower = bin_boundaries[i]
            bin_upper = bin_boundaries[i + 1]
            
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            
            if in_bin.sum() > 0:
                accuracy_in_bin = accuracies[in_bin].mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ce_bins.append(np.abs(avg_confidence_in_bin - accuracy_in_bin))
        
        return max(ce_bins) if ce_bins else 0.0
    
    @staticmethod
    def compute_brier_score(confidences: np.ndarray, accuracies: np.ndarray) -> float:
        """Brier Score - measure of prediction accuracy"""
        return mean_squared_error(accuracies, confidences)
    
    @staticmethod
    def reliability_diagram_data(confidences: np.ndarray, accuracies: np.ndarray, n_bins: int = 10) -> Dict:
        """Prepare data for reliability diagram"""
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        bin_data = {
            'confidence_bins': [],
            'accuracies': [],
            'counts': []
        }
        
        for i in range(n_bins):
            bin_lower = bin_boundaries[i]
            bin_upper = bin_boundaries[i + 1]
            
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            
            if in_bin.sum() > 0:
                bin_data['confidence_bins'].append((bin_lower + bin_upper) / 2)
                bin_data['accuracies'].append(accuracies[in_bin].mean())
                bin_data['counts'].append(in_bin.sum())
        
        return bin_data
    
    @staticmethod
    def analyze_calibration(word_probs: List[Tuple[str, float]], 
                          reference: str, 
                          hypothesis: str) -> CalibrationMetrics:
        """Complete calibration analysis"""
        ref_words = reference.split()
        hyp_words = [w for w, p in word_probs]
        
        # Align words and compute accuracies
        confidences = []
        accuracies = []
        
        for i, (word, conf) in enumerate(word_probs):
            confidences.append(conf)
            # Simple accuracy: 1 if word matches reference at position, 0 otherwise
            if i < len(ref_words) and word == ref_words[i]:
                accuracies.append(1.0)
            else:
                accuracies.append(0.0)
        
        confidences = np.array(confidences)
        accuracies = np.array(accuracies)
        
        # Compute metrics
        ece = ConfidenceCalibrator.compute_ece(confidences, accuracies)
        mce = ConfidenceCalibrator.compute_mce(confidences, accuracies)
        ace = mean_absolute_error(accuracies, confidences)
        brier = ConfidenceCalibrator.compute_brier_score(confidences, accuracies)
        correlation = np.corrcoef(confidences, accuracies)[0, 1] if len(confidences) > 1 else 0.0
        
        reliability_bins = ConfidenceCalibrator.reliability_diagram_data(confidences, accuracies)
        
        return CalibrationMetrics(
            ece=ece,
            mce=mce,
            ace=ace,
            brier_score=brier,
            confidence_accuracy_correlation=correlation,
            reliability_bins=reliability_bins
        )


class BenchmarkEvaluator:
    """Evaluate models on benchmark datasets"""
    
    def __init__(self, asr_wrapper, output_dir: str = "./eval_results"):
        self.asr_wrapper = asr_wrapper
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True, parents=True)
        self.results = []
    
    def load_common_voice_dataset(self, dataset_path: str, split: str = "test", max_samples: int = 100):
        """
        Load Mozilla Common Voice Urdu dataset
        Expected structure:
        dataset_path/
            clips/
                audio1.mp3
                audio2.mp3
            test.tsv (or validated.tsv)
        """
        import csv
        
        dataset_path = Path(dataset_path)
        
        # Load metadata
        tsv_file = dataset_path / f"{split}.tsv"
        if not tsv_file.exists():
            tsv_file = dataset_path / "validated.tsv"
        
        samples = []
        with open(tsv_file, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f, delimiter='\t')
            for i, row in enumerate(reader):
                if i >= max_samples:
                    break
                
                audio_path = dataset_path / "clips" / row['path']
                if audio_path.exists():
                    samples.append({
                        'audio_id': row['path'],
                        'audio_path': str(audio_path),
                        'reference': row['sentence']
                    })
        
        return samples
    
    def evaluate_model(self, model_name: str, test_samples: List[Dict]) -> List[TranscriptionResult]:
        """Evaluate single model on test samples"""
        results = []
        
        print(f"\nEvaluating {model_name}...")
        
        for sample in tqdm(test_samples, desc=f"Processing {model_name}"):
            try:
                # Transcribe
                word_probs = self.asr_wrapper.word_probabilities(
                    sample['audio_path'], 
                    model_name
                )
                
                hypothesis = ' '.join([w for w, p in word_probs])
                
                # Compute metrics
                wer = WERCalculator.compute_wer(sample['reference'], hypothesis)
                cer = WERCalculator.compute_cer(sample['reference'], hypothesis)
                avg_conf = np.mean([p for w, p in word_probs]) if word_probs else 0.0
                
                # Get audio duration
                import librosa
                audio, sr = librosa.load(sample['audio_path'], sr=16000)
                duration = len(audio) / sr
                
                result = TranscriptionResult(
                    audio_id=sample['audio_id'],
                    reference=sample['reference'],
                    hypothesis=hypothesis,
                    word_probs=word_probs,
                    model_name=model_name,
                    wer=wer,
                    cer=cer,
                    duration=duration,
                    avg_confidence=avg_conf
                )
                
                results.append(result)
                
            except Exception as e:
                print(f"Error processing {sample['audio_id']}: {str(e)}")
                continue
        
        return results
    
    def evaluate_all_models(self, test_samples: List[Dict], models: List[str]) -> pd.DataFrame:
        """Evaluate all models and compile results"""
        all_results = []
        
        for model in models:
            model_results = self.evaluate_model(model, test_samples)
            all_results.extend(model_results)
            self.results.extend(model_results)
        
        # Convert to DataFrame
        df = pd.DataFrame([asdict(r) for r in all_results])
        
        # Save results
        output_file = self.output_dir / "evaluation_results.csv"
        df.to_csv(output_file, index=False, encoding='utf-8')
        print(f"\nResults saved to {output_file}")
        
        return df
    
    def compute_aggregate_metrics(self, df: pd.DataFrame) -> pd.DataFrame:
        """Compute aggregate metrics per model"""
        aggregate = df.groupby('model_name').agg({
            'wer': ['mean', 'std', 'min', 'max'],
            'cer': ['mean', 'std', 'min', 'max'],
            'avg_confidence': ['mean', 'std'],
            'duration': 'sum'
        }).round(4)
        
        # Save aggregate metrics
        output_file = self.output_dir / "aggregate_metrics.csv"
        aggregate.to_csv(output_file)
        print(f"Aggregate metrics saved to {output_file}")
        
        return aggregate
    
    def analyze_confidence_calibration_all(self) -> pd.DataFrame:
        """Analyze confidence calibration for all results"""
        calibration_results = []
        
        for result in self.results:
            cal_metrics = ConfidenceCalibrator.analyze_calibration(
                result.word_probs,
                result.reference,
                result.hypothesis
            )
            
            calibration_results.append({
                'model_name': result.model_name,
                'audio_id': result.audio_id,
                'ece': cal_metrics.ece,
                'mce': cal_metrics.mce,
                'ace': cal_metrics.ace,
                'brier_score': cal_metrics.brier_score,
                'correlation': cal_metrics.confidence_accuracy_correlation
            })
        
        df_cal = pd.DataFrame(calibration_results)
        
        # Aggregate by model
        cal_aggregate = df_cal.groupby('model_name').agg({
            'ece': ['mean', 'std'],
            'mce': ['mean', 'std'],
            'ace': ['mean', 'std'],
            'brier_score': ['mean', 'std'],
            'correlation': ['mean', 'std']
        }).round(4)
        
        # Save
        output_file = self.output_dir / "calibration_metrics.csv"
        cal_aggregate.to_csv(output_file)
        print(f"Calibration metrics saved to {output_file}")
        
        return cal_aggregate
    
    def generate_visualizations(self, df: pd.DataFrame):
        """Generate comprehensive visualizations"""
        # 1. WER comparison
        plt.figure(figsize=(12, 6))
        model_wer = df.groupby('model_name')['wer'].mean().sort_values()
        plt.barh(model_wer.index, model_wer.values)
        plt.xlabel('Word Error Rate (WER)')
        plt.title('Model Comparison: Average WER')
        plt.tight_layout()
        plt.savefig(self.output_dir / 'wer_comparison.png', dpi=300, bbox_inches='tight')
        plt.close()
        
        # 2. WER vs Confidence scatter
        plt.figure(figsize=(10, 6))
        for model in df['model_name'].unique():
            model_data = df[df['model_name'] == model]
            plt.scatter(model_data['avg_confidence'], model_data['wer'], 
                       label=model, alpha=0.6)
        plt.xlabel('Average Confidence')
        plt.ylabel('WER')
        plt.title('WER vs Confidence by Model')
        plt.legend()
        plt.tight_layout()
        plt.savefig(self.output_dir / 'wer_vs_confidence.png', dpi=300, bbox_inches='tight')
        plt.close()
        
        # 3. Error distribution
        plt.figure(figsize=(12, 6))
        df.boxplot(column='wer', by='model_name', figsize=(12, 6))
        plt.ylabel('WER')
        plt.title('WER Distribution by Model')
        plt.suptitle('')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(self.output_dir / 'wer_distribution.png', dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"Visualizations saved to {self.output_dir}")
    
    def generate_report(self, df: pd.DataFrame):
        """Generate comprehensive evaluation report"""
        report_path = self.output_dir / "iteration1_report.txt"
        
        with open(report_path, 'w', encoding='utf-8') as f:
            f.write("="*80 + "\n")
            f.write("CORAL ITERATION 1: BASELINE EVALUATION REPORT\n")
            f.write("="*80 + "\n\n")
            
            f.write(f"Total samples evaluated: {len(df)}\n")
            f.write(f"Models evaluated: {df['model_name'].nunique()}\n")
            f.write(f"Total audio duration: {df['duration'].sum():.2f} seconds\n\n")
            
            f.write("-"*80 + "\n")
            f.write("AGGREGATE METRICS BY MODEL\n")
            f.write("-"*80 + "\n\n")
            
            aggregate = self.compute_aggregate_metrics(df)
            f.write(aggregate.to_string())
            f.write("\n\n")
            
            f.write("-"*80 + "\n")
            f.write("BEST PERFORMING MODELS\n")
            f.write("-"*80 + "\n\n")
            
            best_wer = df.groupby('model_name')['wer'].mean().idxmin()
            best_wer_value = df.groupby('model_name')['wer'].mean().min()
            f.write(f"Best WER: {best_wer} ({best_wer_value:.4f})\n\n")
            
            best_cer = df.groupby('model_name')['cer'].mean().idxmin()
            best_cer_value = df.groupby('model_name')['cer'].mean().min()
            f.write(f"Best CER: {best_cer} ({best_cer_value:.4f})\n\n")
            
            f.write("-"*80 + "\n")
            f.write("CALIBRATION ANALYSIS\n")
            f.write("-"*80 + "\n\n")
            
            cal_metrics = self.analyze_confidence_calibration_all()
            f.write(cal_metrics.to_string())
            f.write("\n\n")
            
            f.write("="*80 + "\n")
            f.write("Report generation complete\n")
            f.write("="*80 + "\n")
        
        print(f"\nComprehensive report saved to {report_path}")


# ============================================================================
# USAGE EXAMPLE
# ============================================================================

def run_iteration1_evaluation(dataset_path: str, max_samples: int = 50):
    """
    Complete Iteration 1 evaluation pipeline
    
    Args:
        dataset_path: Path to Common Voice Urdu dataset
        max_samples: Number of test samples (start small for testing)
    """
# ============================================================================
# UrduASRWrapper Class Definition (from CORAL_Iteration1_ASR_Ensemble.ipynb)
# ============================================================================

from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    Wav2Vec2Processor, Wav2Vec2ForCTC,
    SeamlessM4TForSpeechToText, SeamlessM4TProcessor,
    AutoProcessor, AutoModelForCTC
)

class UrduASRWrapper:
    """Unified wrapper for multiple Urdu ASR models."""
    
    SUPPORTED_MODELS = {
        "whisper-large": "openai/whisper-large-v3",
        "whisper-medium": "openai/whisper-medium",
        "whisper-small": "openai/whisper-small",
        "seamless-large": "facebook/seamless-m4t-v2-large",
        "seamless-medium": "facebook/seamless-m4t-medium",
        "mms-1b": "facebook/mms-1b-all",
        "mms-300m": "facebook/mms-300m",
        "wav2vec2-urdu": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu"
    }
    
    def __init__(self, device: str = None):
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        
        print(f"ASR Wrapper initialized on: {self.device}")
        
        self.current_model = None
        self.processor = None
        self.current_model_name = None
    
    def _preprocess_audio(self, file_path: str, target_sr: int = 16000) -> np.ndarray:
        try:
            audio, sr = librosa.load(file_path, sr=target_sr, mono=True)
            
            if audio.dtype != np.float32:
                audio = audio.astype(np.float32)
            
            max_val = np.abs(audio).max()
            if max_val > 0:
                audio = audio / max_val
            
            return audio
            
        except Exception as e:
            raise ValueError(f"Error loading audio file {file_path}: {str(e)}")
    
    def _load_model(self, model_name: str):
        if model_name not in self.SUPPORTED_MODELS:
            raise ValueError(f"Model {model_name} not supported. Choose from: {list(self.SUPPORTED_MODELS.keys())}")
        
        model_id = self.SUPPORTED_MODELS[model_name]
        print(f"Loading {model_name} ({model_id})...")
        
        try:
            if "whisper" in model_name:
                self.processor = WhisperProcessor.from_pretrained(model_id)
                self.current_model = WhisperForConditionalGeneration.from_pretrained(model_id)
                
            elif "seamless" in model_name:
                self.processor = SeamlessM4TProcessor.from_pretrained(model_id)
                self.current_model = SeamlessM4TForSpeechToText.from_pretrained(model_id)
                
            elif "mms" in model_name:
                self.processor = AutoProcessor.from_pretrained(model_id)
                self.current_model = AutoModelForCTC.from_pretrained(model_id)
                
            elif "wav2vec2" in model_name:
                self.processor = Wav2Vec2Processor.from_pretrained(model_id)
                self.current_model = Wav2Vec2ForCTC.from_pretrained(model_id)
            
            self.current_model = self.current_model.to(self.device)
            self.current_model.eval()
            self.current_model_name = model_name
            
            print(f"{model_name} loaded successfully")
            
        except Exception as e:
            raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
    
    def _extract_whisper_probabilities(self, audio_array: np.ndarray) -> List[Tuple[str, float]]:
        input_features = self.processor(
            audio_array, 
            sampling_rate=16000, 
            return_tensors="pt"
        ).input_features.to(self.device)
        
        with torch.no_grad():
            predicted_ids = self.current_model.generate(
                input_features,
                return_dict_in_generate=True,
                output_scores=True
            )
        
        transcription = self.processor.batch_decode(
            predicted_ids.sequences, 
            skip_special_tokens=True
        )[0]
        
        word_probs = []
        if hasattr(predicted_ids, 'scores') and predicted_ids.scores:
            all_probs = []
            for score in predicted_ids.scores:
                probs = torch.softmax(score, dim=-1)
                max_prob = probs.max().item()
                all_probs.append(max_prob)
            
            words = transcription.strip().split()
            
            if len(words) > 0 and len(all_probs) > 0:
                avg_prob = np.mean(all_probs)
                word_probs = [(word, avg_prob) for word in words]
            else:
                word_probs = [(word, 0.5) for word in words]
        else:
            words = transcription.strip().split()
            word_probs = [(word, 0.8) for word in words]
        
        return word_probs
    
    def _extract_ctc_probabilities(self, audio_array: np.ndarray) -> List[Tuple[str, float]]:
        inputs = self.processor(
            audio_array,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        )
        
        input_values = inputs.input_values.to(self.device)
        
        with torch.no_grad():
            logits = self.current_model(input_values).logits
        
        probs = torch.softmax(logits, dim=-1)
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = self.processor.batch_decode(predicted_ids)[0]
        
        words = transcription.strip().split()
        word_probs = []
        
        if len(words) > 0:
            max_probs = probs.max(dim=-1).values.squeeze()
            avg_confidence = max_probs.mean().item()
            word_probs = [(word, avg_confidence) for word in words]
        
        return word_probs
    
    def _extract_seamless_probabilities(self, audio_array: np.ndarray) -> List[Tuple[str, float]]:
        audio_inputs = self.processor(
            audios=audio_array,
            sampling_rate=16000,
            return_tensors="pt"
        ).to(self.device)
        
        with torch.no_grad():
            output = self.current_model.generate(
                **audio_inputs,
                tgt_lang="urd",
                return_dict_in_generate=True,
                output_scores=True
            )
        
        transcription = self.processor.decode(
            output.sequences[0].tolist(),
            skip_special_tokens=True
        )
        
        word_probs = []
        if hasattr(output, 'scores') and output.scores:
            all_probs = []
            for score in output.scores:
                probs = torch.softmax(score, dim=-1)
                max_prob = probs.max().item()
                all_probs.append(max_prob)
            
            words = transcription.strip().split()
            if len(words) > 0 and len(all_probs) > 0:
                avg_prob = np.mean(all_probs)
                word_probs = [(word, avg_prob) for word in words]
            else:
                word_probs = [(word, 0.7) for word in words]
        else:
            words = transcription.strip().split()
            word_probs = [(word, 0.7) for word in words]
        
        return word_probs
    
    def _cleanup(self):
        if self.current_model is not None:
            del self.current_model
            self.current_model = None
        
        if self.processor is not None:
            del self.processor
            self.processor = None
        
        self.current_model_name = None
        
        if self.device == "cuda":
            torch.cuda.empty_cache()
        gc.collect()
    
    def word_probabilities(self, audio_file_path: str, model_name: str) -> List[Tuple[str, float]]:
        try:
            audio_array = self._preprocess_audio(audio_file_path)
            self._load_model(model_name)
            
            if "whisper" in model_name:
                results = self._extract_whisper_probabilities(audio_array)
            elif "mms" in model_name or "wav2vec2" in model_name:
                results = self._extract_ctc_probabilities(audio_array)
            elif "seamless" in model_name:
                results = self._extract_seamless_probabilities(audio_array)
            else:
                raise ValueError(f"Unknown model type: {model_name}")
            
            self._cleanup()
            return results
            
        except Exception as e:
            self._cleanup()
            raise RuntimeError(f"Error processing audio with {model_name}: {str(e)}")

# ============================================================================
# End of UrduASRWrapper Class
# ============================================================================
    
    # Initialize
    asr_wrapper = UrduASRWrapper(device='cpu')  # Use 'cuda' if available
    evaluator = BenchmarkEvaluator(asr_wrapper, output_dir="./iteration1_results")
    
    # Load test data
    print("Loading test dataset...")
    test_samples = evaluator.load_common_voice_dataset(dataset_path, max_samples=max_samples)
    print(f"Loaded {len(test_samples)} test samples")
    
    # Models to evaluate (start with smaller models for speed)
    models_to_test = [
        "whisper-small",
        "whisper-medium",
        "wav2vec2-urdu",
        "mms-300m"
    ]
    
    # Run evaluation
    print("\nStarting evaluation...")
    results_df = evaluator.evaluate_all_models(test_samples, models_to_test)
    
    # Compute metrics
    print("\nComputing aggregate metrics...")
    evaluator.compute_aggregate_metrics(results_df)
    
    # Analyze calibration
    print("\nAnalyzing confidence calibration...")
    evaluator.analyze_confidence_calibration_all()
    
    # Generate visualizations
    print("\nGenerating visualizations...")
    evaluator.generate_visualizations(results_df)
    
    # Generate report
    print("\nGenerating final report...")
    evaluator.generate_report(results_df)
    
    print("\n" + "="*80)
    print("ITERATION 1 EVALUATION COMPLETE!")
    print("="*80)
    print(f"Results directory: {evaluator.output_dir}")
    print("\nDeliverables:")
    print("  - evaluation_results.csv: Detailed per-sample results")
    print("  - aggregate_metrics.csv: Model performance summary")
    print("  - calibration_metrics.csv: Confidence calibration analysis")
    print("  - wer_comparison.png: Model WER comparison")
    print("  - wer_vs_confidence.png: Confidence vs accuracy analysis")
    print("  - wer_distribution.png: Error distribution visualization")
    print("  - iteration1_report.txt: Comprehensive evaluation report")
    print("="*80)
    
    return results_df


if __name__ == "__main__":
    # Example usage
    file_path = r"C:/Users/Nouman Hafeez\Desktop/CORAL-Urdu-ASR\dataset/cv-corpus-22.0-delta-2025-06-20/ur/clips/common_voice_ur_42810146.mp3"
    
    # Run evaluation with 50 samples (increase for final evaluation)
    results = run_iteration1_evaluation(file_path, max_samples=50)

ModuleNotFoundError: No module named 'coral'