In [None]:
# ================================================================
# 1. SETUP AND IMPORTS
# ================================================================

# Import necessary libraries
import os
import torch
import torch.nn as nn
import torchaudio
import soundfile as sf
import pandas as pd
import numpy as np
import whisper
import librosa
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import torch.optim as optim
import warnings
import matplotlib.pyplot as plt
import json
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm

# Add this helper function for progress bars
def safe_tqdm(iterable, desc=None, **kwargs):
    """A tqdm wrapper that doesn't break in various environments"""
    try:
        # Try to use regular tqdm with minimal features
        return tqdm(iterable, desc=desc, miniters=1, mininterval=0.5, **kwargs)
    except Exception as e:
        # If all else fails, return the iterable without progress tracking
        print(f"Progress bar disabled: {str(e)}")
        return iterable

# Suppress warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", message="PySoundFile failed")
warnings.filterwarnings("ignore", message="n_fft=.*is too large for input signal")
warnings.filterwarnings("ignore", message="librosa.core.audio.__audioread_load")
warnings.filterwarnings("ignore", message="frames must be specified for non-seekable files")

# Check CUDA
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    
# Set device to GPU if available, else CPU
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Set paths for the dataset
MP3_DIR = "G44/data/mp3s/Language Detection Dataset"  # Path to MP3 files
WAV_DIR = "G44/wavs"                                 # Output path for WAV files
SPLITS_DIR = "G44/splits"                            # Output path for data splits
OUTPUT_DIR = "G44/models"                            # Output path for models

# Create directories if they don't exist
os.makedirs(WAV_DIR, exist_ok=True)
os.makedirs(SPLITS_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ================================================================
# 2. LANGUAGES AND MAPPINGS
# ================================================================

# List of Indian languages for the project (exact casing as in directory)
LANGS = [
    'Bengali', 'Gujarati', 'Hindi', 'Kannada', 'Malayalam',
    'Marathi', 'Punjabi', 'Tamil', 'Telugu', 'Urdu'
]

# Create a mapping from language to ID (use lowercase for internal processing)
LANG2ID = {lang.lower(): idx for idx, lang in enumerate(LANGS)}
ID2LANG = {idx: lang.lower() for idx, lang in enumerate(LANGS)}

print("Languages:", LANGS)
print("Total languages:", len(LANGS))

# ================================================================
# 3. AUDIO AUGMENTATION CLASSES
# ================================================================

class RandomApply(nn.Module):
    """Randomly apply a list of transformations with a given probability"""
    def __init__(self, transforms, p=0.5):
        super().__init__()
        self.transforms = transforms
        self.p = p
        
    def forward(self, x):
        if self.p < torch.rand(1):
            return x
        for t in self.transforms:
            x = t(x)
        return x

class AddBackgroundNoise(nn.Module):
    """Add background noise to audio at a given SNR level"""
    def __init__(self, snr_db_range=(5, 20)):
        super().__init__()
        self.snr_db_range = snr_db_range
        
    def forward(self, x):
        # Generate white noise
        noise = torch.randn_like(x)
        
        # Calculate signal and noise power
        signal_power = torch.mean(x ** 2)
        noise_power = torch.mean(noise ** 2)
        
        # Random SNR from range
        snr_db = torch.tensor(
            np.random.uniform(*self.snr_db_range)
        )
        
        # Calculate noise scaling factor
        snr = 10 ** (snr_db / 10)
        scale = torch.sqrt(signal_power / (noise_power * snr))
        
        # Add scaled noise
        return x + scale * noise

class SpeedPerturb(nn.Module):
    """Apply speed perturbation to audio (GPU compatible)"""
    def __init__(self, speed_range=(0.9, 1.1), sample_rate=16000):
        super().__init__()
        self.speed_range = speed_range
        self.sample_rate = sample_rate
        
    def forward(self, x):
        # Random speed factor
        speed_factor = np.random.uniform(*self.speed_range)
        
        # Convert tensor to numpy if needed
        x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x
        
        # Use librosa for time stretching
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            y = librosa.effects.time_stretch(x_np, rate=1.0/speed_factor)
        
        # Convert back to tensor
        y = torch.tensor(y, dtype=torch.float32, device=x.device if isinstance(x, torch.Tensor) else None)
        
        # Handle length changes
        if len(y) > len(x):
            y = y[:len(x)]
        elif len(y) < len(x):
            # Pad with zeros
            padding = torch.zeros(len(x) - len(y), device=y.device)
            y = torch.cat([y, padding])
            
        return y

class PitchShift(nn.Module):
    """Apply pitch shifting to audio (GPU compatible)"""
    def __init__(self, pitch_range=(-2, 2), sample_rate=16000):
        super().__init__()
        self.pitch_range = pitch_range
        self.sample_rate = sample_rate
        
    def forward(self, x):
        # Random pitch shift in semitones
        n_semitones = np.random.uniform(*self.pitch_range)
        
        # Convert tensor to numpy if needed
        x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x
        
        # Use librosa for pitch shifting
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            y = librosa.effects.pitch_shift(
                x_np, 
                sr=self.sample_rate,
                n_steps=n_semitones
            )
        
        # Convert back to tensor
        y = torch.tensor(y, dtype=torch.float32, device=x.device if isinstance(x, torch.Tensor) else None)
        
        # Handle length changes
        if len(y) > len(x):
            y = y[:len(x)]
        elif len(y) < len(x):
            padding = torch.zeros(len(x) - len(y), device=y.device)
            y = torch.cat([y, padding])
            
        return y

# ================================================================
# 4. DATASET IMPLEMENTATION
# ================================================================

class WhisperFeatureDataset(Dataset):
    """Dataset for extracting features from audio files using Whisper model"""
    def __init__(self, csv_path, augment=False, max_duration=30):
        """
        Args:
            csv_path: Path to CSV file with 'path' and 'lang' columns
            augment: Whether to apply augmentation to audio
            max_duration: Maximum audio duration in seconds (for padding/truncation)
        """
        df = pd.read_csv(csv_path)
        self.paths = df['path'].tolist()
        self.labels = [LANG2ID[l.lower()] for l in df['lang']]
        self.augment = augment
        self.max_duration = max_duration
        
        # Use Whisper's feature extraction
        self.mel_fn = whisper.audio.log_mel_spectrogram
        
        # Audio augmentations
        self.transforms = nn.Sequential(
            RandomApply([AddBackgroundNoise()], p=0.5),
            RandomApply([SpeedPerturb()], p=0.4),
            RandomApply([PitchShift()], p=0.3),
        )
        
    def load_audio(self, path):
        """Robust audio loading function with updated methods"""
        # Suppress warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            
            try:
                # First try soundfile (most efficient for WAV)
                with sf.SoundFile(path) as sound_file:
                    wav = sound_file.read(dtype='float32')
                    sr = sound_file.samplerate
                    
                    # Convert to mono if stereo
                    if wav.ndim > 1:
                        wav = wav.mean(axis=1)
            except Exception as e:
                # Try torchaudio next
                try:
                    # Try to set best backend
                    try:
                        torchaudio.set_audio_backend("sox_io")  # or "soundfile"
                    except:
                        pass
                    
                    waveform, sr = torchaudio.load(path)
                    if waveform.shape[0] > 1:  # Multi-channel, convert to mono
                        wav = waveform.mean(dim=0).numpy()
                    else:
                        wav = waveform.squeeze(0).numpy()
                except Exception as e2:
                    # Fall back to librosa with explicit parameters
                    try:
                        wav, sr = librosa.load(
                            path,
                            sr=None,
                            mono=True,
                            offset=0.0,
                            duration=None,
                            dtype=np.float32,
                            res_type='kaiser_best'
                        )
                    except Exception as e3:
                        print(f"Failed to load {path}: {e3}")
                        # Return empty audio as fallback
                        sr = 16000
                        wav = np.zeros(sr)
                
        return torch.tensor(wav).float(), sr
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        # Load audio file with robust loader
        wav, sr = self.load_audio(self.paths[idx])
        
        # Apply augmentations if needed
        if self.augment:
            wav = self.transforms(wav)
        
        # Resample if needed
        if sr != 16000:
            wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=16000)
        
        # Pad or trim to expected length (whisper expects 30 s max)
        wav = whisper.audio.pad_or_trim(wav, length=16000 * self.max_duration)
        
        # Extract mel spectrogram features
        mel = self.mel_fn(wav)
        
        return mel, self.labels[idx]

def get_dataloaders(train_csv, val_csv, test_csv=None, batch_size=32, num_workers=4):
    """Create dataloaders for training, validation, and testing"""
    # Create datasets
    train_ds = WhisperFeatureDataset(train_csv, augment=True)
    val_ds = WhisperFeatureDataset(val_csv, augment=False)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_ds, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=num_workers, 
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_ds, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=num_workers
    )
    
    # Create test dataloader if test CSV is provided
    test_loader = None
    if test_csv:
        test_ds = WhisperFeatureDataset(test_csv, augment=False)
        test_loader = DataLoader(
            test_ds, 
            batch_size=batch_size, 
            shuffle=False,
            num_workers=num_workers
        )
    
    return train_loader, val_loader, test_loader

# ================================================================
# 5. MODEL ARCHITECTURE
# ================================================================

class LIDWhisper(nn.Module):
    """
    Whisper model adapted for language identification
    Uses the encoder from Whisper and replaces the decoder with a classifier
    """
    def __init__(self, whisper_model_name='base', num_languages=10, freeze_encoder=True):
        """
        Args:
            whisper_model_name: Whisper model size ('tiny', 'base', 'small', etc.)
            num_languages: Number of languages to classify
            freeze_encoder: Whether to freeze the Whisper encoder weights
        """
        super().__init__()
        
        # Load the pretrained Whisper model
        whisper_model = whisper.load_model(whisper_model_name)
        
        # Extract the encoder
        self.encoder = whisper_model.encoder
        
        # Get model dimensions
        d_model = whisper_model.dims.n_audio_state
        
        # Freeze encoder weights if specified
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad_(False)
        
        # Replace decoder with a classifier
        self.classifier = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_languages)
        )
    
    def forward(self, mel):
        # Pass through encoder
        encoder_out = self.encoder(mel)
        
        # Mean pooling over time dimension
        pooled = encoder_out.mean(dim=1)
        
        # Classify
        logits = self.classifier(pooled)
        
        return logits

def create_model(whisper_size='base', num_languages=10, freeze_encoder=True, device='cuda'):
    """
    Create a Whisper-based language identification model
    
    Args:
        whisper_size: Whisper model size ('tiny', 'base', 'small', etc.)
        num_languages: Number of languages to classify
        freeze_encoder: Whether to freeze the encoder weights
        device: Device to load the model onto
        
    Returns:
        model: LIDWhisper model
    """
    model = LIDWhisper(whisper_size, num_languages, freeze_encoder)
    return model.to(device)

# ================================================================
# 6. TRAINING AND EVALUATION FUNCTIONS
# ================================================================

def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train the model for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    # Training loop
    progress_bar = safe_tqdm(dataloader, desc="Training")
    batch_count = 0
    for mel, labels in progress_bar:
        # Move data to device
        mel, labels = mel.to(device), labels.to(device)
        
        # Forward pass
        logits = model(mel)
        loss = criterion(logits, labels)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update metrics
        total_loss += loss.item()
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Update progress bar
        batch_count += 1
        if batch_count % 1000 == 0:
            print(f"Batch {batch_count}: loss={total_loss/batch_count:.4f}, acc={100*correct/total:.2f}%")
    
    return {
        "loss": total_loss / len(dataloader),
        "accuracy": correct / total
    }

def evaluate(model, dataloader, criterion, device):
    """Evaluate the model on validation/test set"""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for mel, labels in safe_tqdm(dataloader, desc="Evaluating"):
            # Move data to device
            mel, labels = mel.to(device), labels.to(device)
            
            # Forward pass
            logits = model(mel)
            loss = criterion(logits, labels)
            
            # Collect predictions and labels
            _, preds = torch.max(logits, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # Update loss
            total_loss += loss.item()
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    macro_f1 = f1_score(all_labels, all_preds, average="macro")
    
    return {
        "loss": total_loss / len(dataloader),
        "accuracy": accuracy,
        "macro_f1": macro_f1,
        "preds": all_preds,
        "labels": all_labels
    }

def save_model(model, path):
    """Save model checkpoint"""
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")
    
def load_model(model, path, device):
    """Load model checkpoint"""
    model.load_state_dict(torch.load(path, map_location=device))
    print(f"Model loaded from {path}")
    return model

# ================================================================
# 7. DATA PREPROCESSING FUNCTIONS
# ================================================================

def convert_mp3_to_wav(mp3_path, wav_path, sr=16000):
    """Convert MP3 file to WAV with specified sample rate - FIXED VERSION"""
    try:
        # Try different loading methods in sequence
        audio, orig_sr = None, None
        
        # Suppress warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            
            # Try direct ffmpeg approach first for MP3 files
            if mp3_path.lower().endswith('.mp3'):
                try:
                    # Use torchaudio with ffmpeg backend for MP3
                    try:
                        torchaudio.set_audio_backend("sox_io")  # or "soundfile" depending on availability
                    except:
                        pass
                        
                    waveform, orig_sr = torchaudio.load(mp3_path)
                    # Convert to mono if needed
                    if waveform.shape[0] > 1:
                        audio = waveform.mean(dim=0).numpy()
                    else:
                        audio = waveform.squeeze(0).numpy()
                except:
                    pass
            
            # If torchaudio failed or file isn't MP3, try soundfile
            if audio is None:
                try:
                    with sf.SoundFile(mp3_path) as sound_file:
                        audio = sound_file.read(dtype='float32')
                        orig_sr = sound_file.samplerate
                        # Convert to mono if stereo
                        if audio.ndim > 1:
                            audio = audio.mean(axis=1)
                except:
                    pass
            
            # Last resort: try librosa with explicit parameters to avoid deprecated functions
            if audio is None:
                try:
                    # Updated librosa loading to avoid deprecated functions
                    audio, orig_sr = librosa.load(
                        mp3_path,
                        sr=None,
                        mono=True,
                        duration=None,  # No duration limit
                        offset=0.0,     # Start from beginning
                        dtype=np.float32,
                        res_type='kaiser_best'  # High quality resampling
                    )
                except Exception as e:
                    return False, f"{mp3_path}: Failed to load with librosa: {str(e)}"
        
        # If we still don't have audio, return failure
        if audio is None:
            return False, f"{mp3_path}: Failed to load audio with all methods"
            
        # Resample if needed - updated to use current preferred approach
        if orig_sr != sr:
            audio = librosa.resample(
                y=audio, 
                orig_sr=orig_sr, 
                target_sr=sr,
                res_type='kaiser_best'
            )
        
        # Make sure the directory exists
        os.makedirs(os.path.dirname(wav_path), exist_ok=True)
        
        # Save as wav
        sf.write(wav_path, audio, sr)
        return True, mp3_path
    except Exception as e:
        return False, f"{mp3_path}: {str(e)}"

def convert_dataset(in_dir, out_dir, n_workers=4):
    """Convert all MP3 files in the input directory to WAV files"""
    # Get all mp3 files
    mp3_files = []
    
    # Check if the root directory exists
    if not os.path.exists(in_dir):
        print(f"Error: Input directory {in_dir} does not exist")
        return []
        
    print(f"Searching for MP3 files in {in_dir}")
    
    # List all files in the directory
    try:
        contents = os.listdir(in_dir)
        print(f"Contents of {in_dir}: {contents}")
    except Exception as e:
        print(f"Error listing directory contents: {e}")
        return []
    
    # Check each language directory
    for lang in LANGS:
        lang_dir = os.path.join(in_dir, lang)
        if not os.path.exists(lang_dir):
            print(f"Warning: {lang_dir} does not exist, skipping")
            continue
            
        print(f"Found language directory: {lang}")
        
        # Walk through the language directory
        for root, _, files in os.walk(lang_dir):
            for file in files:
                if file.endswith('.mp3'):
                    mp3_path = os.path.join(root, file)
                    # Get relative path to preserve directory structure
                    rel_path = os.path.relpath(mp3_path, in_dir)
                    # Replace extension
                    wav_path = os.path.join(out_dir, rel_path.replace('.mp3', '.wav'))
                    mp3_files.append((mp3_path, wav_path))
    
    print(f"Found {len(mp3_files)} MP3 files to convert")
    
    # Create output directory
    os.makedirs(out_dir, exist_ok=True)
    
    # If no files found, return early
    if len(mp3_files) == 0:
        print("No MP3 files found to convert")
        return []
    
    # Convert files in parallel
    results = []
    with ProcessPoolExecutor(max_workers=n_workers) as executor:
        futures = []
        for mp3_path, wav_path in mp3_files:
            futures.append(executor.submit(convert_mp3_to_wav, mp3_path, wav_path))
        
        # Monitor progress
        processed = 0
        total = len(futures)
        print(f"Converting {total} MP3 files to WAV...")
        for future in futures:
            success, msg = future.result()
            if not success:
                results.append(msg)
            processed += 1
            if processed % 1000 == 0 or processed == total:
                print(f"Processed {processed}/{total} files ({processed/total*100:.1f}%)")
    
    # Report errors
    if results:
        print(f"{len(results)} files failed to convert:")
        for msg in results[:10]:  # Show only first 10 errors
            print(f"  - {msg}")
        if len(results) > 10:
            print(f"  - ... and {len(results) - 10} more")
    
    print(f"Successfully converted {len(mp3_files) - len(results)} out of {len(mp3_files)} files")
    return results

def create_data_splits(data_dir, out_dir, split_ratio=(0.8, 0.1, 0.1), seed=42):
    """Create train/val/test splits from the WAV files"""
    # Find all wav files and their languages
    data = []
    
    # Check if data directory exists
    if not os.path.exists(data_dir):
        print(f"Error: Data directory {data_dir} does not exist")
        # Return empty DataFrames to avoid errors
        empty_df = pd.DataFrame(columns=['path', 'lang'])
        return empty_df, empty_df, empty_df
    
    for lang in LANGS:
        lang_dir = os.path.join(data_dir, lang)
        if not os.path.exists(lang_dir):
            print(f"Warning: {lang_dir} does not exist, skipping")
            continue
            
        for root, _, files in os.walk(lang_dir):
            for file in files:
                if file.endswith('.wav'):
                    wav_path = os.path.join(root, file)
                    data.append({'path': wav_path, 'lang': lang})
    
    # Convert to DataFrame
    df = pd.DataFrame(data)
    print(f"Total samples: {len(df)}")
    
    # If no data found, return empty DataFrames
    if len(df) == 0:
        print("No WAV files found. Skipping data splits creation.")
        empty_df = pd.DataFrame(columns=['path', 'lang'])
        return empty_df, empty_df, empty_df
    
    print(df['lang'].value_counts())
    
    # Shuffle data
    df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
    
    # Create splits for each language
    train_data = []
    val_data = []
    test_data = []
    
    for lang in LANGS:
        lang_df = df[df['lang'] == lang]
        if len(lang_df) == 0:
            continue
            
        # Calculate split sizes
        n_train = int(len(lang_df) * split_ratio[0])
        n_val = int(len(lang_df) * split_ratio[1])
        
        # Split the data
        train_data.append(lang_df.iloc[:n_train])
        val_data.append(lang_df.iloc[n_train:n_train+n_val])
        test_data.append(lang_df.iloc[n_train+n_val:])
    
    # Combine and shuffle
    if train_data:
        train_df = pd.concat(train_data).sample(frac=1, random_state=seed).reset_index(drop=True)
    else:
        train_df = pd.DataFrame(columns=['path', 'lang'])
        
    if val_data:
        val_df = pd.concat(val_data).sample(frac=1, random_state=seed).reset_index(drop=True)
    else:
        val_df = pd.DataFrame(columns=['path', 'lang'])
        
    if test_data:
        test_df = pd.concat(test_data).sample(frac=1, random_state=seed).reset_index(drop=True)
    else:
        test_df = pd.DataFrame(columns=['path', 'lang'])
    
    # Create output directory
    os.makedirs(out_dir, exist_ok=True)
    
    # Save splits
    train_df.to_csv(os.path.join(out_dir, 'train.csv'), index=False)
    val_df.to_csv(os.path.join(out_dir, 'val.csv'), index=False)
    test_df.to_csv(os.path.join(out_dir, 'test.csv'), index=False)
    
    print(f"Splits created and saved to {out_dir}:")
    print(f"  - Train: {len(train_df)} samples")
    print(f"  - Validation: {len(val_df)} samples")
    print(f"  - Test: {len(test_df)} samples")
    
    return train_df, val_df, test_df

# ================================================================
# 8. MAIN TRAINING PIPELINE
# ================================================================

def train_model(train_csv, val_csv, test_csv=None, 
                whisper_size='base', freeze_encoder=True,
                batch_size=32, epochs=5, lr=1e-4, weight_decay=1e-4,
                patience=3, num_workers=4, device='cuda', output_dir='models'):
    """Full training pipeline"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Check if training data is available
    if not os.path.exists(train_csv) or os.path.getsize(train_csv) == 0:
        print(f"Error: Training data file {train_csv} does not exist or is empty")
        return None, None
        
    if not os.path.exists(val_csv) or os.path.getsize(val_csv) == 0:
        print(f"Error: Validation data file {val_csv} does not exist or is empty")
        return None, None
    
    # Set device
    device = torch.device(device)
    print(f"Using device: {device}")
    
    # Get dataloaders
    train_loader, val_loader, test_loader = get_dataloaders(
        train_csv,
        val_csv,
        test_csv,
        batch_size=batch_size,
        num_workers=num_workers
    )
    
    # Create model
    model = create_model(
        whisper_size=whisper_size,
        num_languages=len(LANGS),
        freeze_encoder=freeze_encoder,
        device=device
    )
    
    # Print number of trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model: Whisper-{whisper_size} for LID")
    print(f"Trainable parameters: {trainable_params:,} / {total_params:,}")
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=epochs
    )
    
    # Training loop
    best_f1 = 0
    best_epoch = 0
    patience_counter = 0
    
    # Track metrics
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    val_f1s = []
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        
        # Train
        train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
        print(f"Train Loss: {train_metrics['loss']:.4f}, Accuracy: {train_metrics['accuracy']:.4f}")
        
        # Evaluate
        val_metrics = evaluate(model, val_loader, criterion, device)
        print(f"Val Loss: {val_metrics['loss']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}, Macro F1: {val_metrics['macro_f1']:.4f}")
        
        # Track metrics
        train_losses.append(train_metrics['loss'])
        train_accs.append(train_metrics['accuracy'])
        val_losses.append(val_metrics['loss'])
        val_accs.append(val_metrics['accuracy'])
        val_f1s.append(val_metrics['macro_f1'])
        
        # Step scheduler
        scheduler.step()
        
        # Save the model if it's the best so far
        if val_metrics['macro_f1'] > best_f1:
            best_f1 = val_metrics['macro_f1']
            best_epoch = epoch
            patience_counter = 0
            
            # Save model
            model_path = os.path.join(output_dir, f"lid_whisper_{whisper_size}_epoch{epoch+1}_f1{best_f1:.4f}.pt")
            torch.save(model.state_dict(), model_path)
            print(f"Saved best model to {model_path}")
        else:
            patience_counter += 1
            print(f"No improvement for {patience_counter} epochs (best F1: {best_f1:.4f} at epoch {best_epoch+1})")
        
        # Early stopping
        if patience_counter >= patience:
            print(f"Early stopping after {epoch+1} epochs")
            break
    
    # Plot training curves
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(train_losses, label='Train')
    plt.plot(val_losses, label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss Curves')
    
    plt.subplot(1, 3, 2)
    plt.plot(train_accs, label='Train')
    plt.plot(val_accs, label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy Curves')
    
    plt.subplot(1, 3, 3)
    plt.plot(val_f1s)
    plt.xlabel('Epoch')
    plt.ylabel('Macro F1')
    plt.title('Validation F1 Score')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_curves.png'))
    plt.show()
    
    # Load best model for testing if a best model was saved
    results = {}
    if best_epoch >= 0:
        best_model_path = os.path.join(output_dir, f"lid_whisper_{whisper_size}_epoch{best_epoch+1}_f1{best_f1:.4f}.pt")
        if os.path.exists(best_model_path):
            model.load_state_dict(torch.load(best_model_path))
        
            # Evaluate on test set
            if test_loader is not None:
                print("\nEvaluating on test set...")
                test_metrics = evaluate(model, test_loader, criterion, device)
                print(f"Test Accuracy: {test_metrics['accuracy']:.4f}, Macro F1: {test_metrics['macro_f1']:.4f}")
                
                # Plot confusion matrix
                cm = confusion_matrix(test_metrics['labels'], test_metrics['preds'])
                plt.figure(figsize=(10, 8))
                disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LANGS)
                disp.plot(cmap='Blues', values_format='d', xticks_rotation=45)
                plt.title(f'Confusion Matrix (Test Set)')
                plt.tight_layout()
                plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'))
                plt.show()
                
                # Save test results
                results = {
                    "test_accuracy": test_metrics['accuracy'],
                    "test_macro_f1": test_metrics['macro_f1'],
                    "best_val_f1": best_f1,
                    "best_epoch": best_epoch + 1,
                    "model_path": best_model_path,
                    "whisper_size": whisper_size,
                    "freeze_encoder": freeze_encoder,
                }
                
                # Save results to JSON
                with open(os.path.join(output_dir, "results.json"), "w") as f:
                    json.dump(results, f, indent=2)
    
    return model, results

# ================================================================
# 9. INFERENCE FUNCTIONS
# ================================================================

def predict_language(model, audio_path, device='cuda'):
    """Predict language of an audio file"""
    # Load audio with robust loader
    audio, sr = None, None
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        
        # Try soundfile first
        try:
            with sf.SoundFile(audio_path) as sound_file:
                audio = sound_file.read(dtype='float32')
                sr = sound_file.samplerate
                if audio.ndim > 1:  # Convert to mono if stereo
                    audio = audio.mean(axis=1)
        except:
            # Fall back to librosa
            try:
                audio, sr = librosa.load(
                    audio_path,
                    sr=None,
                    mono=True,
                    duration=None, 
                    offset=0.0,
                    dtype=np.float32,
                    res_type='kaiser_best'
                )
            except:
                # Last resort: torchaudio
                try:
                    try:
                        torchaudio.set_audio_backend("sox_io")
                    except:
                        pass
                    audio_tensor, sr = torchaudio.load(audio_path)
                    if audio_tensor.shape[0] > 1:  # Multi-channel, convert to mono
                        audio = audio_tensor.mean(dim=0).numpy()
                    else:
                        audio = audio_tensor.squeeze(0).numpy()
                except Exception as e:
                    print(f"Failed to load {audio_path}: {e}")
                    sr = 16000
                    audio = np.zeros(sr)
    
    audio = torch
    audio = torch.tensor(audio).float()
    
    # Resample if needed
    if sr != 16000:
        audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=16000)
    
    # Pad or trim
    audio = whisper.audio.pad_or_trim(audio)
    
    # Extract mel spectrogram
    mel = whisper.audio.log_mel_spectrogram(audio)
    
    # Add batch dimension and move to device
    mel = mel.unsqueeze(0).to(device)
    
    # Run inference
    model.eval()
    with torch.no_grad():
        logits = model(mel)
        probabilities = F.softmax(logits, dim=1)
    
    # Get top predictions
    probs, indices = torch.topk(probabilities, 3)
    
    # Convert to languages
    results = []
    for i in range(3):
        lang_idx = indices[0, i].item()
        lang = ID2LANG[lang_idx]
        prob = probs[0, i].item()
        results.append((lang, prob))
    
    return results

def demo_inference(model, audio_path, device='cuda'):
    """Run inference and display results visually"""
    # Get predictions
    results = predict_language(model, audio_path, device=device)
    
    # Print results
    print(f"Results for {os.path.basename(audio_path)}:")
    for lang, prob in results:
        print(f"  - {lang.capitalize()}: {prob*100:.2f}%")
    
    # Plot probabilities as a bar chart
    langs, probs = zip(*results)
    langs = [l.capitalize() for l in langs]
    
    plt.figure(figsize=(10, 5))
    plt.bar(langs, probs, color='skyblue')
    plt.xlabel('Language')
    plt.ylabel('Probability')
    plt.title(f'Language Prediction for {os.path.basename(audio_path)}')
    plt.ylim(0, 1)
    for i, prob in enumerate(probs):
        plt.text(i, prob + 0.02, f'{prob*100:.1f}%', ha='center')
    plt.tight_layout()
    plt.show()

# ================================================================
# 10. EXECUTE THE PIPELINE
# ================================================================

# Debugging information about directories
print(f"Current working directory: {os.getcwd()}")
print(f"Checking if MP3_DIR exists: {os.path.exists(MP3_DIR)}")
if os.path.exists(MP3_DIR):
    print(f"Contents of MP3_DIR: {os.listdir(MP3_DIR)}")
else:
    print("MP3_DIR does not exist or is not accessible")
    
    # Try to find where the data might be
    for root, dirs, files in os.walk("G44"):
        if "Language Detection Dataset" in dirs:
            print(f"Found 'Language Detection Dataset' in {root}")
            MP3_DIR = os.path.join(root, "Language Detection Dataset")
            print(f"Updated MP3_DIR to {MP3_DIR}")
            break

# Check if the data is already processed
wav_count = 0
for lang in LANGS:
    lang_dir = os.path.join(WAV_DIR, lang)
    if os.path.exists(lang_dir):
        for root, _, files in os.walk(lang_dir):
            wav_count += sum(1 for f in files if f.endswith('.wav'))

# Run preprocessing steps
print("Starting data preprocessing...")

# Convert MP3 to WAV only if needed
if wav_count < 1000:  # If we have fewer than 1000 WAV files, assume we need conversion
    print("Converting MP3 files to WAV...")
    convert_dataset(MP3_DIR, WAV_DIR, n_workers=4)
else:
    print(f"Skipping MP3 to WAV conversion, found {wav_count} existing WAV files")

# Create data splits
print("Creating data splits...")
train_df, val_df, test_df = create_data_splits(WAV_DIR, SPLITS_DIR, split_ratio=(0.8, 0.1, 0.1))

# Check if we have data to proceed
if len(train_df) == 0 or len(val_df) == 0:
    print("Insufficient data to train model. Please check your dataset paths.")
    print("You may need to manually download and organize the dataset before running this script.")
    print(f"Expected MP3 files in: {MP3_DIR}")
    
    # Instructions for manual setup
    print("\nSuggested steps to set up the dataset:")
    print(f"1. Download the dataset from https://www.kaggle.com/datasets/hbchaitanyabharadwaj/audio-dataset-with-10-indian-languages")
    print(f"2. Extract all files to: {MP3_DIR}")
    print(f"3. Ensure the directory structure is: {MP3_DIR}/[Language]/[audio_files.mp3]")
    print(f"4. Run this script again after setting up the dataset")
else:
    # Train the model
    print("Starting model training...")
    model, results = train_model(
        train_csv=os.path.join(SPLITS_DIR, "train.csv"),
        val_csv=os.path.join(SPLITS_DIR, "val.csv"),
        test_csv=os.path.join(SPLITS_DIR, "test.csv"),
        whisper_size="tiny",  # Use tiny for faster training; options: tiny, base, small, medium, large
        freeze_encoder=True,
        batch_size=16,        # Reduced batch size for GPU memory
        epochs=5,
        lr=1e-4,
        patience=3,
        num_workers=2,        # Reduced workers for compatibility
        device=DEVICE,
        output_dir=OUTPUT_DIR
    )

    if model is not None:
        print("Pipeline completed successfully!")
        
        # Example of running inference if test files are available
        test_files = []
        if os.path.exists(os.path.join(SPLITS_DIR, "test.csv")):
            test_df = pd.read_csv(os.path.join(SPLITS_DIR, "test.csv"))
            if len(test_df) > 0:
                # Get a few examples from each language for demo
                test_files = []
                for lang in LANGS:
                    lang_files = test_df[test_df['lang'] == lang]['path'].values
                    if len(lang_files) > 0:
                        test_files.append(lang_files[0])  # Take first file of each language
                        
                # Run demo inference on sample files
                if test_files:
                    print("\nDemonstrating inference on sample files:")
                    for audio_file in test_files[:3]:  # Limit to 3 examples
                        if os.path.exists(audio_file):
                            print(f"\nInference on {audio_file}:")
                            demo_inference(model, audio_file, device=DEVICE)
    else:
        print("Model training failed. Please check error messages.")

# Optional: Add a simple way to run inference on a specific file
def run_inference_on_file(model_path, audio_path, device=DEVICE):
    """Run inference on a specific audio file using a trained model"""
    # Load model
    model = create_model(whisper_size="tiny", num_languages=len(LANGS), device=device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    
    # Run demo inference
    demo_inference(model, audio_path, device=device)

# Example usage of the inference function:
# model_path = os.path.join(OUTPUT_DIR, "lid_whisper_tiny_epoch5_f10.8976.pt")
# audio_path = "path/to/test/audio.wav"
# run_inference_on_file(model_path, audio_path)

CUDA available: True
CUDA device: NVIDIA RTX A5000
Using device: cuda
Languages: ['Bengali', 'Gujarati', 'Hindi', 'Kannada', 'Malayalam', 'Marathi', 'Punjabi', 'Tamil', 'Telugu', 'Urdu']
Total languages: 10
Current working directory: /home/teaching
Checking if MP3_DIR exists: True
Contents of MP3_DIR: ['Kannada', 'Marathi', 'Punjabi', 'Telugu', 'Gujarati', 'Malayalam', 'Urdu', 'Tamil', 'Hindi', 'Bengali']
Starting data preprocessing...
Skipping MP3 to WAV conversion, found 256826 existing WAV files
Creating data splits...
Total samples: 256826
lang
Urdu         31959
Bengali      27258
Gujarati     26439
Punjabi      26227
Hindi        25462
Marathi      25378
Tamil        24195
Malayalam    24044
Telugu       23656
Kannada      22208
Name: count, dtype: int64
Splits created and saved to G44/splits:
  - Train: 205457 samples
  - Validation: 25676 samples
  - Test: 25693 samples
Starting model training...
Using device: cuda
Model: Whisper-tiny for LID
Trainable parameters: 101,130 / 7,7

Training:   8%|████▌                                                      | 1000/12842 [02:56<44:19,  4.45it/s]

Batch 1000: loss=2.1901, acc=22.21%


Training:  16%|█████████▏                                                 | 1999/12842 [05:52<32:06,  5.63it/s]

Batch 2000: loss=2.0502, acc=29.99%


Training:  23%|█████████████▊                                             | 2999/12842 [08:45<29:40,  5.53it/s]

Batch 3000: loss=1.9200, acc=35.38%


Training:  31%|██████████████████▍                                        | 4000/12842 [11:40<26:51,  5.49it/s]

Batch 4000: loss=1.8078, acc=39.25%


Training:  39%|██████████████████████▉                                    | 4997/12842 [14:31<22:37,  5.78it/s]

Batch 5000: loss=1.7161, acc=42.16%


Training:  47%|███████████████████████████▌                               | 6000/12842 [17:27<19:36,  5.82it/s]

Batch 6000: loss=1.6390, acc=44.50%


Training:  55%|████████████████████████████████▏                          | 7000/12842 [20:21<17:06,  5.69it/s]

Batch 7000: loss=1.5741, acc=46.42%


Training:  62%|████████████████████████████████████▋                      | 7999/12842 [23:12<13:23,  6.03it/s]

Batch 8000: loss=1.5173, acc=48.13%


Training:  70%|█████████████████████████████████████████▎                 | 8997/12842 [26:05<11:14,  5.70it/s]

Batch 9000: loss=1.4694, acc=49.51%


Training:  75%|███████████████████████████████████████████▉               | 9568/12842 [27:44<09:16,  5.88it/s]