In [None]:
import sys
print("Installing dependencies...")
!{sys.executable} -m pip install -q editdistance

print("Loading libraries...")
import torch
import gc
import librosa
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Dict
import warnings
import json
import csv
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import editdistance
from collections import defaultdict
from datetime import datetime

warnings.filterwarnings('ignore')
sns.set_style("whitegrid")

from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    Wav2Vec2Processor, Wav2Vec2ForCTC,
    SeamlessM4TForSpeechToText, SeamlessM4TProcessor,
    AutoProcessor, AutoModelForCTC
)

print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

class UrduASRWrapper:
    SUPPORTED_MODELS = {
        "whisper-large": "openai/whisper-large-v3",
        "whisper-medium": "openai/whisper-medium",
        "whisper-small": "openai/whisper-small",
        "seamless-large": "facebook/seamless-m4t-v2-large",
        "seamless-medium": "facebook/seamless-m4t-medium",
        "mms-1b": "facebook/mms-1b-all",
        "mms-300m": "facebook/mms-300m",
        "wav2vec2-urdu": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu"
    }
    
    def __init__(self, device: str = None, use_fp16: bool = True):
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        self.use_fp16 = use_fp16 and self.device == "cuda"
        print(f"ASR Wrapper initialized on: {self.device} (FP16: {self.use_fp16})")
        self.current_model = None
        self.processor = None
        self.current_model_name = None
        self.audio_cache = {}
    
    def _preprocess_audio(self, file_path: str, target_sr: int = 16000) -> np.ndarray:
        if file_path in self.audio_cache:
            return self.audio_cache[file_path]
        
        try:
            audio, sr = librosa.load(file_path, sr=target_sr, mono=True)
            if audio.dtype != np.float32:
                audio = audio.astype(np.float32)
            max_val = np.abs(audio).max()
            if max_val > 0:
                audio = audio / max_val
            self.audio_cache[file_path] = audio
            return audio
        except Exception as e:
            raise ValueError(f"Error loading audio file {file_path}: {str(e)}")
    
    def _load_model(self, model_name: str):
        if self.current_model_name == model_name:
            return
        
        self._cleanup()
        
        if model_name not in self.SUPPORTED_MODELS:
            raise ValueError(f"Model {model_name} not supported. Choose from: {list(self.SUPPORTED_MODELS.keys())}")
        
        model_id = self.SUPPORTED_MODELS[model_name]
        print(f"Loading {model_name} ({model_id})...")
        
        try:
            if "whisper" in model_name:
                self.processor = WhisperProcessor.from_pretrained(model_id)
                self.current_model = WhisperForConditionalGeneration.from_pretrained(model_id)
            elif "seamless" in model_name:
                self.processor = SeamlessM4TProcessor.from_pretrained(model_id)
                self.current_model = SeamlessM4TForSpeechToText.from_pretrained(model_id)
            elif "mms" in model_name:
                self.processor = AutoProcessor.from_pretrained(model_id)
                self.current_model = AutoModelForCTC.from_pretrained(model_id)
            elif "wav2vec2" in model_name:
                self.processor = Wav2Vec2Processor.from_pretrained(model_id)
                self.current_model = Wav2Vec2ForCTC.from_pretrained(model_id)
            
            self.current_model = self.current_model.to(self.device)
            if self.use_fp16:
                self.current_model = self.current_model.half()
            self.current_model.eval()
            self.current_model_name = model_name
            print(f"{model_name} loaded successfully")
        except Exception as e:
            raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
    
    def _extract_whisper_probabilities(self, audio_array: np.ndarray) -> List[Tuple[str, float]]:
        input_features = self.processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features.to(self.device)
        if self.use_fp16:
            input_features = input_features.half()
        
        with torch.inference_mode():
            predicted_ids = self.current_model.generate(input_features, return_dict_in_generate=True, output_scores=True)
        
        transcription = self.processor.batch_decode(predicted_ids.sequences, skip_special_tokens=True)[0]
        word_probs = []
        
        if hasattr(predicted_ids, 'scores') and predicted_ids.scores:
            all_probs = []
            for score in predicted_ids.scores:
                probs = torch.softmax(score, dim=-1)
                max_prob = probs.max().item()
                all_probs.append(max_prob)
            words = transcription.strip().split()
            if len(words) > 0 and len(all_probs) > 0:
                avg_prob = np.mean(all_probs)
                word_probs = [(word, avg_prob) for word in words]
            else:
                word_probs = [(word, 0.5) for word in words]
        else:
            words = transcription.strip().split()
            word_probs = [(word, 0.8) for word in words]
        
        return word_probs
    
    def _extract_ctc_probabilities(self, audio_array: np.ndarray) -> List[Tuple[str, float]]:
        inputs = self.processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
        input_values = inputs.input_values.to(self.device)
        if self.use_fp16:
            input_values = input_values.half()
        
        with torch.inference_mode():
            logits = self.current_model(input_values).logits
        
        probs = torch.softmax(logits, dim=-1)
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = self.processor.batch_decode(predicted_ids)[0]
        words = transcription.strip().split()
        word_probs = []
        
        if len(words) > 0:
            max_probs = probs.max(dim=-1).values.squeeze()
            avg_confidence = max_probs.mean().item()
            word_probs = [(word, avg_confidence) for word in words]
        
        return word_probs
    
    def _extract_seamless_probabilities(self, audio_array: np.ndarray) -> List[Tuple[str, float]]:
        audio_inputs = self.processor(audios=audio_array, sampling_rate=16000, return_tensors="pt").to(self.device)
        
        with torch.inference_mode():
            output = self.current_model.generate(**audio_inputs, tgt_lang="urd", return_dict_in_generate=True, output_scores=True)
        
        transcription = self.processor.decode(output.sequences[0].tolist(), skip_special_tokens=True)
        word_probs = []
        
        if hasattr(output, 'scores') and output.scores:
            all_probs = []
            for score in output.scores:
                probs = torch.softmax(score, dim=-1)
                max_prob = probs.max().item()
                all_probs.append(max_prob)
            words = transcription.strip().split()
            if len(words) > 0 and len(all_probs) > 0:
                avg_prob = np.mean(all_probs)
                word_probs = [(word, avg_prob) for word in words]
            else:
                word_probs = [(word, 0.7) for word in words]
        else:
            words = transcription.strip().split()
            word_probs = [(word, 0.7) for word in words]
        
        return word_probs
    
    def _cleanup(self):
        if self.current_model is not None:
            del self.current_model
            self.current_model = None
        if self.processor is not None:
            del self.processor
            self.processor = None
        self.current_model_name = None
        if self.device == "cuda":
            torch.cuda.empty_cache()
        gc.collect()
    
    def clear_audio_cache(self):
        self.audio_cache.clear()
        gc.collect()
    
    def word_probabilities(self, audio_file_path: str, model_name: str) -> List[Tuple[str, float]]:
        try:
            audio_array = self._preprocess_audio(audio_file_path)
            self._load_model(model_name)
            
            if "whisper" in model_name:
                results = self._extract_whisper_probabilities(audio_array)
            elif "mms" in model_name or "wav2vec2" in model_name:
                results = self._extract_ctc_probabilities(audio_array)
            elif "seamless" in model_name:
                results = self._extract_seamless_probabilities(audio_array)
            else:
                raise ValueError(f"Unknown model type: {model_name}")
            
            return results
        except Exception as e:
            raise RuntimeError(f"Error processing audio with {model_name}: {str(e)}")

def compute_wer(reference: str, hypothesis: str) -> float:
    ref_words = reference.split()
    hyp_words = hypothesis.split()
    if len(ref_words) == 0:
        return 0.0 if len(hyp_words) == 0 else 1.0
    return editdistance.eval(ref_words, hyp_words) / len(ref_words)

def compute_cer(reference: str, hypothesis: str) -> float:
    if len(reference) == 0:
        return 0.0 if len(hypothesis) == 0 else 1.0
    return editdistance.eval(reference, hypothesis) / len(reference)

def compute_ece(confidences: np.ndarray, accuracies: np.ndarray, n_bins: int = 10) -> float:
    if len(confidences) == 0:
        return 0.0
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    ece = 0.0
    for i in range(n_bins):
        bin_lower = bin_boundaries[i]
        bin_upper = bin_boundaries[i + 1]
        in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
        prop_in_bin = in_bin.mean()
        if prop_in_bin > 0:
            accuracy_in_bin = accuracies[in_bin].mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    return ece

CONFIG = {
    'DATASET_PATH': "/kaggle/input/coraldataset/ur",
    'MAX_SAMPLES': 10,
    'OUTPUT_DIR': './iteration1_results',
    'MODELS': [
        "whisper-small",
        "whisper-medium",
        "whisper-large",
        "wav2vec2-urdu",
        "mms-300m",
    ],
    'DEVICE': 'cuda' if torch.cuda.is_available() else 'cpu',
    'USE_FP16': True,
    'BATCH_CLEANUP': True
}

print(f"\nConfiguration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

def load_test_samples(dataset_path, max_samples):
    dataset_path = Path(dataset_path)
    tsv_file = dataset_path / "other.tsv"
    
    if not tsv_file.exists():
        raise FileNotFoundError(f"other.tsv not found at {tsv_file}")
    
    samples = []
    with open(tsv_file, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f, delimiter='\t')
        for i, row in enumerate(reader):
            if i >= max_samples:
                break
            audio_path = dataset_path / "clips" / row['path']
            if audio_path.exists():
                samples.append({
                    'audio_id': row['path'],
                    'audio_path': str(audio_path),
                    'reference': row['sentence'],
                    'duration': 0.0
                })
    return samples

def evaluate_model(asr_wrapper, model_name, test_samples):
    results = []
    try:
        for sample in tqdm(test_samples, desc=model_name):
            try:
                word_probs = asr_wrapper.word_probabilities(sample['audio_path'], model_name)
                hypothesis = ' '.join([w for w, p in word_probs])
                reference = sample['reference']
                wer = compute_wer(reference, hypothesis)
                cer = compute_cer(reference, hypothesis)
                avg_conf = np.mean([p for w, p in word_probs]) if word_probs else 0.0
                ref_words = reference.split()
                confidences = [p for w, p in word_probs]
                accuracies = [1.0 if i < len(ref_words) and w == ref_words[i] else 0.0 
                             for i, (w, p) in enumerate(word_probs)]
                ece = compute_ece(np.array(confidences), np.array(accuracies)) if confidences else 0.0
                results.append({
                    'audio_id': sample['audio_id'],
                    'model_name': model_name,
                    'reference': reference,
                    'hypothesis': hypothesis,
                    'wer': wer,
                    'cer': cer,
                    'avg_confidence': avg_conf,
                    'ece': ece,
                    'duration': sample['duration']
                })
            except Exception as e:
                print(f"\nError on {sample['audio_id']}: {str(e)}")
                continue
    finally:
        asr_wrapper._cleanup()
    
    return results

def generate_plots(df, output_dir):
    output_dir = Path(output_dir)
    
    plt.figure(figsize=(10, 6))
    model_wer = df.groupby('model_name')['wer'].mean().sort_values()
    plt.barh(model_wer.index, model_wer.values, color='steelblue')
    plt.xlabel('Word Error Rate (WER)')
    plt.title('Model Comparison: Average WER', fontweight='bold')
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.savefig(output_dir / 'wer_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    plt.figure(figsize=(12, 6))
    df.boxplot(column='wer', by='model_name')
    plt.ylabel('WER')
    plt.title('WER Distribution by Model', fontweight='bold')
    plt.suptitle('')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_dir / 'wer_distribution.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    plt.figure(figsize=(10, 6))
    model_ece = df.groupby('model_name')['ece'].mean().sort_values()
    plt.barh(model_ece.index, model_ece.values, color='coral')
    plt.xlabel('Expected Calibration Error (ECE)')
    plt.title('Confidence Calibration by Model', fontweight='bold')
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.savefig(output_dir / 'calibration.png', dpi=300, bbox_inches='tight')
    plt.close()

def run_iteration1():
    print("\n" + "="*80)
    print("CORAL ITERATION 1: BASELINE EVALUATION")
    print("="*80)
    print(f"\nTimestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
    output_dir = Path(CONFIG['OUTPUT_DIR'])
    output_dir.mkdir(exist_ok=True, parents=True)
    
    print(f"\n[1/5] Loading test dataset from {CONFIG['DATASET_PATH']}...")
    test_samples = load_test_samples(CONFIG['DATASET_PATH'], CONFIG['MAX_SAMPLES'])
    print(f"Loaded {len(test_samples)} test samples")
    
    print(f"\n[2/5] Initializing ASR wrapper on {CONFIG['DEVICE']}...")
    asr_wrapper = UrduASRWrapper(device=CONFIG['DEVICE'], use_fp16=CONFIG['USE_FP16'])
    
    print(f"\n[3/5] Evaluating {len(CONFIG['MODELS'])} models...")
    all_results = []
    for model in CONFIG['MODELS']:
        print(f"\n{'='*60}")
        print(f"Model: {model}")
        print(f"{'='*60}")
        model_results = evaluate_model(asr_wrapper, model, test_samples)
        all_results.extend(model_results)
        if CONFIG['BATCH_CLEANUP']:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    asr_wrapper.clear_audio_cache()
    
    df = pd.DataFrame(all_results)
    
    print(f"\n[4/5] Saving results...")
    df.to_csv(output_dir / 'detailed_results.csv', index=False, encoding='utf-8')
    
    aggregate = df.groupby('model_name').agg({
        'wer': ['mean', 'std', 'min', 'max'],
        'cer': ['mean', 'std'],
        'avg_confidence': ['mean', 'std'],
        'ece': ['mean', 'std'],
        'duration': 'sum'
    }).round(4)
    aggregate.to_csv(output_dir / 'aggregate_metrics.csv')
    
    print("\n" + "="*80)
    print("AGGREGATE METRICS")
    print("="*80)
    print(aggregate)
    
    print(f"\n[5/5] Generating visualizations...")
    generate_plots(df, output_dir)
    
    report_file = output_dir / 'ITERATION1_REPORT.txt'
    best_model = df.groupby('model_name')['wer'].mean().idxmin()
    best_wer = df.groupby('model_name')['wer'].mean().min()
    
    with open(report_file, 'w', encoding='utf-8') as f:
        f.write("="*80 + "\n")
        f.write("CORAL PROJECT - ITERATION 1 EVALUATION REPORT\n")
        f.write("="*80 + "\n\n")
        f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Samples: {len(test_samples)}\n")
        f.write(f"Models: {len(CONFIG['MODELS'])}\n\n")
        f.write("-"*80 + "\n")
        f.write("BASELINE WER BY MODEL\n")
        f.write("-"*80 + "\n\n")
        f.write(df.groupby('model_name')['wer'].describe().to_string())
        f.write("\n\n")
        f.write(f"BEST MODEL: {best_model}\n")
        f.write(f"BASELINE WER: {best_wer:.4f} ({best_wer*100:.2f}%)\n\n")
        f.write("-"*80 + "\n")
        f.write("CALIBRATION ANALYSIS\n")
        f.write("-"*80 + "\n\n")
        f.write(df.groupby('model_name')['ece'].describe().to_string())
        f.write("\n\n")
    
    print("\n" + "="*80)
    print("ITERATION 1 COMPLETE")
    print("="*80)
    print(f"\nResults saved to: {output_dir.absolute()}")
    print(f"Best Model: {best_model}")
    print(f"Baseline WER: {best_wer*100:.2f}%")
    print(f"Samples Evaluated: {len(df)}")
    print("="*80 + "\n")
    
    return df, aggregate

if __name__ == "__main__":
    try:
        results_df, aggregate_metrics = run_iteration1()
        print("\nEvaluation successful!")
        print("Review the results in ./iteration1_results/")
    except Exception as e:
        print(f"\n\nERROR: {str(e)}")
        raise
