### Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from einops import rearrange
from einops.layers.torch import Rearrange
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import warnings
import wandb
from dataclasses import dataclass

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# WandB login
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
try:
    secret_value_0 = user_secrets.get_secret("wandb_api_key")
    wandb.login(key=secret_value_0)
    print("WandB login successful using wandb_api_key.")
except Exception as e:
    print(f"Failed to login to WandB: {e}. Please ensure WANDB_API_KEY is set.")
    raise


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnekloyh[0m ([33mnekloyh-none[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


WandB login successful using wandb_api_key.


### Configurations

In [2]:
@dataclass
class Config:
    # Data processing (các tham số này là chung cho dataset, không thay đổi giữa các model)
    SEED: int = 42
    SR: int = 16000
    N_FFT: int = 2048
    HOP_LENGTH: int = 512
    N_MELS: int = 128
    FMIN: float = 0.0
    FMAX: float = 8000.0
    NUM_TIME_MASKS: int = 2
    NUM_FREQ_MASKS: int = 2
    TIME_MASK_MAX_WIDTH: int = 30
    FREQ_MASK_MAX_WIDTH: int = 15
    MASK_REPLACEMENT_VALUE: float = -80.0
    NORM_EPSILON: float = 1e-6
    LOUDNESS_LUFS: float = -23.0
    USE_GLOBAL_NORMALIZATION: bool = True
    USE_RANDOM_CROPPING: bool = True
    # CHỖ NÀY LÀ ĐƯỜNG DẪN CHUNG CHO CẢ HAI LOẠI MODEL
    CACHE_DIR_BASE: str = "/kaggle/input/vit-3s-dataset" # Đường dẫn gốc
    DATASET_SUBDIR: str = "vit_3s_dataset" # Thư mục con cụ thể của dataset
    train_dir: str = "train"
    val_dir: str = "val"
    test_dir: str = "test"
    metadata_file: str = "kaggle_metadata.csv"

    # Model architecture (các tham số này sẽ khác nhau tùy theo model_size)
    img_size: int = 224
    patch_size: int = 16
    num_classes: int = 2
    in_channels: int = 1
    dim: int = 128
    depth: int = 4
    heads: int = 4
    mlp_dim: int = 256
    dropout: float = 0.1

    # Training (có thể là chung hoặc khác nhau tùy theo model)
    learning_rate: float = 1e-4
    batch_size: int = 32
    epochs: int = 20
    weight_decay: float = 1e-4
    num_workers: int = 4

    # Data augmentation
    apply_augmentation: bool = True
    augmentation_prob: float = 0.5
    audio_length_seconds: float = 3.0
    overlap_ratio: float = 0.5

    # Thuộc tính để lưu trữ tên model và dataset cho mục đích cấu hình và logging
    model_size: str = ""
    dataset_name: str = "" # Tên logic của dataset, ví dụ "vit_3s_dataset"

    def validate(self):
        assert self.img_size % self.patch_size == 0, ("img_size must be divisible by patch_size")
        assert self.dim % self.heads == 0, "dim must be divisible by heads"
        assert self.learning_rate > 0, "learning_rate must be positive"
        assert self.batch_size > 0, "batch_size must be positive"
        assert self.epochs > 0, "epochs must be positive"
        assert self.num_workers >= 0, "num_workers must be non-negative"

    # Hàm trợ giúp để tạo đường dẫn cache đầy đủ
    def get_full_cache_dir(self):
        return os.path.join(self.CACHE_DIR_BASE, self.DATASET_SUBDIR)

In [3]:
BASE_CONFIG = Config()

# Lấy tất cả các tham số từ BASE_CONFIG ngoại trừ các tham số kiến trúc và tên model/dataset
# mà chúng ta muốn ghi đè riêng cho từng loại model
base_params = {
    f.name: getattr(BASE_CONFIG, f.name)
    for f in BASE_CONFIG.__dataclass_fields__.values()
    if f.init and f.name not in ['dim', 'depth', 'heads', 'mlp_dim', 'model_size', 'dataset_name']
}

ALL_MODEL_CONFIGS = {
    "ViT_Small": Config(
        **base_params, # Giải nén các tham số chung từ BASE_CONFIG
        # Ghi đè các tham số cụ thể cho ViT_Small
        dim=128,
        depth=4,
        heads=4,
        mlp_dim=256,
        model_size="ViT_Small", # Đặt tên model_size
        dataset_name="vit_3s_dataset", # Đặt tên dataset logic
    ),
    "ViT_Large": Config(
        **base_params, # Giải nén các tham số chung từ BASE_CONFIG
        # Ghi đè các tham số cụ thể cho ViT_Large
        dim=384,
        depth=6,
        heads=8,
        mlp_dim=768,
        model_size="ViT_Large", # Đặt tên model_size
        dataset_name="vit_3s_dataset", # Đặt tên dataset logic
    ),
}

### Model Definition

In [4]:
class ViT_Audio(nn.Module):
    def __init__(self, img_size, patch_size, num_classes, in_channels, dim, depth, heads, mlp_dim, dropout: float = 0.1): # THÊM dropout VÀO ĐÂY
        super().__init__()
        assert img_size % patch_size == 0, "Image dimensions must be divisible by the patch size."
        num_patches = (img_size // patch_size) ** 2
        patch_dim = in_channels * patch_size * patch_size

        self.patch_size = patch_size

        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(in_channels, patch_dim, kernel_size=patch_size, stride=patch_size),
            Rearrange('b c h w -> b (h w) c'),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        transformer_layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=heads,
            dim_feedforward=mlp_dim,
            dropout=dropout, # TRUYỀN dropout VÀO ĐÂY
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=depth)

        self.ln = nn.LayerNorm(dim)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        x = self.to_patch_embedding(x)
        b, n, _ = x.shape

        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed # Positional embedding
        x = self.transformer(x) # Chú ý rằng PyTorch's TransformerEncoderLayer/Encoder tự xử lý dropout nội bộ

        cls_token_final = x[:, 0]
        x = self.ln(cls_token_final)
        return self.mlp_head(x)

### Dataset

In [5]:
class AudioDataset(Dataset):
    # Đã cập nhật: Loại bỏ 'max_frames_spec' khỏi __init__ vì nó sẽ không được sử dụng để cắt/đệm trực tiếp
    def __init__(self, cache_dir: str, set_type: str, n_mels: int, config: Config):
        self.cache_path = os.path.join(cache_dir, getattr(config, f"{set_type}_dir"))
        self.metadata_path = os.path.join(self.cache_path, config.metadata_file)
        self.n_mels = n_mels
        # self.target_frames = max_frames_spec # Dòng này không còn được sử dụng để cắt/đệm trước
        self.training = set_type == "train"
        self.config = config

        if not os.path.exists(self.metadata_path):
            raise FileNotFoundError(f"Metadata file not found: {self.metadata_path}")
        self.metadata = pd.read_csv(self.metadata_path)

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        npy_path = os.path.join(self.cache_path, row["npy_path"])
        label = int(row["label"])

        try:
            if not os.path.exists(npy_path):
                raise FileNotFoundError(f"Spectrogram file not found: {npy_path}")
            spectrogram = np.load(npy_path)
            spectrogram = self._preprocess_spectrogram(spectrogram)
        except Exception as e:
            print(f"Error loading {npy_path}: {e}")
            return None

        return spectrogram, torch.tensor(label).long()

    def _preprocess_spectrogram(self, spec):
        if isinstance(spec, np.ndarray):
            spec = torch.from_numpy(spec).float()

        if spec.ndim == 2:
            spec = spec.unsqueeze(0)
        elif spec.ndim == 4:
            spec = spec.squeeze(0)

        # Đã cập nhật: Loại bỏ logic cắt/đệm dựa trên target_frames (max_frames_spec)
        # Giờ đây chỉ nội suy trực tiếp đến kích thước img_size x img_size của ViT
        if spec.shape[-2:] != (self.config.img_size, self.config.img_size):
            spec = F.interpolate(
                spec.unsqueeze(0),
                size=(self.config.img_size, self.config.img_size),
                mode="bilinear",
                align_corners=False,
            ).squeeze(0)

        if self.config.USE_GLOBAL_NORMALIZATION:
            mean = spec.mean()
            std = spec.std() + self.config.NORM_EPSILON
            spec = (spec - mean) / std

        if self.training and self.config.apply_augmentation:
            # Augmentation chỉ áp dụng sau khi đã nội suy về kích thước cuối cùng
            for _ in range(self.config.NUM_FREQ_MASKS):
                freq_mask_width = torch.randint(
                    0, self.config.FREQ_MASK_MAX_WIDTH, (1,)
                ).item()
                # Đảm bảo tần số mask không vượt quá kích thước img_size (chiều cao)
                freq_start = torch.randint(
                    0, max(1, spec.shape[-2] - freq_mask_width), (1,)
                ).item()
                spec[:, freq_start : freq_start + freq_mask_width, :] = (
                    self.config.MASK_REPLACEMENT_VALUE
                )
            for _ in range(self.config.NUM_TIME_MASKS):
                time_mask_width = torch.randint(
                    0, self.config.TIME_MASK_MAX_WIDTH, (1,)
                ).item()
                # Đảm bảo thời gian mask không vượt quá kích thước img_size (chiều rộng)
                time_start = torch.randint(
                    0, max(1, spec.shape[-1] - time_mask_width), (1,)
                ).item()
                spec[:, :, time_start : time_start + time_mask_width] = (
                    self.config.MASK_REPLACEMENT_VALUE
                )

        return spec


def custom_collate_fn(batch):
    valid_batch = [item for item in batch if item is not None]

    if not valid_batch:
        print("Warning: Empty batch after filtering")
        return torch.empty(0, 1, 224, 224), torch.empty(0, dtype=torch.long)

    data_list, label_list = zip(*valid_batch)
    expected_shape = (1, 224, 224)
    valid_data = []
    valid_labels = []

    for data, label in zip(data_list, label_list):
        if isinstance(data, torch.Tensor) and data.shape == expected_shape:
            valid_data.append(data)
            valid_labels.append(label)
        else:
            print(
                f"Warning: Skipping invalid shape {data.shape if hasattr(data, 'shape') else type(data)}"
            )

    if not valid_data:
        print("Warning: No valid data in batch")
        return torch.empty(0, 1, 224, 224), torch.empty(0, dtype=torch.long)

    return torch.stack(valid_data, dim=0), torch.stack(valid_labels, dim=0)


### Training

In [6]:
def validate_dataset(dataset, name):
    invalid_count = 0
    invalid_files = []
    for idx in range(len(dataset)):
        row = dataset.metadata.iloc[idx]
        npy_path = os.path.join(dataset.cache_path, row["npy_path"])
        if not os.path.exists(npy_path):
            invalid_count += 1
            invalid_files.append(npy_path)
    if invalid_count > 0:
        print(f"Warning: {invalid_count} invalid samples found in {name} dataset")
        for f in invalid_files[:5]:
            print(f"  - Missing file: {f}")
        if len(invalid_files) > 5:
            print(f"  ... and {len(invalid_files) - 5} more")
    return invalid_count


def train_model(
    model, train_loader, val_loader, optimizer, criterion, device, num_epochs, run_name
):
    model.to(device)
    best_val_f1 = -1
    patience = 5
    patience_counter = 0
    warmup_epochs = 3

    # Đã cập nhật: Đảm bảo T_max được tính toán chính xác
    # Lịch trình Cosine Annealing cho phần sau của quá trình đào tạo
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs - warmup_epochs, eta_min=1e-6
    )
    # Lịch trình Warmup cho các epoch đầu tiên
    warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs
        if epoch < warmup_epochs
        else 1.0,
    )

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        batch_count = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")
        for batch_idx, batch in enumerate(pbar):
            if batch is None or len(batch[0]) == 0:
                print(f"Warning: Skipping empty batch at index {batch_idx}")
                continue

            data, labels = batch
            data, labels = data.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_loss += loss.item()
            batch_count += 1

            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        if batch_count == 0:
            print(f"Error: No valid batches in epoch {epoch + 1}")
            continue

        if epoch < warmup_epochs:
            warmup_scheduler.step()
        else:
            scheduler.step()

        val_loss, val_preds, val_labels, val_probs = evaluate_model(
            model, val_loader, criterion, device
        )
        val_acc = accuracy_score(val_labels, val_preds)
        val_f1 = f1_score(val_labels, val_preds, average="binary")
        val_roc_auc = roc_auc_score(val_labels, val_probs[:, 1])

        print(
            f"Epoch {epoch + 1}: Train Loss: {total_loss / batch_count:.4f}, "
            f"Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}"
        )

        wandb.log(
            {
                "epoch": epoch,
                "train_loss": total_loss / batch_count,
                "val_loss": val_loss,
                "val_f1": val_f1,
                "val_accuracy": val_acc,
                "val_roc_auc": val_roc_auc,
                "learning_rate": optimizer.param_groups[0]["lr"],
                "warmup_phase": epoch < warmup_epochs,
            }
        )

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            patience_counter = 0
            model_save_path = f"best_model_{run_name}.pth"
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "epoch": epoch,
                    "best_val_f1": best_val_f1,
                },
                model_save_path,
            )
            print(f"Saved best model with F1: {best_val_f1:.4f}")
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break

    return model


def evaluate_model(model, loader, criterion, device, return_cm=False):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        pbar = tqdm(loader, desc="Evaluating", leave=False)
        for batch in pbar:
            if batch is None or len(batch[0]) == 0:
                continue
            data, labels = batch
            if -1 in labels.cpu().numpy():
                continue

            data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    if len(all_labels) < 2:
        print("Warning: Too few samples for reliable evaluation")
        return (
            float("inf"),
            [],
            [],
            np.array([]),
            np.zeros((2, 2)) if return_cm else None,
        )

    avg_loss = total_loss / len(loader) if len(loader) > 0 else 0.0
    if return_cm:
        cm = confusion_matrix(all_labels, all_preds)
        return avg_loss, all_preds, all_labels, np.array(all_probs), cm
    return avg_loss, all_preds, all_labels, np.array(all_probs)


def plot_confusion_matrix(cm, run_name, save_dir="results"):
    os.makedirs(save_dir, exist_ok=True)
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=["Real", "Fake"],
        yticklabels=["Real", "Fake"],
        ax=ax,
    )
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    ax.set_title(f"Confusion Matrix - {run_name}")
    cm_plot_path = os.path.join(save_dir, f"cm_{run_name}.png")
    fig.savefig(cm_plot_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return cm_plot_path


def run_training(training_params): # training_params sẽ là dictionary chứa model_size, epochs, v.v.
    torch.manual_seed(Config.SEED) # Vẫn dùng Config.SEED chung
    np.random.seed(Config.SEED)

    model_size = training_params["model_size"]
    epochs = training_params["epochs"]
    learning_rate = training_params["learning_rate"]
    batch_size = training_params["batch_size"]
    num_workers = training_params["num_workers"]

    # Lấy cấu hình đầy đủ cho model_size cụ thể từ ALL_MODEL_CONFIGS
    if model_size not in ALL_MODEL_CONFIGS:
        print(f"Error: Model size '{model_size}' not found in ALL_MODEL_CONFIGS.")
        return

    config = ALL_MODEL_CONFIGS[model_size] # config bây giờ chứa tất cả params cho model đó

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Khởi tạo mô hình dựa trên các tham số từ 'config'
    model = ViT_Audio(
        img_size=config.img_size,
        patch_size=config.patch_size,
        num_classes=config.num_classes,
        in_channels=config.in_channels,
        dim=config.dim, # Lấy dim từ config
        depth=config.depth, # Lấy depth từ config
        heads=config.heads, # Lấy heads từ config
        mlp_dim=config.mlp_dim, # Lấy mlp_dim từ config
        dropout=config.dropout,
    )
    model = model.to(device)
    # summary(model, input_size=(1, 224, 224)) # Đảm bảo dòng này đã được comment hoặc thay thế
    param_count = sum(p.numel() for p in model.parameters())
    print(f"Configuring {model_size} model with {param_count:,} parameters...")

    # Sử dụng hàm get_full_cache_dir để tạo đường dẫn dataset
    model_cache_dir = config.get_full_cache_dir()
    print(f"Loading data from: {model_cache_dir}")

    # Khởi tạo dataset với 'config' hoàn chỉnh
    train_dataset = AudioDataset(model_cache_dir, "train", config.N_MELS, config)
    val_dataset = AudioDataset(model_cache_dir, "val", config.N_MELS, config)
    test_dataset = AudioDataset(model_cache_dir, "test", config.N_MELS, config)

    for dataset, name in [
        (train_dataset, "train"),
        (val_dataset, "val"),
        (test_dataset, "test"),
    ]:
        invalid_count = validate_dataset(dataset, name)
        if invalid_count == len(dataset):
            print(f"Error: All samples in {name} dataset are invalid")
            return

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn,
    )

    print(f"Train samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    print(f"Using Batch size: {batch_size}")

    class_counts = np.bincount(
        [
            train_dataset[i][1].item()
            for i in range(len(train_dataset))
            if train_dataset[i] is not None
        ]
    )
    if 0 in class_counts:
        print(
            f"Error: Class {np.argwhere(class_counts == 0).flatten()} has no samples in training dataset"
        )
        return
    class_weights = torch.tensor(
        [1.0 / max(count, 1e-6) for count in class_counts], dtype=torch.float
    ).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(
        model.parameters(), lr=learning_rate, weight_decay=config.weight_decay
    )

    run_name = f"{model_size}_{config.dataset_name}_{datetime.now().strftime('%H%M%S')}"
    wandb.init(project="audio-deepfake-detection", name=run_name, config=training_params) # Logging training_params
    
    trained_model = train_model(
        model,
        train_loader,
        val_loader,
        optimizer,
        criterion,
        device,
        epochs,
        run_name,
    )

    print(f"\n--- Evaluating {model_size} on Test Set ({config.dataset_name}) ---")
    test_loss, test_preds, test_labels, test_probs, test_cm = evaluate_model(
        trained_model, test_loader, criterion, device, return_cm=True
    )

    test_acc = accuracy_score(test_labels, test_preds)
    test_f1 = f1_score(test_labels, test_preds, average="binary")
    test_roc_auc = roc_auc_score(test_labels, test_probs[:, 1])

    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.4f}")
    print(f"Test F1-score: {test_f1:.4f}")
    print(f"Test ROC AUC: {test_roc_auc:.4f}")

    cm_plot_path = plot_confusion_matrix(test_cm, run_name=run_name, save_dir="results")
    wandb.log(
        {
            "test_loss": test_loss,
            "test_accuracy": test_acc,
            "test_f1_score": test_f1,
            "test_roc_auc": test_roc_auc,
            "confusion_matrix": wandb.Image(cm_plot_path),
        }
    )
    wandb.finish()

    return trained_model


### Run Training

In [7]:
training_params_small = {
    "model_size": "ViT_Small", 
    "epochs": 20,
    "learning_rate": 1e-4,
    "batch_size": 32,
    "num_workers": 4,
}

training_params_large = {
    "model_size": "ViT_Large", 
    "epochs": 20,
    "learning_rate": 1e-4,
    "batch_size": 32,
    "num_workers": 4,
}


print("=== Training ViT_Small ===")
trained_model_small = run_training(training_params_small)

print("\n=== Training ViT_Large ===")
trained_model_large = run_training(training_params_large)

=== Training ViT_Small ===
Using device: cuda
Configuring ViT_Small model with 655,490 parameters...
Loading data from: /kaggle/input/vit-3s-dataset/vit_3s_dataset
Train samples: 102896
Validation samples: 6996
Test samples: 14066
Using Batch size: 32


[34m[1mwandb[0m: Tracking run with wandb version 0.19.9
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250609_040441-docydruq[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mViT_Small_vit_3s_dataset_040441[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/nekloyh-none/audio-deepfake-detection[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/nekloyh-none/audio-deepfake-detection/runs/docydruq[0m
Epoch 1/20: 100%|██████████| 3216/3216 [01:43<00:00, 30.93it/s, loss=0.2755]
                                                             

Epoch 1: Train Loss: 0.4096, Val Loss: 0.2894, Val F1: 0.8710
Saved best model with F1: 0.8710


Epoch 2/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.34it/s, loss=0.2086]
                                                             

Epoch 2: Train Loss: 0.3297, Val Loss: 0.2893, Val F1: 0.8667


Epoch 3/20: 100%|██████████| 3216/3216 [01:41<00:00, 31.60it/s, loss=0.2333]
                                                             

Epoch 3: Train Loss: 0.2992, Val Loss: 0.2051, Val F1: 0.9032
Saved best model with F1: 0.9032


Epoch 4/20: 100%|██████████| 3216/3216 [01:41<00:00, 31.60it/s, loss=0.2517]
                                                             

Epoch 4: Train Loss: 0.2735, Val Loss: 0.1591, Val F1: 0.9334
Saved best model with F1: 0.9334


Epoch 5/20: 100%|██████████| 3216/3216 [01:41<00:00, 31.60it/s, loss=0.2979]
                                                             

Epoch 5: Train Loss: 0.2575, Val Loss: 0.1482, Val F1: 0.9356
Saved best model with F1: 0.9356


Epoch 6/20: 100%|██████████| 3216/3216 [01:41<00:00, 31.56it/s, loss=0.3246]
                                                             

Epoch 6: Train Loss: 0.2427, Val Loss: 0.1283, Val F1: 0.9458
Saved best model with F1: 0.9458


Epoch 7/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.33it/s, loss=0.2582]
                                                             

Epoch 7: Train Loss: 0.2305, Val Loss: 0.1239, Val F1: 0.9543
Saved best model with F1: 0.9543


Epoch 8/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.46it/s, loss=0.2188]
                                                             

Epoch 8: Train Loss: 0.2223, Val Loss: 0.1090, Val F1: 0.9595
Saved best model with F1: 0.9595


Epoch 9/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.51it/s, loss=0.3748]
                                                             

Epoch 9: Train Loss: 0.2103, Val Loss: 0.1486, Val F1: 0.9406


Epoch 10/20: 100%|██████████| 3216/3216 [01:41<00:00, 31.53it/s, loss=0.0452]
                                                             

Epoch 10: Train Loss: 0.2059, Val Loss: 0.1118, Val F1: 0.9458


Epoch 11/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.49it/s, loss=0.2826]
                                                             

Epoch 11: Train Loss: 0.1963, Val Loss: 0.0998, Val F1: 0.9622
Saved best model with F1: 0.9622


Epoch 12/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.33it/s, loss=0.0654]
                                                             

Epoch 12: Train Loss: 0.1906, Val Loss: 0.0957, Val F1: 0.9622


Epoch 13/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.43it/s, loss=0.2777]
                                                             

Epoch 13: Train Loss: 0.1849, Val Loss: 0.0910, Val F1: 0.9653
Saved best model with F1: 0.9653


Epoch 14/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.30it/s, loss=0.0996]
                                                             

Epoch 14: Train Loss: 0.1792, Val Loss: 0.0858, Val F1: 0.9636


Epoch 15/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.47it/s, loss=0.2421]
                                                             

Epoch 15: Train Loss: 0.1734, Val Loss: 0.0806, Val F1: 0.9671
Saved best model with F1: 0.9671


Epoch 16/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.45it/s, loss=0.0800]
                                                             

Epoch 16: Train Loss: 0.1696, Val Loss: 0.0805, Val F1: 0.9693
Saved best model with F1: 0.9693


Epoch 17/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.36it/s, loss=0.1161]
                                                             

Epoch 17: Train Loss: 0.1681, Val Loss: 0.0792, Val F1: 0.9659


Epoch 18/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.33it/s, loss=0.6537]
                                                             

Epoch 18: Train Loss: 0.1617, Val Loss: 0.0757, Val F1: 0.9686


Epoch 19/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.42it/s, loss=0.0990]
                                                             

Epoch 19: Train Loss: 0.1616, Val Loss: 0.0785, Val F1: 0.9662


Epoch 20/20: 100%|██████████| 3216/3216 [01:42<00:00, 31.33it/s, loss=0.1252]
                                                             

Epoch 20: Train Loss: 0.1599, Val Loss: 0.0734, Val F1: 0.9702
Saved best model with F1: 0.9702

--- Evaluating ViT_Small on Test Set (vit_3s_dataset) ---


                                                             

Test Loss: 0.0733
Test Accuracy: 0.9698
Test F1-score: 0.9696
Test ROC AUC: 0.9969


[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:         epoch ▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
[34m[1mwandb[0m: learning_rate ▆████▇▇▇▆▅▅▄▄▃▂▂▂▁▁▁
[34m[1mwandb[0m: test_accuracy ▁
[34m[1mwandb[0m: test_f1_score ▁
[34m[1mwandb[0m:     test_loss ▁
[34m[1mwandb[0m:  test_roc_auc ▁
[34m[1mwandb[0m:    train_loss █▆▅▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁
[34m[1mwandb[0m:  val_accuracy ▁▁▄▆▆▇▇▇▆▇██████████
[34m[1mwandb[0m:        val_f1 ▁▁▃▆▆▆▇▇▆▆▇▇████████
[34m[1mwandb[0m:      val_loss ██▅▄▃▃▃▂▃▂▂▂▂▁▁▁▁▁▁▁
[34m[1mwandb[0m:   val_roc_auc ▁▅▅▆▇▇▇▇▇███████████
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:         epoch 19
[34m[1mwandb[0m: learning_rate 0.0
[34m[1mwandb[0m: test_accuracy 0.96979
[34m[1mwandb[0m: test_f1_score 0.96965
[34m[1mwandb[0m:     test_loss 0.07328
[34m[1mwandb[0m:  test_roc_auc 0.99691
[34m[1mwandb[0m: 


=== Training ViT_Large ===
Using device: cuda
Configuring ViT_Large model with 7,347,330 parameters...
Loading data from: /kaggle/input/vit-3s-dataset/vit_3s_dataset
Train samples: 102896
Validation samples: 6996
Test samples: 14066
Using Batch size: 32


[34m[1mwandb[0m: Tracking run with wandb version 0.19.9
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250609_044740-5hq2eoul[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mViT_Large_vit_3s_dataset_044740[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/nekloyh-none/audio-deepfake-detection[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/nekloyh-none/audio-deepfake-detection/runs/5hq2eoul[0m
Epoch 1/20: 100%|██████████| 3216/3216 [08:40<00:00,  6.18it/s, loss=0.3738]
                                                             

Epoch 1: Train Loss: 0.3577, Val Loss: 0.1940, Val F1: 0.9163
Saved best model with F1: 0.9163


Epoch 2/20: 100%|██████████| 3216/3216 [08:40<00:00,  6.18it/s, loss=0.0524]
                                                             

Epoch 2: Train Loss: 0.2861, Val Loss: 0.1454, Val F1: 0.9371
Saved best model with F1: 0.9371


Epoch 3/20: 100%|██████████| 3216/3216 [08:39<00:00,  6.19it/s, loss=0.2782]
                                                             

Epoch 3: Train Loss: 0.2654, Val Loss: 0.1493, Val F1: 0.9325


Epoch 4/20: 100%|██████████| 3216/3216 [08:37<00:00,  6.22it/s, loss=0.1506]
                                                             

Epoch 4: Train Loss: 0.2400, Val Loss: 0.1214, Val F1: 0.9493
Saved best model with F1: 0.9493


Epoch 5/20: 100%|██████████| 3216/3216 [08:34<00:00,  6.26it/s, loss=0.2821]
                                                             

Epoch 5: Train Loss: 0.2250, Val Loss: 0.1231, Val F1: 0.9522
Saved best model with F1: 0.9522


Epoch 6/20: 100%|██████████| 3216/3216 [08:30<00:00,  6.30it/s, loss=0.0994]
                                                             

Epoch 6: Train Loss: 0.2118, Val Loss: 0.1041, Val F1: 0.9609
Saved best model with F1: 0.9609


Epoch 7/20: 100%|██████████| 3216/3216 [08:27<00:00,  6.34it/s, loss=0.2316]
                                                             

Epoch 7: Train Loss: 0.1982, Val Loss: 0.1011, Val F1: 0.9558


Epoch 8/20: 100%|██████████| 3216/3216 [08:29<00:00,  6.31it/s, loss=0.0324]
                                                             

Epoch 8: Train Loss: 0.1880, Val Loss: 0.0963, Val F1: 0.9593


Epoch 9/20: 100%|██████████| 3216/3216 [08:27<00:00,  6.34it/s, loss=0.2732]
                                                             

Epoch 9: Train Loss: 0.1790, Val Loss: 0.1121, Val F1: 0.9506


Epoch 10/20: 100%|██████████| 3216/3216 [08:27<00:00,  6.34it/s, loss=0.0460]
                                                             

Epoch 10: Train Loss: 0.1694, Val Loss: 0.0975, Val F1: 0.9602


Epoch 11/20: 100%|██████████| 3216/3216 [08:26<00:00,  6.35it/s, loss=0.0954]
                                                             

Epoch 11: Train Loss: 0.1611, Val Loss: 0.0858, Val F1: 0.9647
Saved best model with F1: 0.9647


Epoch 12/20: 100%|██████████| 3216/3216 [08:28<00:00,  6.33it/s, loss=0.3018]
                                                             

Epoch 12: Train Loss: 0.1548, Val Loss: 0.0663, Val F1: 0.9733
Saved best model with F1: 0.9733


Epoch 13/20: 100%|██████████| 3216/3216 [08:25<00:00,  6.36it/s, loss=0.2208]
                                                             

Epoch 13: Train Loss: 0.1484, Val Loss: 0.0732, Val F1: 0.9706


Epoch 14/20: 100%|██████████| 3216/3216 [08:26<00:00,  6.34it/s, loss=0.1989]
                                                             

Epoch 14: Train Loss: 0.1395, Val Loss: 0.0633, Val F1: 0.9733
Saved best model with F1: 0.9733


Epoch 15/20: 100%|██████████| 3216/3216 [08:26<00:00,  6.35it/s, loss=0.1457]
                                                             

Epoch 15: Train Loss: 0.1338, Val Loss: 0.0569, Val F1: 0.9763
Saved best model with F1: 0.9763


Epoch 16/20: 100%|██████████| 3216/3216 [08:26<00:00,  6.35it/s, loss=0.1722]
                                                             

Epoch 16: Train Loss: 0.1281, Val Loss: 0.0494, Val F1: 0.9806
Saved best model with F1: 0.9806


Epoch 17/20: 100%|██████████| 3216/3216 [08:26<00:00,  6.35it/s, loss=0.0465]
                                                             

Epoch 17: Train Loss: 0.1234, Val Loss: 0.0583, Val F1: 0.9757


Epoch 18/20: 100%|██████████| 3216/3216 [08:25<00:00,  6.36it/s, loss=0.0296]
                                                             

Epoch 18: Train Loss: 0.1183, Val Loss: 0.0549, Val F1: 0.9770


Epoch 19/20: 100%|██████████| 3216/3216 [08:26<00:00,  6.35it/s, loss=0.1006]
                                                             

Epoch 19: Train Loss: 0.1156, Val Loss: 0.0479, Val F1: 0.9812
Saved best model with F1: 0.9812


Epoch 20/20: 100%|██████████| 3216/3216 [08:26<00:00,  6.35it/s, loss=0.0319]
                                                             

Epoch 20: Train Loss: 0.1124, Val Loss: 0.0493, Val F1: 0.9794

--- Evaluating ViT_Large on Test Set (vit_3s_dataset) ---


                                                             

Test Loss: 0.0529
Test Accuracy: 0.9775
Test F1-score: 0.9776
Test ROC AUC: 0.9983


[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:         epoch ▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
[34m[1mwandb[0m: learning_rate ▆████▇▇▇▆▅▅▄▄▃▂▂▂▁▁▁
[34m[1mwandb[0m: test_accuracy ▁
[34m[1mwandb[0m: test_f1_score ▁
[34m[1mwandb[0m:     test_loss ▁
[34m[1mwandb[0m:  test_roc_auc ▁
[34m[1mwandb[0m:    train_loss █▆▅▅▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁
[34m[1mwandb[0m:  val_accuracy ▁▄▃▅▅▆▆▆▅▆▆▇▇▇██▇███
[34m[1mwandb[0m:        val_f1 ▁▃▃▅▅▆▅▆▅▆▆▇▇▇▇█▇███
[34m[1mwandb[0m:      val_loss █▆▆▅▅▄▄▃▄▃▃▂▂▂▁▁▁▁▁▁
[34m[1mwandb[0m:   val_roc_auc ▁▃▄▅▆▆▆▆▇▇▇█▇███████
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:         epoch 19
[34m[1mwandb[0m: learning_rate 0.0
[34m[1mwandb[0m: test_accuracy 0.97753
[34m[1mwandb[0m: test_f1_score 0.97756
[34m[1mwandb[0m:     test_loss 0.05292
[34m[1mwandb[0m:  test_roc_auc 0.99835
[34m[1mwandb[0m: 