In [None]:
!pip install wandb python-dotenv
import wandb
from dotenv import load_dotenv
import os
load_dotenv()
wandb.login(key=os.getenv("WANDB_API"))

Collecting wandb
  Downloading wandb-0.19.11-py3-none-win_amd64.whl.metadata (10 kB)
Collecting python-dotenv
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB)
Collecting click!=8.0.0,>=7.1 (from wandb)
  Downloading click-8.2.1-py3-none-any.whl.metadata (2.5 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.44-py3-none-any.whl.metadata (13 kB)
Collecting protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 (from wandb)
  Downloading protobuf-6.31.0-cp310-abi3-win_amd64.whl.metadata (593 bytes)
Collecting pyyaml (from wandb)
  Using cached PyYAML-6.0.2-cp312-cp312-win_amd64.whl.metadata (2.1 kB)
Collecting sentry-sdk>=2.0.0 (from wandb)
  Downloading sentry_sdk-2.29.1-py2.py3-none-any.whl.metadata (10 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.6-cp312-cp312-win_amd64.whl.metadata (10 kB)
Collecting gitdb<

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\NamHoang\_netrc
[34m[1mwandb[0m: Currently logged in as: [33mnamthse182380[0m ([33mnamthse182380-fpt-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
import numpy as np
import pandas as pd
import random
import glob
import librosa
import librosa.display
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay, roc_auc_score, accuracy_score
from kaggle_secrets import UserSecretsClient

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --- Device Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# --- Paths and Spectrogram Configuration ---
REAL_AUDIO_PATH = "/kaggle/input/the-lj-speech-dataset/LJSpeech-1.1/wavs"
FAKE_AUDIO_PATH = "/kaggle/input/wavefake-test/generated_audio/common_voices_prompts_from_conformer_fastspeech2_pwg_ljspeech"
SR = 16000
N_FFT = 2048
HOP_LENGTH = 512
N_MELS = 128
MAX_FRAMES_SPEC = 313
FMIN = 0.0
FMAX = None
APPLY_AUGMENTATION = True
NUM_TIME_MASKS = 1
NUM_FREQ_MASKS = 1
TIME_MASK_MAX_WIDTH = 40
FREQ_MASK_MAX_WIDTH = 15
NORM_EPSILON = 1e-6
MASK_REPLACEMENT_VALUE = 0.0
LIMIT_FILES = None
TRAIN_RATIO = 0.70
VALIDATION_RATIO = 0.15
TEST_RATIO = 0.15
BATCH_SIZE = 32
NUM_WORKERS = 2
LEARNING_RATE = 1e-4
EPOCHS = 20
WEIGHT_DECAY = 1e-4

# --- WandB Login with Kaggle Secrets ---
try:
    user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("wandb_api_key")
    wandb.login(key=wandb_api_key)
    print("WandB login successful using Kaggle Secrets.")
except Exception as e:
    print(f"Failed to login to WandB via Kaggle Secrets: {e}. Falling back to environment variable or manual login.")
    wandb.login()

# --- WandB Initialization for ViT ---
wandb.init(
    project="ASM01_DAT301m",
    config={
        "learning_rate": LEARNING_RATE,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "weight_decay": WEIGHT_DECAY,
        "vit_patch_size": 16,
        "vit_embed_dim": 192,
        "vit_depth": 6,
        "vit_num_heads": 6,
        "vit_mlp_ratio": 4.0,
        "vit_drop_rate": 0.1,
        "vit_attn_drop_rate": 0.1,
        "cnn_dropout_rate": 0.4,  # Added for CNN
        "n_mels": N_MELS,
        "max_frames_spec": MAX_FRAMES_SPEC,
        "apply_augmentation": APPLY_AUGMENTATION,
        "num_time_masks": NUM_TIME_MASKS,
        "num_freq_masks": NUM_FREQ_MASKS,
        "time_mask_max_width": TIME_MASK_MAX_WIDTH,
        "freq_mask_max_width": FREQ_MASK_MAX_WIDTH,
    },
    name=f"ViT_CNN_Audio_Deepfake_{time.strftime('%Y%m%d_%H%M%S')}"
)

In [None]:
# --- 1. Data Loading and Preprocessing Functions ---
def get_audio_files_and_labels(real_dir, fake_dir, limit_files=None):
    real_files = glob.glob(os.path.join(real_dir, '*.wav'))
    fake_files = glob.glob(os.path.join(fake_dir, '*.wav'))
    print(f"Found {len(real_files)} real audio files.")
    print(f"Found {len(fake_files)} fake audio files.")
    if not real_files and not fake_files:
        if limit_files is None: print("Warning: Could not find audio files in one or both directories.")
    elif not real_files: print("Warning: No real audio files found.")
    elif not fake_files: print("Warning: No fake audio files found.")
    if limit_files:
        print(f"Limiting files. Attempting to sample up to {limit_files // 2} from each class.")
        real_files = random.sample(real_files, min(limit_files // 2, len(real_files)))
        fake_files = random.sample(fake_files, min(limit_files // 2, len(fake_files)))
        print(f"Selected {len(real_files)} real and {len(fake_files)} fake files after limiting.")
    filepaths = real_files + fake_files
    labels = [0] * len(real_files) + [1] * len(fake_files)
    if not filepaths:
        print("No files selected. Check limit_files and paths.")
        return [], []
    combined = list(zip(filepaths, labels))
    random.shuffle(combined)
    filepaths_shuffled, labels_shuffled = zip(*combined) if combined else ([], [])
    return list(filepaths_shuffled), list(labels_shuffled)

def audio_to_melspectrogram(filepath, sr=SR, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS, max_frames=MAX_FRAMES_SPEC, fmin=FMIN, fmax=FMAX):
    try:
        y, sr_orig = librosa.load(filepath, sr=None)
        if sr_orig != sr: y = librosa.resample(y, orig_sr=sr_orig, target_sr=sr)
        mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax if fmax is not None else sr/2)
        log_mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
        current_frames = log_mel_spectrogram.shape[1]
        if current_frames < max_frames:
            pad_value = log_mel_spectrogram.min()
            pad_width = max_frames - current_frames
            padded_log_mel_spectrogram = np.pad(log_mel_spectrogram, ((0, 0), (0, pad_width)), mode='constant', constant_values=pad_value)
            return padded_log_mel_spectrogram
        elif current_frames > max_frames:
            truncated_log_mel_spectrogram = log_mel_spectrogram[:, :max_frames]
            return truncated_log_mel_spectrogram
        else: return log_mel_spectrogram
    except Exception as e: return None

In [None]:
# --- 2. PyTorch Dataset ---
class AudioDataset(Dataset):
    def __init__(self, filepaths, labels, transform_spectrogram_fn, augment=False, is_vit_input=False, time_mask_max_width=TIME_MASK_MAX_WIDTH, freq_mask_max_width=FREQ_MASK_MAX_WIDTH, num_time_masks=NUM_TIME_MASKS, num_freq_masks=NUM_FREQ_MASKS, mask_replacement_value=MASK_REPLACEMENT_VALUE):
        self.filepaths = filepaths
        self.labels = labels
        self.transform_spectrogram_fn = transform_spectrogram_fn
        self.augment = augment
        self.is_vit_input = is_vit_input
        self.time_mask_max_width = time_mask_max_width
        self.freq_mask_max_width = freq_mask_max_width
        self.num_time_masks = num_time_masks
        self.num_freq_masks = num_freq_masks
        self.mask_replacement_value = mask_replacement_value

    def __len__(self):
        return len(self.filepaths)

    def _apply_time_mask(self, spectrogram):
        augmented_spec = np.copy(spectrogram)
        num_frames = augmented_spec.shape[1]
        for _ in range(self.num_time_masks):
            if self.time_mask_max_width > 0 and num_frames > self.time_mask_max_width:
                t = random.randint(1, self.time_mask_max_width)
                t0 = random.randint(0, num_frames - t)
                augmented_spec[:, t0:t0 + t] = self.mask_replacement_value
        return augmented_spec

    def _apply_freq_mask(self, spectrogram):
        augmented_spec = np.copy(spectrogram)
        num_mels = augmented_spec.shape[0]
        for _ in range(self.num_freq_masks):
            if self.freq_mask_max_width > 0 and num_mels > self.freq_mask_max_width:
                f = random.randint(1, self.freq_mask_max_width)
                f0 = random.randint(0, num_mels - f)
                augmented_spec[f0:f0 + f, :] = self.mask_replacement_value
        return augmented_spec

    def __getitem__(self, idx):
        filepath = self.filepaths[idx]
        label = self.labels[idx]
        mel_spec = self.transform_spectrogram_fn(filepath)
        if mel_spec is None: return None
        if self.augment:
            mel_spec = self._apply_time_mask(mel_spec)
            mel_spec = self._apply_freq_mask(mel_spec)
        mean = np.mean(mel_spec); std = np.std(mel_spec)
        mel_spec_normalized = (mel_spec - mean) / (std + NORM_EPSILON)
        if self.is_vit_input:
            mel_spec_final = np.stack([mel_spec_normalized]*3, axis=0)
        else:
            mel_spec_final = np.expand_dims(mel_spec_normalized, axis=0)
        mel_spec_tensor = torch.tensor(mel_spec_final, dtype=torch.float32)
        label_tensor = torch.tensor(label, dtype=torch.float32)
        return mel_spec_tensor, label_tensor

def collate_fn_skip_none_vit(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if not batch:
        return torch.empty((0, 3, N_MELS, MAX_FRAMES_SPEC)), torch.empty((0,))
    return torch.utils.data.dataloader.default_collate(batch)

def collate_fn_skip_none_cnn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if not batch:
        return torch.empty((0, 1, N_MELS, MAX_FRAMES_SPEC)), torch.empty((0,))
    return torch.utils.data.dataloader.default_collate(batch)

In [None]:
# --- 3. Data Splitting ---
filepaths_all, labels_all = get_audio_files_and_labels(REAL_AUDIO_PATH, FAKE_AUDIO_PATH, limit_files=LIMIT_FILES)
if not filepaths_all: raise ValueError("Halting: No audio files found or loaded.")
print(f"\nTotal samples before splitting: {len(filepaths_all)} (Labels: Real={labels_all.count(0)}, Fake={labels_all.count(1)})")
X_train_paths, X_temp_paths, y_train, y_temp = train_test_split(filepaths_all, labels_all, test_size=(VALIDATION_RATIO + TEST_RATIO), random_state=SEED, stratify=labels_all if len(set(labels_all)) > 1 else None)
if X_temp_paths:
    relative_test_ratio = TEST_RATIO / (VALIDATION_RATIO + TEST_RATIO)
    if len(set(y_temp)) > 1 : X_val_paths, X_test_paths, y_val, y_test = train_test_split(X_temp_paths, y_temp, test_size=relative_test_ratio, random_state=SEED, stratify=y_temp)
    else: X_val_paths, X_test_paths, y_val, y_test = train_test_split(X_temp_paths, y_temp, test_size=relative_test_ratio, random_state=SEED)
else: X_val_paths, X_test_paths, y_val, y_test = [], [], [], []
print(f"Training samples: {len(X_train_paths)} (Real={y_train.count(0)}, Fake={y_train.count(1)})")
print(f"Validation samples: {len(X_val_paths)} (Real={y_val.count(0)}, Fake={y_val.count(1)})")
print(f"Test samples: {len(X_test_paths)} (Real={y_test.count(0)}, Fake={y_test.count(1)})")

# Log dataset statistics to WandB
wandb.log({
    "total_samples": len(filepaths_all),
    "real_samples": labels_all.count(0),
    "fake_samples": labels_all.count(1),
    "train_samples": len(X_train_paths),
    "train_real": y_train.count(0),
    "train_fake": y_train.count(1),
    "val_samples": len(X_val_paths),
    "val_real": y_val.count(0),
    "val_fake": y_val.count(1),
    "test_samples": len(X_test_paths),
    "test_real": y_test.count(0),
    "test_fake": y_test.count(1)
})

In [None]:
# --- 4. PyTorch Datasets and DataLoaders for ViT and CNN ---
# ViT Datasets
train_dataset_vit = AudioDataset(
    X_train_paths, y_train, transform_spectrogram_fn=audio_to_melspectrogram,
    augment=APPLY_AUGMENTATION, is_vit_input=True
)
val_dataset_vit = AudioDataset(
    X_val_paths, y_val, transform_spectrogram_fn=audio_to_melspectrogram,
    augment=False, is_vit_input=True
)
train_loader_vit = DataLoader(
    train_dataset_vit, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn_skip_none_vit
)
val_loader_vit = DataLoader(
    val_dataset_vit, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn_skip_none_vit
)
if X_test_paths:
    test_dataset_vit = AudioDataset(
        X_test_paths, y_test, transform_spectrogram_fn=audio_to_melspectrogram,
        augment=False, is_vit_input=True
    )
    test_loader_vit = DataLoader(
        test_dataset_vit, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn_skip_none_vit
    )
else:
    test_loader_vit = None
    print("Test set is empty, test_loader_vit not created.")

# CNN Datasets
train_dataset_cnn = AudioDataset(
    X_train_paths, y_train, transform_spectrogram_fn=audio_to_melspectrogram,
    augment=APPLY_AUGMENTATION, is_vit_input=False
)
val_dataset_cnn = AudioDataset(
    X_val_paths, y_val, transform_spectrogram_fn=audio_to_melspectrogram,
    augment=False, is_vit_input=False
)
train_loader_cnn = DataLoader(
    train_dataset_cnn, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn_skip_none_cnn
)
val_loader_cnn = DataLoader(
    val_dataset_cnn, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn_skip_none_cnn
)
if X_test_paths:
    test_dataset_cnn = AudioDataset(
        X_test_paths, y_test, transform_spectrogram_fn=audio_to_melspectrogram,
        augment=False, is_vit_input=False
    )
    test_loader_cnn = DataLoader(
        test_dataset_cnn, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn_skip_none_cnn
    )
else:
    test_loader_cnn = None
    print("Test set is empty, test_loader_cnn not created.")

# Test DataLoaders
if train_loader_vit and len(train_loader_vit) > 0:
    print("\n--- Testing ViT Training DataLoader (PyTorch) ---")
    try:
        sample_batch_x_vit_pt, sample_batch_y_vit_pt = next(iter(train_loader_vit))
        if sample_batch_x_vit_pt.nelement() > 0:
            print(f"PyTorch ViT Train Batch X shape: {sample_batch_x_vit_pt.shape}")
            print(f"PyTorch ViT Train Batch Y shape: {sample_batch_y_vit_pt.shape}")
            plt.figure(figsize=(10, 4))
            img_to_show = sample_batch_x_vit_pt[0, 0, :, :].cpu().numpy()
            librosa.display.specshow(img_to_show, sr=SR, hop_length=HOP_LENGTH, x_axis='time', y_axis='mel')
            plt.colorbar(format='%+2.0f dB')
            plt.title(f'Sample Spectrogram (ViT Train Batch, Ch 0, Label: {sample_batch_y_vit_pt[0].item():.0f})')
            plt.tight_layout()
            plt.savefig("sample_spectrogram_vit.png")
            wandb.log({"sample_spectrogram_vit": wandb.Image("sample_spectrogram_vit.png")})
            plt.close()
        else: print("Sample batch X from ViT DataLoader is empty after filtering.")
    except StopIteration: print("Train loader for ViT (PyTorch) is empty or all samples in first batch failed.")
    except Exception as e: print(f"Error during ViT DataLoader test: {e}")
else: print("Train loader for ViT (PyTorch) is not available or empty.")

if train_loader_cnn and len(train_loader_cnn) > 0:
    print("\n--- Testing CNN Training DataLoader (PyTorch) ---")
    try:
        sample_batch_x_cnn_pt, sample_batch_y_cnn_pt = next(iter(train_loader_cnn))
        if sample_batch_x_cnn_pt.nelement() > 0:
            print(f"PyTorch CNN Train Batch X shape: {sample_batch_x_cnn_pt.shape}")
            print(f"PyTorch CNN Train Batch Y shape: {sample_batch_y_cnn_pt.shape}")
            plt.figure(figsize=(10, 4))
            img_to_show = sample_batch_x_cnn_pt[0, 0, :, :].cpu().numpy()
            librosa.display.specshow(img_to_show, sr=SR, hop_length=HOP_LENGTH, x_axis='time', y_axis='mel')
            plt.colorbar(format='%+2.0f dB')
            plt.title(f'Sample Spectrogram (CNN Train Batch, Label: {sample_batch_y_cnn_pt[0].item():.0f})')
            plt.tight_layout()
            plt.savefig("sample_spectrogram_cnn.png")
            wandb.log({"sample_spectrogram_cnn": wandb.Image("sample_spectrogram_cnn.png")})
            plt.close()
        else: print("Sample batch X from CNN DataLoader is empty after filtering.")
    except StopIteration: print("Train loader for CNN (PyTorch) is empty or all samples in first batch failed.")
    except Exception as e: print(f"Error during CNN DataLoader test: {e}")
else: print("Train loader for CNN (PyTorch) is not available or empty.")

In [None]:
# --- 5. PyTorch Vision Transformer (ViT) Model ---
class PatchEmbed(nn.Module):
    def __init__(self, img_size=(N_MELS, MAX_FRAMES_SPEC), patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size, img_size[1] // patch_size)
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=(N_MELS, MAX_FRAMES_SPEC), patch_size=16, in_chans=3, num_classes=1,
                 embed_dim=192, depth=6, num_heads=6, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0.1, attn_drop_rate=0.1):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                  drop=drop_rate, attn_drop=attn_drop_rate)
            for i in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        nn.init.trunc_normal_(self.pos_embed, std=.02)
        nn.init.trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x[:, 0]

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

In [None]:
# --- 6. PyTorch CNN Model ---
class AudioCNN(nn.Module):
    def __init__(self, num_classes=1, dropout_rate=0.4):
        super(AudioCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.drop1 = nn.Dropout2d(dropout_rate/2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.drop2 = nn.Dropout2d(dropout_rate/2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        self.drop3 = nn.Dropout2d(dropout_rate)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.pool4 = nn.MaxPool2d(kernel_size=2)
        self.drop4 = nn.Dropout2d(dropout_rate)
        height_after_convs = N_MELS // (2**4)
        width_after_convs = MAX_FRAMES_SPEC // (2**4)
        self.fc1 = nn.Linear(256 * height_after_convs * width_after_convs, 512)
        self.bn_fc1 = nn.BatchNorm1d(512)
        self.drop_fc1 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(512, 128)
        self.bn_fc2 = nn.BatchNorm1d(128)
        self.drop_fc2 = nn.Dropout(dropout_rate)
        self.fc3 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = self.drop1(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        x = self.drop2(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool3(x)
        x = self.drop3(x)
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool4(x)
        x = self.drop4(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.drop_fc1(x)
        x = F.relu(self.bn_fc2(self.fc2(x)))
        x = self.drop_fc2(x)
        x = self.fc3(x)
        return x

In [None]:
# --- 7. Initialize Models, Optimizers, Loss ---
VIT_PATCH_SIZE = 16
VIT_EMBED_DIM = 192
VIT_DEPTH = 6
VIT_NUM_HEADS = 6
VIT_MLP_RATIO = 4.0
VIT_DROP_RATE = 0.1
VIT_ATTN_DROP_RATE = 0.1
CNN_DROPOUT_RATE = 0.4

if N_MELS % VIT_PATCH_SIZE != 0 or MAX_FRAMES_SPEC % VIT_PATCH_SIZE != 0:
    print(f"Warning: N_MELS ({N_MELS}) or MAX_FRAMES_SPEC ({MAX_FRAMES_SPEC}) is not perfectly divisible by VIT_PATCH_SIZE ({VIT_PATCH_SIZE}).")
    print("The PatchEmbed layer will crop the input to the largest divisible dimensions.")
    eff_H = (N_MELS // VIT_PATCH_SIZE) * VIT_PATCH_SIZE
    eff_W = (MAX_FRAMES_SPEC // VIT_PATCH_SIZE) * VIT_PATCH_SIZE
    print(f"Effective input to PatchEmbed will be ({eff_H}, {eff_W})")

pytorch_vit_model = VisionTransformer(
    img_size=(N_MELS, MAX_FRAMES_SPEC),
    patch_size=VIT_PATCH_SIZE,
    in_chans=3,
    num_classes=1,
    embed_dim=VIT_EMBED_DIM,
    depth=VIT_DEPTH,
    num_heads=VIT_NUM_HEADS,
    mlp_ratio=VIT_MLP_RATIO,
    qkv_bias=True,
    drop_rate=VIT_DROP_RATE,
    attn_drop_rate=VIT_ATTN_DROP_RATE
).to(DEVICE)

pytorch_cnn_model = AudioCNN(num_classes=1, dropout_rate=CNN_DROPOUT_RATE).to(DEVICE)

optimizer_vit = optim.AdamW(pytorch_vit_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
optimizer_cnn = optim.AdamW(pytorch_cnn_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
criterion_vit = nn.BCEWithLogitsLoss()
criterion_cnn = nn.BCEWithLogitsLoss()

print("\n--- PyTorch ViT Model Architecture (Simplified) ---")
total_params_vit = sum(p.numel() for p in pytorch_vit_model.parameters() if p.requires_grad)
print(f"Total trainable parameters (ViT): {total_params_vit:,}")
print(f"ViT Config: Patch={VIT_PATCH_SIZE}, EmbedDim={VIT_EMBED_DIM}, Depth={VIT_DEPTH}, Heads={VIT_NUM_HEADS}")

print("\n--- PyTorch CNN Model Architecture (Simplified) ---")
total_params_cnn = sum(p.numel() for p in pytorch_cnn_model.parameters() if p.requires_grad)
print(f"Total trainable parameters (CNN): {total_params_cnn:,}")

wandb.log({"vit_total_trainable_parameters": total_params_vit, "cnn_total_trainable_parameters": total_params_cnn})

In [None]:
# --- 8. Training Loop and Evaluation Function ---
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch_num, num_epochs, model_name="Model"):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch_num+1}/{num_epochs} [{model_name} Training]", unit="batch")
    for inputs, labels in progress_bar:
        if inputs.nelement() == 0: continue
        inputs, labels = inputs.to(device), labels.to(device).unsqueeze(1)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        preds = torch.sigmoid(outputs) > 0.5
        correct_predictions += (preds == labels).sum().item()
        total_samples += labels.size(0)
        progress_bar.set_postfix(loss=loss.item(), acc=correct_predictions/total_samples if total_samples > 0 else 0)
    epoch_loss = running_loss / total_samples if total_samples > 0 else 0
    epoch_acc = correct_predictions / total_samples if total_samples > 0 else 0
    return epoch_loss, epoch_acc

def evaluate_model_pytorch(model, val_loader, criterion, device, epoch_num=None, num_epochs=None, model_name="Model"):
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    all_labels = []
    all_preds_probs = []
    desc_str = f"{model_name} Evaluating"
    if epoch_num is not None and num_epochs is not None: desc_str = f"Epoch {epoch_num+1}/{num_epochs} [{model_name} Validation]"
    progress_bar = tqdm(val_loader, desc=desc_str, unit="batch")
    with torch.no_grad():
        for inputs, labels in progress_bar:
            if inputs.nelement() == 0: continue
            inputs, labels = inputs.to(device), labels.to(device).unsqueeze(1)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            if loss is not None: running_loss += loss.item() * inputs.size(0)
            probs = torch.sigmoid(outputs)
            preds = probs > 0.5
            correct_predictions += (preds == labels).sum().item()
            total_samples += labels.size(0)
            all_labels.extend(labels.cpu().numpy().flatten())
            all_preds_probs.extend(probs.cpu().numpy().flatten())
            if loss is not None: progress_bar.set_postfix(loss=loss.item(), acc=correct_predictions/total_samples if total_samples > 0 else 0)
            else: progress_bar.set_postfix(acc=correct_predictions/total_samples if total_samples > 0 else 0)
    epoch_loss = running_loss / total_samples if total_samples > 0 else float('inf')
    epoch_acc = correct_predictions / total_samples if total_samples > 0 else 0
    return epoch_loss, epoch_acc, np.array(all_labels), np.array(all_preds_probs)

def plot_history(train_losses, val_losses, train_accs, val_accs, model_name="Model"):
    epochs_range = range(1, len(train_losses) + 1)
    plt.figure(figsize=(14, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, train_losses, label='Training Loss')
    plt.plot(epochs_range, val_losses, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title(f'{model_name} - Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, train_accs, label='Training Accuracy')
    plt.plot(epochs_range, val_accs, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title(f'{model_name} - Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.tight_layout()
    plt.savefig(f"training_history_{model_name.lower()}.png")
    wandb.log({f"training_history_{model_name.lower()}": wandb.Image(f"training_history_{model_name.lower()}.png")})
    plt.close()

In [None]:
# --- 9. Main Training Execution for ViT ---
train_losses_vit_history = []
val_losses_vit_history = []
train_accs_vit_history = []
val_accs_vit_history = []
best_val_loss_vit = float('inf')
best_epoch_vit = -1
patience_counter_vit = 0
patience_limit_vit = 7

print(f"\n--- Starting ViT Model Training on {DEVICE} for {EPOCHS} epochs ---")
start_time_total_vit = time.time()

for epoch in range(EPOCHS):
    epoch_start_time = time.time()
    train_loss, train_acc = train_one_epoch(pytorch_vit_model, train_loader_vit, criterion_vit, optimizer_vit, DEVICE, epoch, EPOCHS, model_name="ViT")
    val_loss, val_acc, _, _ = evaluate_model_pytorch(pytorch_vit_model, val_loader_vit, criterion_vit, DEVICE, epoch, EPOCHS, model_name="ViT")
    epoch_duration = time.time() - epoch_start_time
    print(f"Epoch {epoch+1}/{EPOCHS} - ViT - "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} - "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f} - "
          f"Duration: {epoch_duration:.2f}s")
    wandb.log({
        "epoch": epoch + 1,
        "vit_train_loss": train_loss,
        "vit_train_accuracy": train_acc,
        "vit_val_loss": val_loss,
        "vit_val_accuracy": val_acc,
        "vit_epoch_duration_seconds": epoch_duration
    })
    train_losses_vit_history.append(train_loss)
    val_losses_vit_history.append(val_loss)
    train_accs_vit_history.append(train_acc)
    val_accs_vit_history.append(val_acc)
    if val_loss < best_val_loss_vit:
        best_val_loss_vit = val_loss
        best_epoch_vit = epoch + 1
        torch.save(pytorch_vit_model.state_dict(), 'best_vit_model_pytorch.pth')
        print(f"Epoch {epoch+1}: ViT Val loss improved to {val_loss:.4f}. Model saved.")
        wandb.save('best_vit_model_pytorch.pth')
        patience_counter_vit = 0
    else:
        patience_counter_vit += 1
        print(f"Epoch {epoch+1}: ViT Val loss ({val_loss:.4f}) did not improve from {best_val_loss_vit:.4f}. Patience: {patience_counter_vit}/{patience_limit_vit}")
    if patience_counter_vit >= patience_limit_vit:
        print(f"ViT Early stopping triggered at epoch {epoch+1}.")
        break

total_training_time_vit = time.time() - start_time_total_vit
print(f"--- ViT Training Finished ---")
print(f"Total ViT Training Time: {total_training_time_vit // 60:.0f}m {total_training_time_vit % 60:.0f}s")
print(f"Best ViT validation loss: {best_val_loss_vit:.4f} at epoch {best_epoch_vit}")
wandb.log({
    "vit_total_training_time_minutes": total_training_time_vit / 60,
    "vit_best_val_loss": best_val_loss_vit,
    "vit_best_epoch": best_epoch_vit
})
plot_history(train_losses_vit_history, val_losses_vit_history, train_accs_vit_history, val_accs_vit_history, "ViT")

In [None]:
# --- 10. Main Training Execution for CNN ---
train_losses_cnn_history = []
val_losses_cnn_history = []
train_accs_cnn_history = []
val_accs_cnn_history = []
best_val_loss_cnn = float('inf')
best_epoch_cnn = -1
patience_counter_cnn = 0
patience_limit_cnn = 7

print(f"\n--- Starting CNN Model Training on {DEVICE} for {EPOCHS} epochs ---")
start_time_total_cnn = time.time()

for epoch in range(EPOCHS):
    epoch_start_time = time.time()
    train_loss, train_acc = train_one_epoch(pytorch_cnn_model, train_loader_cnn, criterion_cnn, optimizer_cnn, DEVICE, epoch, EPOCHS, model_name="CNN")
    val_loss, val_acc, _, _ = evaluate_model_pytorch(pytorch_cnn_model, val_loader_cnn, criterion_cnn, DEVICE, epoch, EPOCHS, model_name="CNN")
    epoch_duration = time.time() - epoch_start_time
    print(f"Epoch {epoch+1}/{EPOCHS} - CNN - "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} - "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f} - "
          f"Duration: {epoch_duration:.2f}s")
    wandb.log({
        "epoch": epoch + 1,
        "cnn_train_loss": train_loss,
        "cnn_train_accuracy": train_acc,
        "cnn_val_loss": val_loss,
        "cnn_val_accuracy": val_acc,
        "cnn_epoch_duration_seconds": epoch_duration
    })
    train_losses_cnn_history.append(train_loss)
    val_losses_cnn_history.append(val_loss)
    train_accs_cnn_history.append(train_acc)
    val_accs_cnn_history.append(val_acc)
    if val_loss < best_val_loss_cnn:
        best_val_loss_cnn = val_loss
        best_epoch_cnn = epoch + 1
        torch.save(pytorch_cnn_model.state_dict(), 'best_cnn_model_pytorch.pth')
        print(f"Epoch {epoch+1}: CNN Val loss improved to {val_loss:.4f}. Model saved.")
        wandb.save('best_cnn_model_pytorch.pth')
        patience_counter_cnn = 0
    else:
        patience_counter_cnn += 1
        print(f"Epoch {epoch+1}: CNN Val loss ({val_loss:.4f}) did not improve from {best_val_loss_cnn:.4f}. Patience: {patience_counter_cnn}/{patience_limit_cnn}")
    if patience_counter_cnn >= patience_limit_cnn:
        print(f"CNN Early stopping triggered at epoch {epoch+1}.")
        break

total_training_time_cnn = time.time() - start_time_total_cnn
print(f"--- CNN Training Finished ---")
print(f"Total CNN Training Time: {total_training_time_cnn // 60:.0f}m {total_training_time_cnn % 60:.0f}s")
print(f"Best CNN validation loss: {best_val_loss_cnn:.4f} at epoch {best_epoch_cnn}")
wandb.log({
    "cnn_total_training_time_minutes": total_training_time_cnn / 60,
    "cnn_best_val_loss": best_val_loss_cnn,
    "cnn_best_epoch": best_epoch_cnn
})
plot_history(train_losses_cnn_history, val_losses_cnn_history, train_accs_cnn_history, val_accs_cnn_history, "CNN")

In [None]:
# --- 11. Evaluation on Test Set for ViT ---
if test_loader_vit:
    print("\n--- Evaluating ViT on Test Set with the Best Model ---")
    best_model_vit = VisionTransformer(
        img_size=(N_MELS, MAX_FRAMES_SPEC), patch_size=VIT_PATCH_SIZE, in_chans=3, num_classes=1,
        embed_dim=VIT_EMBED_DIM, depth=VIT_DEPTH, num_heads=VIT_NUM_HEADS, mlp_ratio=VIT_MLP_RATIO,
        qkv_bias=True, drop_rate=VIT_DROP_RATE, attn_drop_rate=VIT_ATTN_DROP_RATE
    ).to(DEVICE)
    try:
        best_model_vit.load_state_dict(torch.load('best_vit_model_pytorch.pth', map_location=DEVICE))
        print("Best ViT model weights loaded successfully.")
        wandb.init(project="audio-deepfake-detection", name=f"ViT_Test_Evaluation_{time.strftime('%Y%m%d_%H%M%S')}")
        test_loss, test_acc, test_labels_true, test_preds_probs = evaluate_model_pytorch(best_model_vit, test_loader_vit, criterion_vit, DEVICE, model_name="ViT")
        print(f"ViT Test Set - Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}")
        wandb.log({
            "vit_test_loss": test_loss,
            "vit_test_accuracy": test_acc
        })
        if len(test_labels_true) > 0 and len(test_preds_probs) > 0:
            test_preds_binary = (test_preds_probs > 0.5).astype(int)
            print("\nClassification Report (ViT Test Set):")
            report = classification_report(test_labels_true, test_preds_binary, target_names=['Real (0)', 'Fake (1)'], output_dict=True)
            print(classification_report(test_labels_true, test_preds_binary, target_names=['Real (0)', 'Fake (1)']))
            wandb.log({
                "vit_test_precision_real": report['Real (0)']['precision'],
                "vit_test_recall_real": report['Real (0)']['recall'],
                "vit_test_f1_real": report['Real (0)']['f1-score'],
                "vit_test_precision_fake": report['Fake (1)']['precision'],
                "vit_test_recall_fake": report['Fake (1)']['recall'],
                "vit_test_f1_fake": report['Fake (1)']['f1-score'],
                "vit_test_macro_avg_precision": report['macro avg']['precision'],
                "vit_test_macro_avg_recall": report['macro avg']['recall'],
                "vit_test_macro_avg_f1": report['macro avg']['f1-score'],
                "vit_test_weighted_avg_precision": report['weighted avg']['precision'],
                "vit_test_weighted_avg_recall": report['weighted avg']['recall'],
                "vit_test_weighted_avg_f1": report['weighted avg']['f1-score']
            })
            try:
                roc_auc = roc_auc_score(test_labels_true, test_preds_probs)
                print(f"ROC AUC Score (ViT Test Set): {roc_auc:.4f}")
                wandb.log({"vit_test_roc_auc": roc_auc})
            except ValueError as e:
                print(f"Could not calculate ROC AUC for ViT: {e}")
            print("\nConfusion Matrix (ViT Test Set):")
            cm = confusion_matrix(test_labels_true, test_preds_binary)
            disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Real', 'Fake'])
            disp.plot(cmap=plt.cm.Blues)
            plt.title('Confusion Matrix - PyTorch ViT (Test Set)')
            plt.savefig("confusion_matrix_vit.png")
            plt.close()
            wandb.log({"vit_confusion_matrix": wandb.Image("confusion_matrix_vit.png")})
        else:
            print("Not enough data in ViT test results for report/matrix.")
        wandb.finish()
    except FileNotFoundError:
        print("Error: 'best_vit_model_pytorch.pth' not found.")
    except Exception as e:
        print(f"An error occurred during ViT test set evaluation: {e}")
else:
    print("\nViT Test loader is not available. Skipping test set evaluation.")

In [None]:
# --- 12. Evaluation on Test Set for CNN ---
if test_loader_cnn:
    print("\n--- Evaluating CNN on Test Set with the Best Model ---")
    best_model_cnn = AudioCNN(num_classes=1, dropout_rate=CNN_DROPOUT_RATE).to(DEVICE)
    try:
        best_model_cnn.load_state_dict(torch.load('best_cnn_model_pytorch.pth', map_location=DEVICE))
        print("Best CNN model weights loaded successfully.")
        wandb.init(project="audio-deepfake-detection", name=f"CNN_Test_Evaluation_{time.strftime('%Y%m%d_%H%M%S')}")
        test_loss, test_acc, test_labels_true, test_preds_probs = evaluate_model_pytorch(best_model_cnn, test_loader_cnn, criterion_cnn, DEVICE, model_name="CNN")
        print(f"CNN Test Set - Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}")
        wandb.log({
            "cnn_test_loss": test_loss,
            "cnn_test_accuracy": test_acc
        })
        if len(test_labels_true) > 0 and len(test_preds_probs) > 0:
            test_preds_binary = (test_preds_probs > 0.5).astype(int)
            print("\nClassification Report (CNN Test Set):")
            report = classification_report(test_labels_true, test_preds_binary, target_names=['Real (0)', 'Fake (1)'], output_dict=True)
            print(classification_report(test_labels_true, test_preds_binary, target_names=['Real (0)', 'Fake (1)']))
            wandb.log({
                "cnn_test_precision_real": report['Real (0)']['precision'],
                "cnn_test_recall_real": report['Real (0)']['recall'],
                "cnn_test_f1_real": report['Real (0)']['f1-score'],
                "cnn_test_precision_fake": report['Fake (1)']['precision'],
                "cnn_test_recall_fake": report['Fake (1)']['recall'],
                "cnn_test_f1_fake": report['Fake (1)']['f1-score'],
                "cnn_test_macro_avg_precision": report['macro avg']['precision'],
                "cnn_test_macro_avg_recall": report['macro avg']['recall'],
                "cnn_test_macro_avg_f1": report['macro avg']['f1-score'],
                "cnn_test_weighted_avg_precision": report['weighted avg']['precision'],
                "cnn_test_weighted_avg_recall": report['weighted avg']['recall'],
                "cnn_test_weighted_avg_f1": report['weighted avg']['f1-score']
            })
            try:
                roc_auc = roc_auc_score(test_labels_true, test_preds_probs)
                print(f"ROC AUC Score (CNN Test Set): {roc_auc:.4f}")
                wandb.log({"cnn_test_roc_auc": roc_auc})
            except ValueError as e:
                print(f"Could not calculate ROC AUC for CNN: {e}")
            print("\nConfusion Matrix (CNN Test Set):")
            cm = confusion_matrix(test_labels_true, test_preds_binary)
            disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Real', 'Fake'])
            disp.plot(cmap=plt.cm.Blues)
            plt.title('Confusion Matrix - PyTorch CNN (Test Set)')
            plt.savefig("confusion_matrix_cnn.png")
            plt.close()
            wandb.log({"cnn_confusion_matrix": wandb.Image("confusion_matrix_cnn.png")})
        else:
            print("Not enough data in CNN test results for report/matrix.")
        wandb.finish()
    except FileNotFoundError:
        print("Error: 'best_cnn_model_pytorch.pth' not found.")
    except Exception as e:
        print(f"An error occurred during CNN test set evaluation: {e}")
else:
    print("\nCNN Test loader is not available. Skipping test set evaluation.")

# Close main WandB run
wandb.finish()