In [None]:
"""
Iteration: 1 - CORAL-Urdu-ASR - CORAL_Iteration1_Complete_Pipeline.ipynb
====================================================
This notebook integrates your existing ASR wrapper with comprehensive
evaluation, calibration, and baseline establishment.

Run this in Kaggle with Mozilla Common Voice Urdu dataset
"""

# ============================================================================
# SECTION 1: SETUP & IMPORTS
# ============================================================================

import torch
import gc
import librosa
import soundfile as sf
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Dict
import warnings
import json
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from dataclasses import dataclass, asdict
import editdistance
from collections import defaultdict

warnings.filterwarnings('ignore')

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 100

print("✅ Imports successful")
print(f"🔧 Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

# ============================================================================
# SECTION 2: YOUR EXISTING ASR WRAPPER (Integrated)
# ============================================================================

# Copy your UrduASRWrapper class here from CORAL_Iteration1_ASR_Ensemble.ipynb
# Or import it if saved as a module

from transformers import (
    WhisperProcessor, 
    WhisperForConditionalGeneration,
    Wav2Vec2Processor, 
    Wav2Vec2ForCTC,
    SeamlessM4TForSpeechToText,
    SeamlessM4TProcessor,
    AutoProcessor,
    AutoModelForCTC
)

class UrduASRWrapper:
    """Your existing ASR wrapper - copied from coral.ipynb"""
    
    SUPPORTED_MODELS = {
        "whisper-large": "openai/whisper-large-v3",
        "whisper-medium": "openai/whisper-medium",
        "whisper-small": "openai/whisper-small",
        "seamless-large": "facebook/seamless-m4t-v2-large",
        "seamless-medium": "facebook/seamless-m4t-medium",
        "mms-1b": "facebook/mms-1b-all",
        "mms-300m": "facebook/mms-300m",
        "wav2vec2-urdu": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu"
    }
    
    def __init__(self, device: str = None):
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        
        print(f"🚀 ASR Wrapper initialized on: {self.device}")
        
        self.current_model = None
        self.processor = None
        self.current_model_name = None
    
    def _preprocess_audio(self, file_path: str, target_sr: int = 16000) -> np.ndarray:
        try:
            audio, sr = librosa.load(file_path, sr=target_sr, mono=True)
            
            if audio.dtype != np.float32:
                audio = audio.astype(np.float32)
            
            max_val = np.abs(audio).max()
            if max_val > 0:
                audio = audio / max_val
            
            return audio
            
        except Exception as e:
            raise ValueError(f"Error loading audio file {file_path}: {str(e)}")
    
    def _load_model(self, model_name: str):
        if model_name not in self.SUPPORTED_MODELS:
            raise ValueError(f"Model {model_name} not supported")
        
        model_id = self.SUPPORTED_MODELS[model_name]
        
        try:
            if "whisper" in model_name:
                self.processor = WhisperProcessor.from_pretrained(model_id)
                self.current_model = WhisperForConditionalGeneration.from_pretrained(model_id)
                
            elif "seamless" in model_name:
                self.processor = SeamlessM4TProcessor.from_pretrained(model_id)
                self.current_model = SeamlessM4TForSpeechToText.from_pretrained(model_id)
                
            elif "mms" in model_name:
                self.processor = AutoProcessor.from_pretrained(model_id)
                self.current_model = AutoModelForCTC.from_pretrained(model_id)
                
            elif "wav2vec2" in model_name:
                self.processor = Wav2Vec2Processor.from_pretrained(model_id)
                self.current_model = Wav2Vec2ForCTC.from_pretrained(model_id)
            
            self.current_model = self.current_model.to(self.device)
            self.current_model.eval()
            self.current_model_name = model_name
            
        except Exception as e:
            raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
    
    def _extract_whisper_probabilities(self, audio_array: np.ndarray) -> List[Tuple[str, float]]:
        input_features = self.processor(
            audio_array, 
            sampling_rate=16000, 
            return_tensors="pt"
        ).input_features.to(self.device)
        
        with torch.no_grad():
            predicted_ids = self.current_model.generate(
                input_features,
                return_dict_in_generate=True,
                output_scores=True
            )
        
        transcription = self.processor.batch_decode(
            predicted_ids.sequences, 
            skip_special_tokens=True
        )[0]
        
        word_probs = []
        if hasattr(predicted_ids, 'scores') and predicted_ids.scores:
            all_probs = []
            for score in predicted_ids.scores:
                probs = torch.softmax(score, dim=-1)
                max_prob = probs.max().item()
                all_probs.append(max_prob)
            
            words = transcription.strip().split()
            
            if len(words) > 0 and len(all_probs) > 0:
                avg_prob = np.mean(all_probs)
                word_probs = [(word, avg_prob) for word in words]
            else:
                word_probs = [(word, 0.5) for word in words]
        else:
            words = transcription.strip().split()
            word_probs = [(word, 0.8) for word in words]
        
        return word_probs
    
    def _extract_ctc_probabilities(self, audio_array: np.ndarray) -> List[Tuple[str, float]]:
        inputs = self.processor(
            audio_array,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        )
        
        input_values = inputs.input_values.to(self.device)
        
        with torch.no_grad():
            logits = self.current_model(input_values).logits
        
        probs = torch.softmax(logits, dim=-1)
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = self.processor.batch_decode(predicted_ids)[0]
        
        words = transcription.strip().split()
        word_probs = []
        
        if len(words) > 0:
            max_probs = probs.max(dim=-1).values.squeeze()
            avg_confidence = max_probs.mean().item()
            word_probs = [(word, avg_confidence) for word in words]
        
        return word_probs
    
    def _extract_seamless_probabilities(self, audio_array: np.ndarray) -> List[Tuple[str, float]]:
        audio_inputs = self.processor(
            audios=audio_array,
            sampling_rate=16000,
            return_tensors="pt"
        ).to(self.device)
        
        with torch.no_grad():
            output = self.current_model.generate(
                **audio_inputs,
                tgt_lang="urd",
                return_dict_in_generate=True,
                output_scores=True
            )
        
        transcription = self.processor.decode(
            output.sequences[0].tolist(),
            skip_special_tokens=True
        )
        
        word_probs = []
        if hasattr(output, 'scores') and output.scores:
            all_probs = []
            for score in output.scores:
                probs = torch.softmax(score, dim=-1)
                max_prob = probs.max().item()
                all_probs.append(max_prob)
            
            words = transcription.strip().split()
            if len(words) > 0 and len(all_probs) > 0:
                avg_prob = np.mean(all_probs)
                word_probs = [(word, avg_prob) for word in words]
            else:
                word_probs = [(word, 0.7) for word in words]
        else:
            words = transcription.strip().split()
            word_probs = [(word, 0.7) for word in words]
        
        return word_probs
    
    def _cleanup(self):
        if self.current_model is not None:
            del self.current_model
            self.current_model = None
        
        if self.processor is not None:
            del self.processor
            self.processor = None
        
        self.current_model_name = None
        
        if self.device == "cuda":
            torch.cuda.empty_cache()
        gc.collect()
    
    def word_probabilities(self, audio_file_path: str, model_name: str) -> List[Tuple[str, float]]:
        try:
            audio_array = self._preprocess_audio(audio_file_path)
            self._load_model(model_name)
            
            if "whisper" in model_name:
                results = self._extract_whisper_probabilities(audio_array)
            elif "mms" in model_name or "wav2vec2" in model_name:
                results = self._extract_ctc_probabilities(audio_array)
            elif "seamless" in model_name:
                results = self._extract_seamless_probabilities(audio_array)
            else:
                raise ValueError(f"Unknown model type: {model_name}")
            
            self._cleanup()
            return results
            
        except Exception as e:
            self._cleanup()
            raise RuntimeError(f"Error processing audio with {model_name}: {str(e)}")

print("✅ ASR Wrapper loaded")

# ============================================================================
# SECTION 3: EVALUATION UTILITIES
# ============================================================================

def compute_wer(reference: str, hypothesis: str) -> float:
    """Compute Word Error Rate"""
    ref_words = reference.strip().split()
    hyp_words = hypothesis.strip().split()
    
    if len(ref_words) == 0:
        return 0.0 if len(hyp_words) == 0 else 1.0
    
    distance = editdistance.eval(ref_words, hyp_words)
    return distance / len(ref_words)

def compute_cer(reference: str, hypothesis: str) -> float:
    """Compute Character Error Rate"""
    ref_chars = list(reference.strip())
    hyp_chars = list(hypothesis.strip())
    
    if len(ref_chars) == 0:
        return 0.0 if len(hyp_chars) == 0 else 1.0
    
    distance = editdistance.eval(ref_chars, hyp_chars)
    return distance / len(ref_chars)

def compute_ece(confidences: np.ndarray, accuracies: np.ndarray, n_bins: int = 10) -> float:
    """Expected Calibration Error"""
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    ece = 0.0
    
    for i in range(n_bins):
        bin_lower = bin_boundaries[i]
        bin_upper = bin_boundaries[i + 1]
        
        in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
        prop_in_bin = in_bin.mean()
        
        if prop_in_bin > 0:
            accuracy_in_bin = accuracies[in_bin].mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    
    return ece

print("✅ Evaluation utilities loaded")

# ============================================================================
# SECTION 4: DATASET LOADING
# ============================================================================

def load_common_voice_test_set(dataset_path: str, max_samples: int = 100):
    """Load Mozilla Common Voice Urdu test samples"""
    import csv
    
    dataset_path = Path(dataset_path)
    
    # Try different metadata files
    tsv_files = ["test.tsv", "validated.tsv", "dev.tsv"]
    tsv_file = None
    
    for fname in tsv_files:
        potential_path = dataset_path / fname
        if potential_path.exists():
            tsv_file = potential_path
            break
    
    if tsv_file is None:
        raise FileNotFoundError(f"No metadata file found in {dataset_path}")
    
    print(f"Loading from: {tsv_file}")
    
    samples = []
    with open(tsv_file, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f, delimiter='\t')
        for i, row in enumerate(reader):
            if i >= max_samples:
                break
            
            audio_path = dataset_path / "clips" / row['path']
            if audio_path.exists():
                samples.append({
                    'audio_id': row['path'],
                    'audio_path': str(audio_path),
                    'reference': row['sentence'],
                    'duration': float(row.get('duration', 0))
                })
    
    print(f"Loaded {len(samples)} test samples")
    return samples

# ============================================================================
# SECTION 5: MODEL EVALUATION PIPELINE
# ============================================================================

def evaluate_single_model(asr_wrapper, model_name: str, test_samples: List[Dict], 
                         output_dir: Path) -> List[Dict]:
    """Evaluate a single model on all test samples"""
    
    results = []
    
    print(f"\n{'='*60}")
    print(f"Evaluating: {model_name}")
    print(f"{'='*60}")
    
    for sample in tqdm(test_samples, desc=f"{model_name}"):
        try:
            # Transcribe
            word_probs = asr_wrapper.word_probabilities(
                sample['audio_path'], 
                model_name
            )
            
            hypothesis = ' '.join([w for w, p in word_probs])
            reference = sample['reference']
            
            # Compute metrics
            wer = compute_wer(reference, hypothesis)
            cer = compute_cer(reference, hypothesis)
            avg_conf = np.mean([p for w, p in word_probs]) if word_probs else 0.0
            
            # Word-level accuracy for calibration
            ref_words = reference.split()
            hyp_words = [w for w, p in word_probs]
            confidences = [p for w, p in word_probs]
            
            # Simple word-level accuracy
            accuracies = []
            for i, (word, conf) in enumerate(word_probs):
                if i < len(ref_words) and word == ref_words[i]:
                    accuracies.append(1.0)
                else:
                    accuracies.append(0.0)
            
            # Calibration metrics
            if len(confidences) > 0:
                ece = compute_ece(np.array(confidences), np.array(accuracies))
            else:
                ece = 0.0
            
            result = {
                'audio_id': sample['audio_id'],
                'model_name': model_name,
                'reference': reference,
                'hypothesis': hypothesis,
                'wer': wer,
                'cer': cer,
                'avg_confidence': avg_conf,
                'ece': ece,
                'num_words': len(word_probs),
                'duration': sample['duration']
            }
            
            results.append(result)
            
        except Exception as e:
            print(f"\nError on {sample['audio_id']}: {str(e)}")
            continue
    
    return results

# ============================================================================
# SECTION 6: MAIN EVALUATION SCRIPT
# ============================================================================

def run_complete_iteration1_evaluation(dataset_path: str, 
                                      output_dir: str = "./iteration1_results",
                                      max_samples: int = 50):
    """
    Complete Iteration 1 evaluation pipeline
    
    Deliverables:
    1. Baseline WER for each model
    2. Confidence calibration metrics
    3. Comparative analysis
    4. Visualizations
    5. Final report
    """
    
    # Setup
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)
    
    print("="*80)
    print("CORAL ITERATION 1: BASELINE EVALUATION")
    print("="*80)
    
    # Load dataset
    print("\n[1/6] Loading test dataset...")
    test_samples = load_common_voice_test_set(dataset_path, max_samples)
    
    # Initialize ASR wrapper
    print("\n[2/6] Initializing ASR models...")
    asr_wrapper = UrduASRWrapper(device='cpu')  # Change to 'cuda' if available
    
    # Models to evaluate
    models_to_test = [
        "whisper-small",
        "whisper-medium", 
        "wav2vec2-urdu",
        "mms-300m"
    ]
    
    # Evaluate each model
    print("\n[3/6] Running model evaluations...")
    all_results = []
    
    for model in models_to_test:
        model_results = evaluate_single_model(asr_wrapper, model, test_samples, output_dir)
        all_results.extend(model_results)
    
    # Convert to DataFrame
    df = pd.DataFrame(all_results)
    
    # Save detailed results
    results_file = output_dir / "detailed_results.csv"
    df.to_csv(results_file, index=False, encoding='utf-8')
    print(f"\nDetailed results saved: {results_file}")
    
    # Compute aggregate metrics
    print("\n[4/6] Computing aggregate metrics...")
    aggregate = df.groupby('model_name').agg({
        'wer': ['mean', 'std', 'min', 'max'],
        'cer': ['mean', 'std', 'min', 'max'],
        'avg_confidence': ['mean', 'std'],
        'ece': ['mean', 'std'],
        'duration': 'sum'
    }).round(4)
    
    aggregate_file = output_dir / "aggregate_metrics.csv"
    aggregate.to_csv(aggregate_file)
    print(f"Aggregate metrics saved: {aggregate_file}")
    print("\n" + "="*60)
    print("AGGREGATE METRICS")
    print("="*60)
    print(aggregate)
    
    # Generate visualizations
    print("\n[5/6] Generating visualizations...")
    
    # 1. WER Comparison
    fig, ax = plt.subplots(figsize=(10, 6))
    model_wer = df.groupby('model_name')['wer'].mean().sort_values()
    ax.barh(model_wer.index, model_wer.values, color='steelblue')
    ax.set_xlabel('Word Error Rate (WER)', fontsize=12)
    ax.set_title('Model Comparison: Average WER', fontsize=14, fontweight='bold')
    ax.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.savefig(output_dir / 'wer_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 2. WER Distribution
    fig, ax = plt.subplots(figsize=(12, 6))
    df.boxplot(column='wer', by='model_name', ax=ax)
    ax.set_ylabel('WER', fontsize=12)
    ax.set_xlabel('Model', fontsize=12)
    ax.set_title('WER Distribution by Model', fontsize=14, fontweight='bold')
    plt.suptitle('')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_dir / 'wer_distribution.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 3. Confidence vs WER
    fig, ax = plt.subplots(figsize=(10, 6))
    for model in df['model_name'].unique():
        model_data = df[df['model_name'] == model]
        ax.scatter(model_data['avg_confidence'], model_data['wer'], 
                  label=model, alpha=0.6, s=50)
    ax.set_xlabel('Average Confidence', fontsize=12)
    ax.set_ylabel('WER', fontsize=12)
    ax.set_title('Confidence vs WER Analysis', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(output_dir / 'confidence_vs_wer.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 4. Calibration Analysis
    fig, ax = plt.subplots(figsize=(10, 6))
    model_ece = df.groupby('model_name')['ece'].mean().sort_values()
    ax.barh(model_ece.index, model_ece.values, color='coral')
    ax.set_xlabel('Expected Calibration Error (ECE)', fontsize=12)
    ax.set_title('Confidence Calibration by Model', fontsize=14, fontweight='bold')
    ax.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.savefig(output_dir / 'calibration_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Visualizations saved to: {output_dir}")
    
    # Generate final report
    print("\n[6/6] Generating final report...")
    report_file = output_dir / "ITERATION1_REPORT.txt"
    
    with open(report_file, 'w', encoding='utf-8') as f:
        f.write("="*80 + "\n")
        f.write("CORAL PROJECT - ITERATION 1 EVALUATION REPORT\n")
        f.write("Baseline Establishment & Confidence Calibration Analysis\n")
        f.write("="*80 + "\n\n")
        
        f.write(f"Evaluation Date: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Total Samples: {len(test_samples)}\n")
        f.write(f"Models Evaluated: {len(models_to_test)}\n")
        f.write(f"Total Audio Duration: {df['duration'].sum():.2f} seconds\n\n")
        
        f.write("-"*80 + "\n")
        f.write("1. BASELINE WORD ERROR RATES (WER)\n")
        f.write("-"*80 + "\n\n")
        
        wer_summary = df.groupby('model_name')['wer'].agg(['mean', 'std', 'min', 'max'])
        f.write(wer_summary.to_string())
        f.write("\n\n")
        
        best_model = df.groupby('model_name')['wer'].mean().idxmin()
        best_wer = df.groupby('model_name')['wer'].mean().min()
        f.write(f"BEST PERFORMING MODEL: {best_model}\n")
        f.write(f"Baseline WER: {best_wer:.4f} ({best_wer*100:.2f}%)\n\n")
        
        f.write("-"*80 + "\n")
        f.write("2. CHARACTER ERROR RATES (CER)\n")
        f.write("-"*80 + "\n\n")
        
        cer_summary = df.groupby('model_name')['cer'].agg(['mean', 'std'])
        f.write(cer_summary.to_string())
        f.write("\n\n")
        
        f.write("-"*80 + "\n")
        f.write("3. CONFIDENCE CALIBRATION METRICS\n")
        f.write("-"*80 + "\n\n")
        
        cal_summary = df.groupby('model_name').agg({
            'avg_confidence': ['mean', 'std'],
            'ece': ['mean', 'std']
        })
        f.write(cal_summary.to_string())
        f.write("\n\n")
        
        best_cal_model = df.groupby('model_name')['ece'].mean().idxmin()
        best_ece = df.groupby('model_name')['ece'].mean().min()
        f.write(f"BEST CALIBRATED MODEL: {best_cal_model}\n")
        f.write(f"ECE: {best_ece:.4f}\n\n")
        
        f.write("-"*80 + "\n")
        f.write("4. KEY FINDINGS & INSIGHTS\n")
        f.write("-"*80 + "\n\n")
        
        f.write(f"• Current state-of-the-art WER: {best_wer*100:.2f}%\n")
        f.write(f"• WER improvement target for Iteration 2-4: < {(best_wer*0.9)*100:.2f}%\n")
        f.write(f"• Average confidence scores range: ")
        f.write(f"{df['avg_confidence'].min():.3f} - {df['avg_confidence'].max():.3f}\n")
        f.write(f"• Calibration quality varies across models (ECE range: ")
        f.write(f"{df.groupby('model_name')['ece'].mean().min():.4f} - ")
        f.write(f"{df.groupby('model_name')['ece'].mean().max():.4f})\n\n")
        
        f.write("-"*80 + "\n")
        f.write("5. ITERATION 1 DELIVERABLES CHECKLIST\n")
        f.write("-"*80 + "\n\n")
        
        f.write("✓ Multi-model ASR integration complete\n")
        f.write("✓ Word-level confidence extraction implemented\n")
        f.write("✓ Baseline WER established for all models\n")
        f.write("✓ Confidence calibration metrics computed\n")
        f.write("✓ Comparative visualizations generated\n")
        f.write("✓ Working pipeline produces confidence-annotated hypotheses\n\n")
        
        f.write("-"*80 + "\n")
        f.write("6. NEXT STEPS (ITERATION 2)\n")
        f.write("-"*80 + "\n\n")
        
        f.write("• Develop instruction prompts for black-box LLM\n")
        f.write("• Test different prompt formulations\n")
        f.write("• Implement confidence presentation formats\n")
        f.write("• Begin hypothesis fusion experiments\n")
        f.write(f"• Target: Reduce WER below {(best_wer*0.9)*100:.2f}%\n\n")
        
        f.write("="*80 + "\n")
        f.write("END OF REPORT\n")
        f.write("="*80 + "\n")
    
    print(f"Final report saved: {report_file}")
    
    print("\n" + "="*80)
    print("ITERATION 1 EVALUATION COMPLETE!")
    print("="*80)
    print(f"\nResults directory: {output_dir.absolute()}")
    print("\nGenerated files:")
    print(f"  • detailed_results.csv - Per-sample evaluation results")
    print(f"  • aggregate_metrics.csv - Statistical summary by model")
    print(f"  • wer_comparison.png - Model performance comparison")
    print(f"  • wer_distribution.png - Error distribution analysis")
    print(f"  • confidence_vs_wer.png - Calibration visualization")
    print(f"  • calibration_comparison.png - ECE by model")
    print(f"  • ITERATION1_REPORT.txt - Comprehensive evaluation report")
    print("="*80)
    
    return df, aggregate

# ============================================================================
# SECTION 7: EXECUTION
# ============================================================================

if __name__ == "__main__":
    # Configure your dataset path
    DATASET_PATH = "/kaggle/input/urdudataset/15026046341 15026046337/cv-corpus-22.0-delta-2025-06-20/ur"
    
    # Run evaluation (start with 50 samples, increase for final run)
    results_df, aggregate_metrics = CORAL_Iteration1_Baseline_Evaluation(
        dataset_path=DATASET_PATH,
        output_dir="./iteration1_results",
        max_samples=50  # Increase to 100-200 for comprehensive evaluation
    )
    
    print("\n✅ Iteration 1 complete! Review the results in ./iteration1_results/")