In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import warnings

import wandb
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [2]:
# --- WandB login ---
try:
    with open("wandb_api_key.txt", "r") as file:
        wandb_api_key = file.read().strip()
    wandb.login(key=wandb_api_key)
    print("WandB login successful using wandb_api_key.")
except Exception as e:
    print(f"Failed to login to WandB: {e}. Falling back to manual login.")
    wandb.login()

wandb: ERROR Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\Nekloyh\_netrc
wandb: Currently logged in as: nekloyh (nekloyh-none) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


WandB login successful using wandb_api_key.


In [None]:
# --- Configuration Section ---
class DataConfig:
    """Configuration for audio data processing parameters."""
    # General audio processing settings
    SEED = 42  # Random seed for reproducibility

    SR = 16000  # Sample rate (Hz)
    N_FFT = 2048  # FFT window size
    HOP_LENGTH = 512  # Hop length for spectrogram
    N_MELS = 128  # Number of Mel bands
    FMIN = 0.0  # Minimum frequency (Hz)
    FMAX = 8000.0  # Maximum frequency (Hz)

    # Augmentation settings
    NUM_TIME_MASKS = 2  # Number of time masks for SpecAugment
    NUM_FREQ_MASKS = 2  # Number of frequency masks for SpecAugment
    TIME_MASK_MAX_WIDTH = 60  # Maximum width of time mask
    FREQ_MASK_MAX_WIDTH = 25  # Maximum width of frequency mask
    MASK_REPLACEMENT_VALUE = -80.0  # Value for masked regions in spectrogram
    NORM_EPSILON = 1e-6  # Small value to prevent division by zero
    LOUDNESS_LUFS = -23.0  # Target loudness (LUFS)

    # Dataset and processing options
    USE_GLOBAL_NORMALIZATION = False  # Use global mean/std for normalization
    USE_RANDOM_CROPPING = True  # Apply random cropping to spectrograms
    CACHE_DIR = "F:\\Deepfake-Audio-Detector\\processed_dataset" # Directory for processed data


In [None]:
class ModelConfig:
    """Configuration for model-specific dataset creation."""

    def __init__(
        self,
        name: str,
        audio_length_seconds: float,
        overlap_ratio: float,
        apply_augmentation: bool = False,
        apply_waveform_augmentation: bool = False,
        patch_width: int = 16,  # Must match ViTConfig patch_width
    ):
        self.name = name
        self.audio_length_seconds = audio_length_seconds
        self.overlap_ratio = overlap_ratio
        self.apply_augmentation = apply_augmentation
        self.apply_waveform_augmentation = apply_waveform_augmentation
        # Calculate max_frames_spec and ensure divisibility by patch_width
        frames = (audio_length_seconds * DataConfig.SR) / DataConfig.HOP_LENGTH
        self.max_frames_spec = int(np.ceil(frames / patch_width) * patch_width)


In [None]:
ALL_MODEL_CONFIGS = {
    "vit_balanced_dataset": ModelConfig(
        name="vit_balanced_dataset",
        audio_length_seconds=8.192,
        overlap_ratio=0.5,
        apply_augmentation=True,
        apply_waveform_augmentation=True,
    ),
    "vit_performance_dataset": ModelConfig(
        name="vit_performance_dataset",
        audio_length_seconds=10.24,  
        overlap_ratio=0.0,
        apply_augmentation=True,
        apply_waveform_augmentation=True,
    ),
}


In [None]:
# --- 1. Cấu hình Model ViT ---
class ViTConfig:
    def __init__(
        self,
        name: str,
        image_size: tuple,
        patch_size: tuple,
        dim: int,
        depth: int,
        heads: int,
        mlp_dim: int,
        dropout: float = 0.1,
        emb_dropout: float = 0.1,
        channels: int = 1,
        num_classes: int = 2,
    ):
        self.name = name
        self.image_height, self.image_width = image_size
        self.patch_height, self.patch_width = patch_size
        self.dim = dim
        self.depth = depth
        self.heads = heads
        self.mlp_dim = mlp_dim
        self.dropout = dropout
        self.emb_dropout = emb_dropout
        self.channels = channels
        self.num_classes = num_classes

        assert self.image_height % self.patch_height == 0 and self.image_width % self.patch_width == 0, \
            "Image dimensions must be divisible by the patch size."
        assert dim > 0, "Embedding dimension must be positive."
        assert depth > 0, "Number of transformer layers must be positive."
        assert heads > 0, "Number of attention heads must be positive."
        assert mlp_dim > 0, "MLP dimension must be positive."

        self.num_patches = (self.image_height // self.patch_height) * (self.image_width // self.patch_width)
        self.patch_dim = self.channels * self.patch_height * self.patch_width


In [None]:
# --- 2. Kiến trúc Mô hình Vision Transformer (AudioViT) ---
# Tái cấu trúc để sử dụng einops
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)
        self.heads = heads
        self.scale = dim_head**-0.5

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = (
            nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
            if project_out
            else nn.Identity()
        )

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = rearrange(torch.matmul(attn, v), "b h n d -> b n (h d)")
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim, dropout=0.0):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PreNorm(
                            dim,
                            Attention(
                                dim, heads=heads, dim_head=dim // heads, dropout=dropout
                            ),
                        ),
                        PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)),
                    ]
                )
            )

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


class AudioViT(nn.Module):
    def __init__(self, *, config: ViTConfig):
        super().__init__()
        self.config = config
        image_height, image_width = config.image_height, config.image_width
        patch_height, patch_width = config.patch_height, config.patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange(
                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
                p1=patch_height,
                p2=patch_width,
            ),
            nn.LayerNorm(config.patch_dim),
            nn.Linear(config.patch_dim, config.dim),
            nn.LayerNorm(config.dim),
        )

        self.pos_embedding = nn.Parameter(
            torch.randn(1, config.num_patches + 1, config.dim)
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.dim))
        self.dropout = nn.Dropout(config.emb_dropout)

        self.transformer = Transformer(
            config.dim, config.depth, config.heads, config.mlp_dim, config.dropout
        )

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(config.dim), nn.Linear(config.dim, config.num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, : (n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        # Lấy đầu ra của CLS token
        x = x[:, 0]

        return self.mlp_head(x)


In [None]:
# --- 3. Lớp Dataset cho dữ liệu đã được cache ---
class AudioDataset(Dataset):
    def __init__(self, cache_dir: str, set_type: str, n_mels: int, max_frames_spec: int):
        self.cache_path = os.path.join(cache_dir, set_type)
        self.metadata_path = os.path.join(self.cache_path, "metadata.csv")
        self.n_mels = n_mels
        self.max_frames_spec = max_frames_spec

        if not os.path.exists(self.metadata_path):
            raise FileNotFoundError(
                f"Metadata file not found: {self.metadata_path}. Please run data_preprocessing.py first."
            )

        self.metadata = pd.read_csv(self.metadata_path)

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        npy_path = os.path.join(self.cache_path, row["npy_path"])
        label = row["label"]

        try:
            spectrogram = np.load(npy_path)
            # Ensure spectrogram is 3D (channels, height, width) for ViT
            if spectrogram.ndim == 2:
                spectrogram = np.expand_dims(spectrogram, axis=0)
            elif spectrogram.ndim == 3 and spectrogram.shape[0] != 1:
                raise ValueError(
                    f"Unexpected spectrogram shape: {spectrogram.shape}. Expected (1, N_MELS, N_FRAMES)."
                )

            # Convert to float32 and then to tensor
            spectrogram = torch.from_numpy(spectrogram).float()

        except Exception as e:
            warnings.warn(f"Error loading or processing {npy_path}: {e}")
            return None  # Return None to be filtered by collate_fn

        return spectrogram, torch.tensor(label).long()

# Custom collate function to filter None values
def custom_collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        return None
    return torch.utils.data.dataloader.default_collate(batch)


In [None]:
# --- 4. Hàm Huấn luyện và Đánh giá ---
def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device,
    num_epochs,
    run_name,
    dataset_name,
):
    model.to(device)
    best_val_f1 = -1

    # Initialize wandb run
    wandb.init(
        project="audio-deepfake-detection",
        name=run_name,
        config={
            "learning_rate": optimizer.defaults["lr"],
            "epochs": num_epochs,
            "batch_size": train_loader.batch_size,
            "model_name": model.__class__.__name__,
            "model_config": model.config.__dict__,
            "dataset_name": dataset_name,
            "device": str(device),
        },
    )
    wandb.watch(model, log_freq=100)  # Log gradients and model parameters

    print(f"Starting training for {num_epochs} epochs on {device}...")

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        pbar = tqdm(
            train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Train]", leave=False
        )
        for batch_idx, batch in enumerate(pbar):
            if batch is None:  # Skip empty batches from collate_fn
                continue
            data, labels = batch
            if -1 in labels.cpu().numpy():  # Check if any -1 label exists
                continue

            data, labels = data.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item())

        avg_train_loss = total_loss / len(train_loader)

        # Validation
        val_loss, val_preds, val_labels, val_probs = evaluate_model(
            model, val_loader, criterion, device
        )
        val_acc = accuracy_score(val_labels, val_preds)
        val_f1 = f1_score(val_labels, val_preds, average="binary")
        val_roc_auc = roc_auc_score(val_labels, val_probs[:, 1])

        print(
            f"Epoch {epoch + 1} Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}, Val ROC AUC: {val_roc_auc:.4f}"
        )

        # Log metrics to wandb
        wandb.log(
            {
                "epoch": epoch,
                "train_loss": avg_train_loss,
                "val_loss": val_loss,
                "val_accuracy": val_acc,
                "val_f1_score": val_f1,
                "val_roc_auc": val_roc_auc,
            }
        )

        # Save best model
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            model_save_path = f"best_vit_model_{run_name}.pth"
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved best model with F1: {best_val_f1:.4f} to {model_save_path}")
            wandb.save(model_save_path)  # Save model to wandb

    return model


def evaluate_model(model, data_loader, criterion, device, return_cm=False):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        pbar = tqdm(data_loader, desc="Evaluation", leave=False)
        for batch in pbar:
            if batch is None:  # Skip empty batches from collate_fn
                continue
            data, labels = batch
            if -1 in labels.cpu().numpy():  # Check if any -1 label exists
                continue

            data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    avg_loss = total_loss / len(data_loader)

    if return_cm:
        cm = confusion_matrix(all_labels, all_preds)
        return (
            avg_loss,
            np.array(all_preds),
            np.array(all_labels),
            np.array(all_probs),
            cm,
        )
    else:
        return avg_loss, np.array(all_preds), np.array(all_labels), np.array(all_probs)


def plot_confusion_matrix(cm, labels=["Real", "Fake"], run_name="", save_dir="."):
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels
    )
    plt.title(f"Confusion Matrix for {run_name}")
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, f"confusion_matrix_{run_name}.png")
    plt.savefig(save_path)
    print(f"Confusion matrix saved to {save_path}")
    plt.show()
    return save_path


In [None]:
# --- Training Configuration ---
class TrainingConfig:
    """Configuration for training parameters."""

    def __init__(
        self,
        model_size: str,
        dataset_name: str,
        epochs: int,
        learning_rate: float,
        batch_size: int,
        num_workers: int,
    ):
        self.model_size = model_size
        self.dataset_name = dataset_name
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.num_workers = num_workers

        # Validate parameters
        assert model_size in ['ViT_Small', 'ViT_Medium', 'ViT_Large'], f"Model size '{model_size}' not found"
        assert dataset_name in ALL_MODEL_CONFIGS, f"Dataset name '{dataset_name}' not found in ALL_MODEL_CONFIGS"
        assert batch_size > 0, "Batch size must be positive"
        assert epochs > 0, "Number of epochs must be positive"
        assert learning_rate > 0, "Learning rate must be positive"
        assert num_workers >= 0, "Number of workers must be non-negative"


In [None]:
def run_training(training_config):
    """Main function to run the training process."""

    # Set random seed for reproducibility
    torch.manual_seed(DataConfig.SEED)
    np.random.seed(DataConfig.SEED)

    # Get configuration from TrainingConfig
    model_size = training_config.model_size
    dataset_name = training_config.dataset_name
    epochs = training_config.epochs
    learning_rate = training_config.learning_rate
    batch_size = training_config.batch_size
    num_workers = training_config.num_workers

    # Get the ModelConfig for the chosen dataset
    if dataset_name not in ALL_MODEL_CONFIGS:
        print(f"Error: Dataset name '{dataset_name}' not found in ALL_MODEL_CONFIGS.")
        print("Please ensure 'data_preprocessing.py' defines this dataset name.")
        return

    current_dataset_model_config = ALL_MODEL_CONFIGS[dataset_name]
    max_frames_spec = current_dataset_model_config.max_frames_spec

    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Determine ViT configuration based on model_size
    vit_configs = {
        "ViT_Small": ViTConfig(
            name="ViT_Small",
            image_size=(
                DataConfig.N_MELS,
                current_dataset_model_config.max_frames_spec,
            ),
            patch_size=(16, 16),
            dim=224,  # Embedding dimension
            depth=8,  # Number of Transformer blocks
            heads=8,  # Number of attention heads
            mlp_dim=224 * 4,  # Typically 4x dim
        ),
        "ViT_Medium": ViTConfig(
            name="ViT_Medium",
            image_size=(
                DataConfig.N_MELS,
                current_dataset_model_config.max_frames_spec,
            ),
            patch_size=(16, 16),
            dim=256,
            depth=12,
            heads=8,
            mlp_dim=256 * 4,
        ),
        "ViT_Large": ViTConfig(
            name="ViT_Large",
            image_size=(
                DataConfig.N_MELS,
                current_dataset_model_config.max_frames_spec,
            ),
            patch_size=(16, 16),
            dim=384,
            depth=16,
            heads=12,
            mlp_dim=384 * 4,
        ),
    }

    if model_size not in vit_configs:
        print(
            f"Error: Invalid model_size '{model_size}'. Choose from {list(vit_configs.keys())}"
        )
        return

    vit_config = vit_configs[model_size]
    print(f"Configuring {vit_config.name} model...")
    print(
        f"Image size: {vit_config.image_height}x{vit_config.image_width}, Patch size: {vit_config.patch_height}x{vit_config.patch_width}"
    )
    print(
        f"Dim: {vit_config.dim}, Depth: {vit_config.depth}, Heads: {vit_config.heads}"
    )

    # Dataset paths
    base_cache_dir = DataConfig.CACHE_DIR
    model_cache_dir = os.path.join(base_cache_dir, dataset_name)

    print(f"Loading data from: {model_cache_dir}")

    # Datasets
    train_dataset = AudioDataset(
        model_cache_dir, "train", DataConfig.N_MELS, max_frames_spec
    )
    val_dataset = AudioDataset(
        model_cache_dir, "val", DataConfig.N_MELS, max_frames_spec
    )
    test_dataset = AudioDataset(
        model_cache_dir, "test", DataConfig.N_MELS, max_frames_spec
    )

    # DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn,
    )

    print(f"Train samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    print(f"Using Batch size: {batch_size}")

    # Initialize model, loss, and optimizer
    model = AudioViT(config=vit_config)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Run name for W&B
    run_name = f"{model_size}_{dataset_name}_{datetime.now().strftime('%H%M%S')}"

    # Train the model
    trained_model = train_model(
        model,
        train_loader,
        val_loader,
        optimizer,
        criterion,
        device,
        epochs,
        run_name,
        dataset_name,
    )

    # Evaluate on test set
    print(f"\n--- Evaluating {vit_config.name} on Test Set ({dataset_name}) ---")
    test_loss, test_preds, test_labels, test_probs, test_cm = evaluate_model(
        trained_model, test_loader, criterion, device, return_cm=True
    )

    test_acc = accuracy_score(test_labels, test_preds)
    test_f1 = f1_score(test_labels, test_preds, average="binary")
    test_roc_auc = roc_auc_score(test_labels, test_probs[:, 1])

    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.4f}")
    print(f"Test F1-score: {test_f1:.4f}")
    print(f"Test ROC AUC: {test_roc_auc:.4f}")

    # Plot and save confusion matrix
    cm_plot_path = plot_confusion_matrix(test_cm, run_name=run_name, save_dir="results")

    # Log test metrics to W&B
    wandb.log(
        {
            "test_loss": test_loss,
            "test_accuracy": test_acc,
            "test_f1_score": test_f1,
            "test_roc_auc": test_roc_auc,
            "confusion_matrix": wandb.Image(cm_plot_path),
        }
    )
    wandb.finish()

    return trained_model


In [None]:
# Configure training parameters
training_params = {
    "model_size": "ViT_Small",
    "dataset_name": "vit_balanced_dataset",
    "epochs": 20,
    "learning_rate": 1e-4,
    "batch_size": 32,
    "num_workers": 8,
}

# Initialize TrainingConfig with the specified parameters
training_config = TrainingConfig(**training_params)

# Print training configuration
print("=== ViT Training Configuration ===")
print(f"Model Size: {training_config.model_size}")
print(f"Dataset: {training_config.dataset_name}")
print(f"Epochs: {training_config.epochs}")
print(f"Learning Rate: {training_config.learning_rate}")
print(f"Batch Size: {training_config.batch_size}")
print(f"Num Workers: {training_config.num_workers}")
print("==================================")

# Run training with the configured TrainingConfig
trained_model = run_training(training_config)
