In [3]:
import sys
!{sys.executable} pip install fvcore

Channels:
 - facebookresearch
 - defaults
 - conda-forge
 - anaconda
 - pytorch
Platform: linux-64
Collecting package metadata (repodata.json): failed

UnavailableInvalidChannel: HTTP 404 NOT FOUND for channel facebookresearch <https://conda.anaconda.org/facebookresearch>

The channel is not accessible or is invalid.

You will need to adjust your conda configuration to proceed.
Use `conda config --show channels` to view your configuration's current state,
and use `conda config --show-sources` to view config file locations.



Note: you may need to restart the kernel to use updated packages.


In [1]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
from torch.utils.data import Dataset, DataLoader
import glob
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, recall_score, roc_auc_score, roc_curve
import torchaudio.transforms as T
from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR
from torch.optim.swa_utils import AveragedModel, SWALR
import librosa  # Added import
import warnings
warnings.filterwarnings('ignore')

# Dataset Analysis Function
def analyze_dataset(protocols_dir, audio_dir):
    """Analyze dataset characteristics including class balance and audio durations"""
    train_files = glob.glob(f"{protocols_dir}/train_fold*.csv")
    
    class_counts = [0, 0]
    durations = []
    sample_rates = set()
    valid_files = 0
    total_files = 0
    
    for fold_file in train_files:
        annos = pd.read_csv(fold_file)
        for idx in range(len(annos)):
            total_files += 1
            clip_name = annos.iloc[idx, 0]
            audio_path = os.path.join(audio_dir, f"{clip_name}.wav")
            
            if os.path.exists(audio_path):
                try:
                    metadata = torchaudio.info(audio_path)
                    duration = metadata.num_frames / metadata.sample_rate
                    durations.append(duration)
                    sample_rates.add(metadata.sample_rate)
                    
                    label = 0 if any(k in str(annos.iloc[idx, 1]).lower() for k in ['truth', '0']) else 1
                    class_counts[label] += 1
                    valid_files += 1
                except Exception as e:
                    print(f"Error analyzing {audio_path}: {e}")
    
    duration_stats = {
        'min': min(durations),
        'max': max(durations),
        'mean': np.mean(durations),
        'median': np.median(durations),
        'std': np.std(durations)
    }
    
    print("\nDataset Analysis Results:")
    print(f"Total files: {total_files}")
    print(f"Valid files: {valid_files} ({valid_files/total_files:.1%})")
    print(f"Class distribution: {class_counts} (Real: {class_counts[0]}, Fake: {class_counts[1]})")
    print(f"Class ratio: {class_counts[1]/class_counts[0]:.2f}:1")
    print(f"Sample rates found: {sample_rates}")
    print("Duration statistics (seconds):")
    for k, v in duration_stats.items():
        print(f"  {k}: {v:.2f}")
    
    return {
        'class_counts': class_counts,
        'duration_stats': duration_stats,
        'sample_rates': sample_rates,
        'valid_ratio': valid_files/total_files
    }

# Enhanced AudioDataset Class
class AudioDataset(Dataset):
    def __init__(self, annotations_file, audio_dir, target_length=16000, mode='spectrogram', 
                 transform=None, spec_transform=None, spec_params=None):
        self.annos = pd.read_csv(annotations_file)
        self.audio_dir = audio_dir
        self.target_length = target_length
        self.mode = mode
        self.transform = transform
        self.spec_transform = spec_transform
        self.spec_params = spec_params or {
            'n_mels': 128,
            'n_fft': 2048,
            'hop_length': 512,
            'f_min': 20,
            'f_max': 8000
        }
        
        self.valid_data = []
        for idx in range(len(self.annos)):
            clip_name = self.annos.iloc[idx, 0]
            audio_path = os.path.join(self.audio_dir, f"{clip_name}.wav")
            if os.path.exists(audio_path):
                label = self._process_label(self.annos.iloc[idx, 1])
                self.valid_data.append((audio_path, label))
        
        print(f"Found {len(self.valid_data)}/{len(self.annos)} valid files.")
    
    def __len__(self):
        return len(self.valid_data)
    
    def __getitem__(self, idx):
        audio_path, label = self.valid_data[idx]
        try:
            waveform, sample_rate = torchaudio.load(audio_path)
            waveform = waveform.mean(dim=0)  # Mono
            waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
            
            if waveform.shape[0] > self.target_length:
                start = torch.randint(0, waveform.shape[0] - self.target_length, (1,)).item()
                waveform = waveform[start:start + self.target_length]
            else:
                padding = self.target_length - waveform.shape[0]
                waveform = torch.nn.functional.pad(waveform, (0, padding))
            
            if self.transform:
                waveform = self.transform(waveform)
            
            if self.mode == 'waveform':
                data = waveform
            elif self.mode == 'spectrogram':
                spectrogram = self._create_spectrogram(waveform)
                if self.spec_transform:
                    spectrogram = self.spec_transform(spectrogram)
                data = spectrogram
            elif self.mode == 'both':
                spectrogram = self._create_spectrogram(waveform)
                if self.spec_transform:
                    spectrogram = self.spec_transform(spectrogram)
                data = (waveform, spectrogram)
            else:
                raise ValueError("Invalid mode")
            
            return data, torch.tensor(label)
        except Exception as e:
            print(f"[ERROR] Error loading {audio_path}: {e}")
            return None, None
    
    def _create_spectrogram(self, waveform):
        mel_spec = T.MelSpectrogram(
            sample_rate=16000,
            n_mels=self.spec_params['n_mels'],
            n_fft=self.spec_params['n_fft'],
            hop_length=self.spec_params['hop_length'],
            f_min=self.spec_params['f_min'],
            f_max=self.spec_params['f_max']
        )(waveform)
        db_spec = T.AmplitudeToDB()(mel_spec)
        return db_spec.unsqueeze(0)
    
    def _process_label(self, label_str):
        str_label = str(label_str).strip().lower()
        return 0 if any(k in str_label for k in ['truth', '0']) else 1

# Collate Function
def audio_collate_fn(batch):
    batch = [item for item in batch if item[0] is not None]
    if not batch:
        return torch.zeros(1, 16000), torch.tensor([0])
    inputs, labels = zip(*batch)
    if isinstance(inputs[0], tuple):
        waveforms, spectrograms = zip(*inputs)
        waveforms = torch.stack(waveforms)
        spectrograms = torch.stack(spectrograms)
        inputs = (waveforms, spectrograms)
    else:
        inputs = torch.stack(inputs)
    labels = torch.stack(labels)
    return inputs, labels

# Advanced Audio Transform
import torch
import torchaudio
import librosa
import numpy as np

class AdvancedAudioTransform:
    def __init__(self, sample_rate=16000, target_length=None):
        self.sample_rate = sample_rate
        self.target_length = target_length
    
    def __call__(self, waveform):
        try:
            # Ensure waveform is a tensor
            if not isinstance(waveform, torch.Tensor):
                raise ValueError("Input must be a PyTorch tensor")
            
            # Time stretch
            if self.target_length is not None and torch.rand(1) < 0.5:
                rate = 0.8 + torch.rand(1) * 0.4  # 0.8-1.2
                waveform_np = waveform.numpy()
                stretched = librosa.effects.time_stretch(waveform_np, rate=rate.item())
                stretched = torch.from_numpy(stretched).float()
                if stretched.shape[0] > self.target_length:
                    start = torch.randint(0, stretched.shape[0] - self.target_length, (1,)).item()
                    waveform = stretched[start:start + self.target_length]
                else:
                    padding = self.target_length - stretched.shape[0]
                    waveform = torch.nn.functional.pad(stretched, (0, padding))
            else:
                waveform = waveform
            
            # Pitch shift
            if torch.rand(1) < 0.5:
                n_steps = -4 + torch.rand(1) * 8  # -4 to 4 steps
                waveform = torchaudio.functional.pitch_shift(waveform, self.sample_rate, n_steps.item())
            
            # Background noise
            if torch.rand(1) < 0.3:
                noise = torch.randn_like(waveform) * 0.005
                waveform = waveform + noise
            
            # Random gain
            if torch.rand(1) < 0.3:
                gain = 0.8 + torch.rand(1) * 0.4  # 0.8-1.2
                waveform = waveform * gain
            
            return waveform
        except Exception as e:
            print(f"Error in AdvancedAudioTransform: {e}")
            return waveform  # Return original waveform on failure to avoid None

# Compute Class Weights
def compute_class_weights(protocols_dir, audio_dir):
    train_files = glob.glob(f"{protocols_dir}/train_fold*.csv")
    labels = []
    for fold_file in train_files:
        annos = pd.read_csv(fold_file)
        for idx in range(len(annos)):
            clip_name = annos.iloc[idx, 0]
            audio_path = os.path.join(audio_dir, f"{clip_name}.wav")
            if os.path.exists(audio_path):
                label = 0 if any(k in str(annos.iloc[idx, 1]).lower() for k in ['truth', '0']) else 1
                labels.append(label)
    labels = np.array(labels)
    class_counts = np.bincount(labels)
    total = len(labels)
    weights = total / (2.0 * class_counts)
    return torch.tensor(weights, dtype=torch.float)

# Models from Updated Code
class HTSAT(nn.Module):
    """Hierarchical Token-Semantic Audio Transformer (HTS-AT)"""
    def __init__(self, num_classes=2):
        super().__init__()
        self.spec = T.MelSpectrogram(sample_rate=16000, n_fft=2048, hop_length=512, n_mels=128)
        self.amplitude_to_db = T.AmplitudeToDB()
        
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.SiLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.SiLU()
        )
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=128, nhead=8, dim_feedforward=512),
            num_layers=6
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        x = self.spec(x)
        x = self.amplitude_to_db(x)
        x = x.unsqueeze(1)
        x = self.conv(x)
        b, c, h, w = x.shape
        x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.classifier(x)

class S3D(nn.Module):
    """S3D Network with pretrained weights adapted for audio"""
    def __init__(self, num_classes=2):
        super().__init__()
        self.s3d = torch.hub.load('facebookresearch/pytorchvideo', 's3d', pretrained=True)
        # Modify for audio input (assuming waveform input)
        self.s3d.blocks[0].conv = nn.Conv3d(1, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3))
        self.s3d.blocks[4].proj = nn.Linear(1024, num_classes)
    
    def forward(self, x):
        x = x.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # [B, 1, T, 1, 1]
        return self.s3d(x)

class AudioSpectrogramTransformer(nn.Module):
    """Modified AST with Learnable Positional Embeddings"""
    def __init__(self, input_size=(128, 256), num_classes=2):
        super().__init__()
        self.patch_size = 16
        self.hidden_size = 768
        
        self.conv = nn.Conv2d(1, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size)
        
        seq_len = (input_size[0] // self.patch_size) * (input_size[1] // self.patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, seq_len, self.hidden_size))
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=self.hidden_size, nhead=12, dim_feedforward=3072),
            num_layers=12
        )
        
        self.classifier = nn.Linear(self.hidden_size, num_classes)
    
    def forward(self, x):
        x = self.conv(x)  # [B, C, H, W]
        x = x.flatten(2).permute(0, 2, 1)  # [B, N, C]
        x += self.pos_embed
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.classifier(x)

class WavegramCNN(nn.Module):
    """Waveform-based CNN with SincNet Filters"""
    def __init__(self, num_classes=2):
        super().__init__()
        self.sincnet = nn.Sequential(
            nn.Conv1d(1, 64, 251, stride=80, padding=125),
            nn.MaxPool1d(4),
            nn.Conv1d(64, 128, 5),
            nn.MaxPool1d(4),
            nn.Conv1d(128, 256, 3),
            nn.MaxPool1d(4)
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dim
        x = self.sincnet(x)
        return self.classifier(x)

class HybridModel(nn.Module):
    """Waveform + Spectrogram Hybrid Model"""
    def __init__(self, num_classes=2):
        super().__init__()
        self.wave_model = WavegramCNN(num_classes=128)  # Intermediate features
        self.spec_model = HTSAT(num_classes=128)        # Intermediate features
        
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, inputs):
        x_wave, x_spec = inputs
        wave_feat = self.wave_model(x_wave)
        spec_feat = self.spec_model(x_spec)
        combined = torch.cat([wave_feat, spec_feat], dim=1)
        return self.classifier(combined)

# Mixup Function
def mixup_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    if isinstance(x, tuple):
        mixed_x = tuple(lam * xi + (1 - lam) * xi[index] for xi in x)
    else:
        mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# Compute Metrics with EER
def compute_metrics(preds, labels):
    preds_np = preds.argmax(dim=1).cpu().numpy()
    probs_np = preds[:, 1].cpu().numpy()
    labels_np = labels.cpu().numpy()
    
    acc = accuracy_score(labels_np, preds_np)
    f1 = f1_score(labels_np, preds_np)
    recall = recall_score(labels_np, preds_np)
    auc = roc_auc_score(labels_np, probs_np)
    eer = compute_eer(labels_np, probs_np)
    
    return {
        'loss': nn.CrossEntropyLoss()(preds, labels).item(),
        'acc': acc,
        'f1': f1,
        'recall': recall,
        'auc': auc,
        'eer': eer
    }

def compute_eer(labels, probs):
    fpr, tpr, _ = roc_curve(labels, probs)
    fnr = 1 - tpr
    min_index = np.nanargmin(np.abs(fnr - fpr))
    eer = (fpr[min_index] + fnr[min_index]) / 2
    return eer

# Validation with TTA
def validate_with_tta(model, loader, device, tta_steps=5, target_length=None):
    model.eval()
    all_preds = []
    all_labels = []
    transform = AdvancedAudioTransform(sample_rate=16000, target_length=target_length)
    
    with torch.no_grad():
        for batch in loader:
            if isinstance(batch[0], tuple):
                (waveform, spectrogram), labels = batch
                waveform = waveform.to(device)
                spectrogram = spectrogram.to(device)
                inputs = (waveform, spectrogram)
            else:
                inputs, labels = batch
                inputs = inputs.to(device)
            labels = labels.to(device)
            
            preds = []
            for _ in range(tta_steps):
                if isinstance(inputs, tuple):
                    # Augment only the waveform for HybridModel
                    aug_wave = transform(inputs[0])
                    aug_inputs = (aug_wave, inputs[1])  # Keep spectrogram unchanged
                else:
                    aug_inputs = transform(inputs)
                
                if aug_inputs is None:  # Safeguard against None
                    aug_inputs = inputs
                
                outputs = model(aug_inputs)
                preds.append(F.softmax(outputs, dim=1))
            avg_preds = torch.stack(preds).mean(0)
            all_preds.append(avg_preds)
            all_labels.append(labels)
    
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    return compute_metrics(all_preds, all_labels)

# Training Function
def train_model(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using {device}")
    
    # Compute class weights
    class_weights = None
    if config.get('class_weighting', False):
        class_weights = compute_class_weights(config['protocols_dir'], config['audio_dir'])
        class_weights = class_weights.to(device)
        print(f"Class weights: {class_weights}")
    
    # Create checkpoint directory
    os.makedirs(config['checkpoint_dir'], exist_ok=True)
    metrics_file = os.path.join(config['checkpoint_dir'], 'metrics.csv')
    if not os.path.exists(metrics_file):
        pd.DataFrame(columns=['epoch', 'fold', 'train_loss', 'train_acc', 'train_f1', 
                              'train_recall', 'train_auc', 'train_eer', 'val_loss', 
                              'val_acc', 'val_f1', 'val_recall', 'val_auc', 'val_eer']).to_csv(metrics_file, index=False)
    
    # Get fold numbers
    train_files = glob.glob(f"{config['protocols_dir']}/train_fold*.csv")
    fold_numbers = sorted([f.split("train_fold")[1].split(".csv")[0] for f in train_files])
    
    for fold in fold_numbers:
        print(f"\n=== Training Fold {fold} ===")
        
        # Initialize model
        model = config['model']().to(device)
        
        # Optimizer and criterion
        optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'], 
                                weight_decay=config.get('weight_decay', 0.01))
        criterion = nn.CrossEntropyLoss(weight=class_weights, 
                                        label_smoothing=config.get('label_smoothing', 0.0))
        
        # Scheduler
        scheduler = None
        if config.get('scheduler') == 'cosine':
            scheduler = CosineAnnealingLR(optimizer, T_max=config['num_epochs'])
        elif config.get('scheduler') == 'cyclic':
            scheduler = CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-3, step_size_up=500)
        
        # EMA and SWA
        ema_model = AveragedModel(model)
        swa_model = AveragedModel(model)
        swa_scheduler = SWALR(optimizer, swa_lr=config['learning_rate'] / 10)
        
        # Data loading
        train_file = f"{config['protocols_dir']}/train_fold{fold}.csv"
        val_file = f"{config['protocols_dir']}/test_fold{fold}.csv"
        
        transform = AdvancedAudioTransform(sample_rate=16000, target_length=config['target_length']) if config.get('augmentations', False) else None
        train_dataset = AudioDataset(
            train_file, config['audio_dir'], config['target_length'],
            mode=config.get('mode', 'spectrogram'),
            transform=transform,
            spec_transform=None,
            spec_params=config.get('spec_params')
        )
        val_dataset = AudioDataset(
            val_file, config['audio_dir'], config['target_length'],
            mode=config.get('mode', 'spectrogram'),
            transform=None,
            spec_transform=None,
            spec_params=config.get('spec_params')
        )
        
        train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], 
                                  collate_fn=audio_collate_fn, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], 
                                collate_fn=audio_collate_fn)
        
        best_val_auc = 0
        patience_counter = 0
        
        for epoch in range(config['num_epochs']):
            model.train()
            epoch_loss = 0
            train_preds, train_labels = [], []
            
            for batch in train_loader:
                if config.get('mode') == 'both':
                    (waveform, spectrogram), labels = batch
                    waveform = waveform.to(device)
                    spectrogram = spectrogram.to(device)
                    inputs = (waveform, spectrogram)
                else:
                    inputs, labels = batch
                    inputs = inputs.to(device)
                labels = labels.to(device)
                
                # Mixup
                if config.get('mixup_alpha', 0) > 0:
                    inputs, labels_a, labels_b, lam = mixup_data(inputs, labels, config['mixup_alpha'])
                
                optimizer.zero_grad()
                outputs = model(inputs)
                if config.get('mixup_alpha', 0) > 0:
                    loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
                else:
                    loss = criterion(outputs, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                
                # Update EMA
                ema_model.update_parameters(model)
                
                epoch_loss += loss.item()
                train_preds.append(outputs.detach())
                train_labels.append(labels)
            
            # Update SWA
            if epoch >= config.get('swa_start', 25):
                swa_model.update_parameters(model)
                swa_scheduler.step()
            
            # Validation with TTA
            # Validation with TTA
            val_metrics = validate_with_tta(ema_model, val_loader, device, 
                                            tta_steps=config.get('tta_steps', 5), 
                                            target_length=config['target_length'])
            
            # Compute train metrics
            train_preds = torch.cat(train_preds)
            train_labels = torch.cat(train_labels)
            train_metrics = compute_metrics(train_preds, train_labels)
            train_loss = epoch_loss / len(train_loader)
            
            # Save metrics
            metrics = {
                'epoch': epoch + 1,
                'fold': fold,
                'train_loss': train_loss,
                **{f'train_{k}': v for k, v in train_metrics.items()},
                **{f'val_{k}': v for k, v in val_metrics.items()}
            }
            pd.DataFrame([metrics]).to_csv(metrics_file, mode='a', header=False, index=False)
            
            # Early stopping and model saving
            if val_metrics['auc'] > best_val_auc:
                best_val_auc = val_metrics['auc']
                patience_counter = 0
                
                # Save best EMA model
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': ema_model.module.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_auc': val_metrics['auc'],
                    'config': config
                }, os.path.join(config['checkpoint_dir'], f'best_ema_model_fold{fold}.pth'))
                
                print(f"Fold {fold} Epoch {epoch+1}: New best model saved (val_auc: {val_metrics['auc']:.4f})")
            else:
                patience_counter += 1
                if config.get('early_stopping', False) and patience_counter >= config['patience']:
                    print(f"Early stopping triggered for fold {fold} at epoch {epoch + 1}")
                    break
            
            if scheduler and not isinstance(scheduler, SWALR):
                scheduler.step()
            
            print(f"Fold {fold} Epoch {epoch+1}/{config['num_epochs']}: "
                  f"Train Loss: {train_loss:.4f}, Val AUC: {val_metrics['auc']:.4f}")
        
        # Save SWA model
        torch.save({
            'model_state_dict': swa_model.module.state_dict(),
            'config': config
        }, os.path.join(config['checkpoint_dir'], f'swa_model_fold{fold}.pth'))

# Main Execution
if __name__ == "__main__":
    # Analyze dataset
    dataset_stats = analyze_dataset(
        protocols_dir='/home/azureuser/cloudfiles/code/Users/yashika22csu235/research/train_protocol',
        audio_dir='/home/azureuser/cloudfiles/code/Users/yashika22csu235/research/audio_files'
    )
    
    # Base configuration
    base_config = {
        'batch_size': 64,
        'num_epochs': 100,
        'learning_rate': 3e-4,
        'audio_dir': '/home/azureuser/cloudfiles/code/Users/yashika22csu235/research/audio_files',
        'protocols_dir': '/home/azureuser/cloudfiles/code/Users/yashika22csu235/research/train_protocol',
        'base_checkpoint_dir': '/home/azureuser/cloudfiles/code/Users/yashika22csu235/research/train_new/experiment4/',
        'target_length': int(dataset_stats['duration_stats']['median'] * 16000),
        'augmentations': True,
        'class_weighting': True if dataset_stats['class_counts'][0] != dataset_stats['class_counts'][1] else False,
        'early_stopping': True,
        'patience': 7,
        'min_delta': 0.001,
        'mixup_alpha': 0.4,
        'label_smoothing': 0.1,
        'scheduler': 'cosine',
        'weight_decay': 0.05,
        'swa_start': 25,
        'tta_steps': 5,
        'spec_params': {
            'n_mels': 128,
            'n_fft': 2048,
            'hop_length': 512,
            'f_min': 20,
            'f_max': 8000
        }
    }
    
    # Experiments
    experiments = [
       # {
       #    'name': 'HTSAT_Advanced',
        #   'model': HTSAT,
         #  'config': {
    #      'mode': 'waveform',  # HTSAT computes spectrogram internally
          #      'learning_rate': 2e-4
          #  }
        #},
        {
            'name': 'S3D_Pretrained',
            'model': S3D,
            'config': {
                'mode': 'waveform',
                'learning_rate': 3e-5
            }
        },
        {
            'name': 'AST_Deep',
            'model': AudioSpectrogramTransformer,
            'config': {
                'mode': 'spectrogram',
                'learning_rate': 5e-5
            }
        },
        {
            'name': 'WavegramCNN',
            'model': WavegramCNN,
            'config': {
                'mode': 'waveform',
                'learning_rate': 1e-4
            }
        },
        {
            'name': 'Hybrid_WaveSpec',
            'model': HybridModel,
            'config': {
                'mode': 'both',
                'learning_rate': 1e-4,
                'scheduler': 'cyclic'
            }
        }
    ]
    
    for exp in experiments:
        config = base_config.copy()
        config.update(exp['config'])
        config['model'] = exp['model']
        config['checkpoint_dir'] = os.path.join(base_config['base_checkpoint_dir'], exp['name'])
        print(f"\nRunning experiment: {exp['name']}")
        train_model(config)


Dataset Analysis Results:
Total files: 3291
Valid files: 2523 (76.7%)
Class distribution: [1186, 1337] (Real: 1186, Fake: 1337)
Class ratio: 1.13:1
Sample rates found: {44100}
Duration statistics (seconds):
  min: 2.00
  max: 1569.00
  mean: 7.84
  median: 5.00
  std: 48.78

Running experiment: S3D_Pretrained
Using cpu
Class weights: tensor([1.0637, 0.9435])

=== Training Fold 1 ===


Using cache found in /home/azureuser/.cache/torch/hub/facebookresearch_pytorchvideo_main


ModuleNotFoundError: No module named 'fvcore'