In [None]:
"""
Urdu ASR Wrapper for Multiple Models
Supports 8 diverse ASR models for Urdu speech recognition
Optimized for Kaggle CPU/GPU notebooks with one-at-a-time loading
"""

import torch
import gc
import librosa
import soundfile as sf
import numpy as np
from pathlib import Path
from typing import List, Tuple, Dict
import warnings
warnings.filterwarnings('ignore')

# Transformers imports
from transformers import (
    WhisperProcessor, 
    WhisperForConditionalGeneration,
    Wav2Vec2Processor, 
    Wav2Vec2ForCTC,
    SeamlessM4TForSpeechToText,
    SeamlessM4TProcessor,
    AutoProcessor,
    AutoModelForCTC
)


class UrduASRWrapper:
    """
    Unified wrapper for multiple Urdu ASR models.
    Handles audio preprocessing, model loading, and word-probability extraction.
    """
    
    SUPPORTED_MODELS = {
        "whisper-large": "openai/whisper-large-v3",
        "whisper-medium": "openai/whisper-medium",
        "whisper-small": "openai/whisper-small",
        "seamless-large": "facebook/seamless-m4t-v2-large",
        "seamless-medium": "facebook/seamless-m4t-medium",
        "mms-1b": "facebook/mms-1b-all",
        "mms-300m": "facebook/mms-300m",
        "wav2vec2-urdu": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu"
    }
    
    def __init__(self, device: str = None):
        """
        Initialize the wrapper.
        
        Args:
            device: 'cuda', 'cpu', or None (auto-detect)
        """
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        
        print(f"🚀 ASR Wrapper initialized on: {self.device}")
        
        self.current_model = None
        self.processor = None
        self.current_model_name = None
    
    def _preprocess_audio(self, file_path: str, target_sr: int = 16000) -> np.ndarray:
        """
        Convert audio file to the required format.
        Handles MP3, MP4, WAV, and other formats.
        
        Args:
            file_path: Path to audio file
            target_sr: Target sample rate (default 16kHz)
            
        Returns:
            Audio array (mono, 16kHz)
        """
        try:
            # Load audio with librosa (handles all formats)
            audio, sr = librosa.load(file_path, sr=target_sr, mono=True)
            
            # Normalize audio to [-1, 1] range
            if audio.dtype != np.float32:
                audio = audio.astype(np.float32)
            
            # Normalize amplitude
            max_val = np.abs(audio).max()
            if max_val > 0:
                audio = audio / max_val
            
            return audio
            
        except Exception as e:
            raise ValueError(f"Error loading audio file {file_path}: {str(e)}")
    
    def _load_model(self, model_name: str):
        """
        Load a specific ASR model and its processor.
        
        Args:
            model_name: Key from SUPPORTED_MODELS
        """
        if model_name not in self.SUPPORTED_MODELS:
            raise ValueError(f"Model {model_name} not supported. Choose from: {list(self.SUPPORTED_MODELS.keys())}")
        
        model_id = self.SUPPORTED_MODELS[model_name]
        print(f"📥 Loading {model_name} ({model_id})...")
        
        try:
            # Load based on model family
            if "whisper" in model_name:
                self.processor = WhisperProcessor.from_pretrained(model_id)
                self.current_model = WhisperForConditionalGeneration.from_pretrained(model_id)
                
            elif "seamless" in model_name:
                self.processor = SeamlessM4TProcessor.from_pretrained(model_id)
                self.current_model = SeamlessM4TForSpeechToText.from_pretrained(model_id)
                
            elif "mms" in model_name:
                self.processor = AutoProcessor.from_pretrained(model_id)
                self.current_model = AutoModelForCTC.from_pretrained(model_id)
                
            elif "wav2vec2" in model_name:
                self.processor = Wav2Vec2Processor.from_pretrained(model_id)
                self.current_model = Wav2Vec2ForCTC.from_pretrained(model_id)
            
            # Move to device
            self.current_model = self.current_model.to(self.device)
            self.current_model.eval()
            self.current_model_name = model_name
            
            print(f"✅ {model_name} loaded successfully")
            
        except Exception as e:
            raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
    
    def _extract_whisper_probabilities(self, audio_array: np.ndarray) -> List[Tuple[str, float]]:
        """
        Extract word-probability pairs from Whisper models.
        
        Args:
            audio_array: Preprocessed audio
            
        Returns:
            List of (word, probability) tuples
        """
        # Prepare input
        input_features = self.processor(
            audio_array, 
            sampling_rate=16000, 
            return_tensors="pt"
        ).input_features.to(self.device)
        
        # Generate with word timestamps
        with torch.no_grad():
            predicted_ids = self.current_model.generate(
                input_features,
                return_dict_in_generate=True,
                output_scores=True
            )
        
        # Decode transcription
        transcription = self.processor.batch_decode(
            predicted_ids.sequences, 
            skip_special_tokens=True
        )[0]
        
        # Extract probabilities from scores
        word_probs = []
        if hasattr(predicted_ids, 'scores') and predicted_ids.scores:
            # Get average probability across all tokens
            all_probs = []
            for score in predicted_ids.scores:
                probs = torch.softmax(score, dim=-1)
                max_prob = probs.max().item()
                all_probs.append(max_prob)
            
            # Split transcription into words
            words = transcription.strip().split()
            
            # Assign probabilities to words (distribute evenly)
            if len(words) > 0 and len(all_probs) > 0:
                avg_prob = np.mean(all_probs)
                word_probs = [(word, avg_prob) for word in words]
            else:
                word_probs = [(word, 0.5) for word in words]
        else:
            # Fallback: assign default probability
            words = transcription.strip().split()
            word_probs = [(word, 0.8) for word in words]
        
        return word_probs
    
    def _extract_ctc_probabilities(self, audio_array: np.ndarray) -> List[Tuple[str, float]]:
        """
        Extract word-probability pairs from CTC models (MMS, Wav2Vec2).
        
        Args:
            audio_array: Preprocessed audio
            
        Returns:
            List of (word, probability) tuples
        """
        # Prepare input
        inputs = self.processor(
            audio_array,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        )
        
        input_values = inputs.input_values.to(self.device)
        
        # Get logits
        with torch.no_grad():
            logits = self.current_model(input_values).logits
        
        # Get probabilities
        probs = torch.softmax(logits, dim=-1)
        
        # Decode with CTC
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = self.processor.batch_decode(predicted_ids)[0]
        
        # Extract word-level probabilities
        words = transcription.strip().split()
        word_probs = []
        
        if len(words) > 0:
            # Calculate average confidence across the sequence
            max_probs = probs.max(dim=-1).values.squeeze()
            avg_confidence = max_probs.mean().item()
            
            # Assign to each word
            word_probs = [(word, avg_confidence) for word in words]
        
        return word_probs
    
    def _extract_seamless_probabilities(self, audio_array: np.ndarray) -> List[Tuple[str, float]]:
        """
        Extract word-probability pairs from Seamless-M4T models.
        
        Args:
            audio_array: Preprocessed audio
            
        Returns:
            List of (word, probability) tuples
        """
        # Prepare audio input
        audio_inputs = self.processor(
            audios=audio_array,
            sampling_rate=16000,
            return_tensors="pt"
        ).to(self.device)
        
        # Generate transcription
        with torch.no_grad():
            output = self.current_model.generate(
                **audio_inputs,
                tgt_lang="urd",  # Urdu language code
                return_dict_in_generate=True,
                output_scores=True
            )
        
        # Decode transcription
        transcription = self.processor.decode(
            output.sequences[0].tolist(),
            skip_special_tokens=True
        )
        
        # Extract probabilities
        word_probs = []
        if hasattr(output, 'scores') and output.scores:
            all_probs = []
            for score in output.scores:
                probs = torch.softmax(score, dim=-1)
                max_prob = probs.max().item()
                all_probs.append(max_prob)
            
            words = transcription.strip().split()
            if len(words) > 0 and len(all_probs) > 0:
                avg_prob = np.mean(all_probs)
                word_probs = [(word, avg_prob) for word in words]
            else:
                word_probs = [(word, 0.7) for word in words]
        else:
            words = transcription.strip().split()
            word_probs = [(word, 0.7) for word in words]
        
        return word_probs
    
    def _cleanup(self):
        """Clean up memory after processing."""
        if self.current_model is not None:
            del self.current_model
            self.current_model = None
        
        if self.processor is not None:
            del self.processor
            self.processor = None
        
        self.current_model_name = None
        
        # Clear cache
        if self.device == "cuda":
            torch.cuda.empty_cache()
        gc.collect()
    
    def word_probabilities(
        self, 
        audio_file_path: str, 
        model_name: str
    ) -> List[Tuple[str, float]]:
        """
        Main function: Process audio and return word-probability pairs.
        
        Args:
            audio_file_path: Path to audio file (MP3, MP4, WAV, etc.)
            model_name: Model to use (key from SUPPORTED_MODELS)
            
        Returns:
            List of (word, probability) tuples
            Example: [("سلام", 0.95), ("دنیا", 0.87), ("میں", 0.92)]
        """
        try:
            print(f"\n{'='*60}")
            print(f"🎯 Processing: {Path(audio_file_path).name}")
            print(f"🤖 Model: {model_name}")
            print(f"{'='*60}")
            
            # Step 1: Preprocess audio
            print("📊 Preprocessing audio...")
            audio_array = self._preprocess_audio(audio_file_path)
            print(f"✅ Audio loaded: {len(audio_array)/16000:.2f} seconds")
            
            # Step 2: Load model
            self._load_model(model_name)
            
            # Step 3: Extract probabilities based on model type
            print("🔄 Running inference...")
            
            if "whisper" in model_name:
                results = self._extract_whisper_probabilities(audio_array)
            elif "mms" in model_name or "wav2vec2" in model_name:
                results = self._extract_ctc_probabilities(audio_array)
            elif "seamless" in model_name:
                results = self._extract_seamless_probabilities(audio_array)
            else:
                raise ValueError(f"Unknown model type: {model_name}")
            
            print(f"✅ Transcription complete: {len(results)} words")
            print(f"📝 Preview: {' '.join([w for w, p in results[:5]])}...")
            
            # Step 4: Cleanup
            self._cleanup()
            print("🧹 Memory cleaned")
            
            return results
            
        except Exception as e:
            self._cleanup()
            raise RuntimeError(f"Error processing audio with {model_name}: {str(e)}")


# ============================================================================
# USAGE EXAMPLE FOR KAGGLE
# ============================================================================

def demo_usage():
    """Example usage for your FYP demo"""
    
    # Initialize wrapper
    wrapper = UrduASRWrapper(device='cpu')  # Use 'cuda' if GPU available
    
    # Your audio file path
    audio_path = "test_urdu_audio.mp4"
    
    # Process with all 8 models
    models_to_test = [
        "whisper-large",
        "whisper-medium",
        "whisper-small",
        "seamless-large",
        "seamless-medium",
        "mms-1b",
        "mms-300m",
        "wav2vec2-urdu"
    ]
    
    all_results = {}
    
    for model in models_to_test:
        try:
            results = wrapper.word_probabilities(audio_path, model)
            all_results[model] = results
            
            # Display results
            print(f"\n{model.upper()} Results:")
            print(f"Transcription: {' '.join([w for w, p in results])}")
            print(f"Avg Confidence: {np.mean([p for w, p in results]):.3f}")
            
        except Exception as e:
            print(f"❌ Error with {model}: {str(e)}")
            all_results[model] = []
    
    return all_results


# Quick test function
def test_single_model(audio_path: str, model_name: str = "whisper-small"):
    """Quick test with a single model"""
    wrapper = UrduASRWrapper()
    results = wrapper.word_probabilities(audio_path, model_name)
    
    print("\n" + "="*60)
    print("RESULTS:")
    print("="*60)
    for word, prob in results:
        print(f"{word:20s} | Confidence: {prob:.3f}")
    
    return results

In [None]:
if __name__ == "__main__":
    # Example: Test with Mozilla Common Voice Urdu sample
    print("Urdu ASR Wrapper - Ready for use!")
    print(f"Supported models: {list(UrduASRWrapper.SUPPORTED_MODELS.keys())}")

In [None]:
asr = UrduASRWrapper()

In [None]:
file_path = "/kaggle/input/urdudataset/15026046341 15026046337/cv-corpus-22.0-delta-2025-06-20/ur/clips/common_voice_ur_42810146.mp3"
probs = asr.word_probabilities(file_path,"whisper-large")

In [None]:
print(probs)