Link to Colab notebook: https://colab.research.google.com/drive/1YT3M4-fIakqLP0ITYTusXDcZ1BzNvwpu?usp=sharing

All code was run locally on one team member's device

Due to time and GPU usage constraints, code was not run on Colab. For more information on results of the code (i.e. print outputs), visit this link to view the separate Jupyter notebooks used to run the project: https://drive.google.com/drive/folders/1qKlqJc7a75tm3EPjvDzx8BFxgwBB1Wkk?usp=sharing

# Data preprocessing

## Downloading SONICS dataset

Downloaded files are zipped, which are then unzipped and stored locally on a home device

In [None]:
import os
from huggingface_hub import snapshot_download

In [None]:
datasets_dir = f'{os.getcwd()}/datasets'
test = snapshot_download(
    repo_id="awsaf49/sonics",
    repo_type="dataset",
    local_dir=datasets_dir
)

## Downloading real music clips

The SONICS dataset comes with a real_music.csv file, which contains YouTube links to real music files. The YoutubeDL pipeline is used here to download these files and store them locally as mp3 files

In [None]:
import os
import pandas as pd
from yt_dlp import YoutubeDL
from concurrent.futures import ThreadPoolExecutor, as_completed

In [None]:
def download_single_task(args):
    """
    Worker function to download a single YouTube ID to MP3 at given bitrate.
    Returns (youtube_id, success: bool, message: str).
    """
    youtube_id, output_dir, bitrate = args
    url = f"https://www.youtube.com/watch?v={youtube_id}"
    outtmpl = os.path.join(output_dir, f"{youtube_id}.%(ext)s")
    ydl_opts = {
        'format': 'bestaudio/best',
        'postprocessors': [{
            'key': 'FFmpegExtractAudio',
            'preferredcodec': 'mp3',
            'preferredquality': str(bitrate),
        }],
        'noplaylist': True,
        'outtmpl': outtmpl,
    }
    try:
        with YoutubeDL(ydl_opts) as ydl:
            ydl.download([url])
        mp3_path = os.path.join(output_dir, f"{youtube_id}.mp3")
        if os.path.isfile(mp3_path):
            return youtube_id, True, f"Downloaded: {mp3_path}"
        else:
            return youtube_id, False, f"Missing file after download"
    except Exception as e:
        return youtube_id, False, str(e)


def download_and_convert(
    csv_path: str,
    output_dir: str,
    target_count: int = 10000,
    bitrate: int = 64,
    workers: int = 4
):
    """
    Reads a CSV with a 'youtube_id' column, downloads audio as MP3 via yt-dlp
    in parallel across multiple threads, until target_count successes.
    Uses ThreadPoolExecutor with given number of worker threads.
    """
    os.makedirs(output_dir, exist_ok=True)
    df = pd.read_csv(csv_path)
    if 'youtube_id' not in df.columns:
        raise ValueError("CSV must contain a 'youtube_id' column")

    total_rows = len(df)
    success_count = 0
    submitted = 0

    # Prepare list of IDs
    ids = [str(y).strip() for y in df['youtube_id'].tolist() if str(y).strip()]
    print(f"Starting parallel download: target={target_count}, workers={workers}, total IDs available={len(ids)}")

    with ThreadPoolExecutor(max_workers=workers) as executor:
        # Submit jobs up to all IDs, but will stop early when target reached
        future_to_id = {}
        for youtube_id in ids:
            if submitted >= total_rows:
                break
            future = executor.submit(download_single_task, (youtube_id, output_dir, bitrate))
            future_to_id[future] = youtube_id
            submitted += 1

        for future in as_completed(future_to_id):
            youtube_id = future_to_id[future]
            success = False
            try:
                _, success, msg = future.result()
            except Exception as e:
                msg = str(e)
            if success:
                success_count += 1
            print(f"[{success_count}/{target_count}] {youtube_id}: {msg}")
            if success_count >= target_count:
                # Cancel remaining futures
                for fut in future_to_id:
                    if not fut.done():
                        fut.cancel()
                break

    print(f"Finished: {success_count} successful downloads (target was {target_count})")

In [None]:
download_and_convert(
    csv_path=f"{os.getcwd()}\\datasets\\real_songs.csv",
    output_dir=f"{os.getcwd()}\\datasets\\real_songs",
    target_count=10000,
    workers=12
)

## Converting mp3 files to mel-spectrograms

The mp3 files are then converted into mel-spectrograms of dimension 256x256 (with 3 channels) with the help of the librosa library

In [None]:
import os
import glob
import librosa
import numpy as np
import torch

In [None]:
def process_melspectrograms(
    input_dir: str,
    output_dir: str,
    sample_rate: int = 22050,
    n_mels: int = 256,
    hop_length: int = 512,
    fmin: int = 0,
    fmax: int = None,
    top_db: int = 80,
    target_frames: int = 256,
    num_channels: int = 3,
) -> None:
    """
    Convert all .mp3 files in input_dir to normalized Mel-spectrogram
    tensors of shape (num_channels, 256, 256) and save them as .pt in output_dir.

    - Uses n_mels=256 bins and fixes time dimension to 256 frames.
    - Outputs 3-channel if num_channels=3 by duplicating the grayscale mel.
    """
    os.makedirs(output_dir, exist_ok=True)
    mp3_paths = glob.glob(os.path.join(input_dir, "*.mp3"))

    for mp3_path in mp3_paths:
        fname = os.path.splitext(os.path.basename(mp3_path))[0]

        # 1. Load audio
        y, sr = librosa.load(mp3_path, sr=sample_rate)

        # 2. Compute Mel-spectrogram
        S = librosa.feature.melspectrogram(
            y=y, sr=sr,
            n_mels=n_mels,
            hop_length=hop_length,
            fmin=fmin,
            fmax=fmax
        )

        # 3. Convert to dB
        S_db = librosa.power_to_db(S, ref=1.0, top_db=top_db)

        # 4. Crop or pad time axis to target_frames
        _, t = S_db.shape
        if t >= target_frames:
            S_db = S_db[:, :target_frames]
        else:
            pad_amt = target_frames - t
            S_db = np.pad(
                S_db,
                pad_width=((0, 0), (0, pad_amt)),
                mode="constant",
                constant_values=-top_db
            )

        # 5. Normalize to [0, 1]
        S_norm = np.clip((S_db + top_db) / top_db, 0.0, 1.0).astype(np.float32)

        # 6. To torch tensor, grayscale
        tensor = torch.from_numpy(S_norm).unsqueeze(0)  # (1, 256, 256)

        # 7. Expand to num_channels (3 => RGB by replication)
        if num_channels == 3:
            tensor = tensor.repeat(3, 1, 1)
        elif num_channels != 1:
            raise ValueError(f"Unsupported num_channels={num_channels}")

        # 8. Save
        out_path = os.path.join(output_dir, f"{fname}_mel.pt")
        torch.save(tensor, out_path)

    print(f"Processed {len(mp3_paths)} files → {output_dir} (shape: {num_channels},256,256)")

In [None]:
# Process fake songs to mel-spectograms
process_melspectrograms(
    input_dir=f"{os.getcwd()}\\datasets\\fake_songs",
    output_dir=f"{os.getcwd()}\\datasets\\fake_songs_mel"
)

# Process real songs to mel-spectograms
process_melspectrograms(
    input_dir=f"{os.getcwd()}\\datasets\\real_songs",
    output_dir=f"{os.getcwd()}\\datasets\\real_songs_mel"
)

An example from both real and fake music datasets are plotted

In [None]:
import matplotlib.pyplot as plt

# Load a Mel-spectrogram
mel_tensor = torch.load(f"{os.getcwd()}\\datasets\\fake_songs_mel\\fake_00001_suno_0_mel.pt")
mel_np = mel_tensor.numpy()  # shape: (mel_bins, time)

# Plot it
plt.figure(figsize=(10, 4))
if mel_np.ndim == 2:
    # single-channel: previous behavior
    plt.imshow(mel_np, aspect="auto", origin="lower", cmap="magma")
    plt.colorbar(label="dB (normalized)")
else:
    # multi-channel RGB: permute and drop the colormap
    # shape -> (H, W, C)
    mel_img = mel_np.transpose(1, 2, 0)
    plt.imshow(mel_img, aspect="auto", origin="lower")
plt.xlabel("Time frames")
plt.ylabel("Mel bins")
plt.title("Mel-Spectrogram")
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Load a Mel-spectrogram
mel_tensor = torch.load(f"{os.getcwd()}\\datasets\\real_songs_mel\\_4-0dnvknOY_mel.pt")
mel_np = mel_tensor.numpy()  # shape: (mel_bins, time)

# Plot it
plt.figure(figsize=(10, 4))
if mel_np.ndim == 2:
    # single-channel: previous behavior
    plt.imshow(mel_np, aspect="auto", origin="lower", cmap="magma")
    plt.colorbar(label="dB (normalized)")
else:
    # multi-channel RGB: permute and drop the colormap
    # shape -> (H, W, C)
    mel_img = mel_np.transpose(1, 2, 0)
    plt.imshow(mel_img, aspect="auto", origin="lower")
plt.xlabel("Time frames")
plt.ylabel("Mel bins")
plt.title("Mel-Spectrogram")
plt.tight_layout()
plt.show()

## Creating train/val/test split

As a new dataset of 10,000 real and 10,000 fake songs was engineered (due to runtime constraints), we created a 60:20:20 train/val/test split

In [None]:
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from pathlib import Path

In [None]:
def create_split_csvs(
    real_dir,
    fake_dir,
    output_dir,
    train_ratio: float = 0.6,
    val_ratio: float = 0.2,
    test_ratio: float = 0.2,
    random_state: int = 42
):
    """
    Generate train/val/test CSV files listing mel-spectrogram .pt file paths and labels.

    Args:
        real_dir: Directory containing real mel .pt files (label=0).
        fake_dir: Directory containing fake mel .pt files (label=1).
        output_dir: Directory to save train.csv, val.csv, and test.csv.
    """
    real_dir = Path(real_dir)
    fake_dir = Path(fake_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Gather file paths using pathlib (produces POSIX paths compatible with torch.load)
    real_paths = sorted(real_dir.glob('*.pt'))
    fake_paths = sorted(fake_dir.glob('*.pt'))

    real_labels = [0] * len(real_paths)
    fake_labels = [1] * len(fake_paths)

    # Combine lists
    all_paths = [p.as_posix() for p in real_paths + fake_paths]
    all_labels = real_labels + fake_labels

    # First split off test set
    train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
        all_paths,
        all_labels,
        test_size=test_ratio,
        random_state=random_state,
        stratify=all_labels
    )

    # Then split train and validation
    relative_val_size = val_ratio / (train_ratio + val_ratio)
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        train_val_paths,
        train_val_labels,
        test_size=relative_val_size,
        random_state=random_state,
        stratify=train_val_labels
    )

    # Create DataFrames
    train_df = pd.DataFrame({'file_path': train_paths, 'label': train_labels})
    val_df   = pd.DataFrame({'file_path': val_paths,   'label': val_labels})
    test_df  = pd.DataFrame({'file_path': test_paths,  'label': test_labels})

    # Save CSVs
    train_df.to_csv(output_dir / 'train.csv', index=False)
    val_df.to_csv(output_dir / 'val.csv',   index=False)
    test_df.to_csv(output_dir / 'test.csv', index=False)

    print(f"Saved splits to {output_dir}:")
    print(f"  train.csv: {len(train_df)} samples")
    print(f"  val.csv:   {len(val_df)} samples")
    print(f"  test.csv:  {len(test_df)} samples")

In [None]:
base_dir = Path.cwd() / 'datasets'
create_split_csvs(
    real_dir=base_dir / 'real_songs_mel',
    fake_dir=base_dir / 'fake_songs_mel',
    output_dir=base_dir
)

#Deep learning boilerplate code

Defining classes and functions essential for the training pipeline later on for both models

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import timm
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

## MelCSVDataset

MelCSVDataset is a class defined for the DataLoader function to take mini batches later on

In [None]:
class MelCSVDataset(Dataset):
    """
    Loads mel-spectrogram .pt files and labels from a CSV.
    CSV must have 'file_path' and 'label' columns.

    Args:
      csv_path: path to train.csv or val.csv
      img_size:  size to interpolate spectrogram to (HxW)
      num_channels: output channels (1 or 3)
      unit_variance: if True, (x - mean)/std; else min–max to [0,1]
      eps: small constant to avoid div-by-zero
    """
    def __init__(
        self,
        csv_path: str,
        img_size: int = 256,
        num_channels: int = 3,
        unit_variance: bool = False,
        eps: float = 1e-6
    ):
        df = pd.read_csv(csv_path)
        self.paths = df['file_path'].tolist()
        self.labels = df['label'].astype(int).tolist()
        self.img_size = img_size
        self.num_channels = num_channels
        self.unit_variance = unit_variance
        self.eps = eps

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        # load raw mel-spectrogram tensor, shape (C, H, W)
        mel = torch.load(self.paths[idx], map_location='cpu', weights_only=True)
        label = self.labels[idx]

        # adapt channels
        if mel.size(0) == 1 and self.num_channels == 3:
            mel = mel.repeat(3, 1, 1)
        elif mel.size(0) > self.num_channels:
            mel = mel[:self.num_channels]

        # resize to (num_channels, img_size, img_size)
        mel = F.interpolate(
            mel.unsqueeze(0),
            size=(self.img_size, self.img_size),
            mode='bilinear',
            align_corners=False
        ).squeeze(0)

        # ==== NORMALIZATION ====
        if self.unit_variance:
            mean = mel.mean()
            std = mel.std()
            mel = (mel - mean) / (std + self.eps)
        else:
            mn = mel.min()
            mx = mel.max()
            mel = (mel - mn) / (mx - mn + self.eps)
        # =======================

        return mel, torch.tensor(label, dtype=torch.long)

## train_one_epoch / validate_one_epoch

Boilerplate code for the training and validation of one epoch, with losses and accuracy computed and printed so progress can be monitored

In [None]:
# Training and validation loops
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    for batch_idx, (x, y) in enumerate(tqdm(loader, desc='Train batches')):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        loss = criterion(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if batch_idx % 20 == 0:
            print(f"  Batch {batch_idx}/{len(loader)} — loss: {loss.item():.4f}")
    return total_loss / len(loader)

def validate_one_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(tqdm(loader, desc='Val batches')):
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            logits = model(x)
            loss = criterion(logits, y)
            total_loss += loss.item()
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    avg_loss = total_loss / len(loader)
    acc = correct / total
    print(f"Validation Loss: {avg_loss:.4f}, Accuracy: {acc:.4f}")
    return avg_loss, acc

# CNN

Model: 'lambda_resnet26rpt_256.c1_in1k'

In [None]:
# Main entry point
def CNN(
    train_csv: str,
    val_csv: str,
    batch_size: int = 32,
    workers: int = 0,
    epochs: int = 5,
    lr: float = 1e-5,
    model_name: str = 'lambda_resnet26rpt_256.c1_in1k'
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')

    # — Dataset & Dataloaders —
    train_loader = DataLoader(
        MelCSVDataset(train_csv, unit_variance=True), batch_size=batch_size,
        shuffle=True, num_workers=workers, pin_memory=True
    )
    val_loader = DataLoader(
        MelCSVDataset(val_csv, unit_variance=True), batch_size=batch_size,
        shuffle=False, num_workers=workers, pin_memory=True
    )

    # — Model, optimizer, criterion —
    model = timm.create_model(
        model_name,
        pretrained=True,
        num_classes=2,
        in_chans=3
    ).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # — Logging setup —
    os.makedirs('logs', exist_ok=True)
    log_csv = os.path.join('logs', 'losses_lambda.csv')
    # If starting fresh, write header; if resuming, append
    write_header = not os.path.exists('checkpoints_lambda/resume.pt')
    with open(log_csv, 'a') as f:
        if write_header:
            f.write('epoch,train_loss,val_loss,val_acc\n')

    # — Checkpoint setup —
    ckpt_dir        = 'checkpoints_lambda'
    os.makedirs(ckpt_dir, exist_ok=True)
    resume_path     = os.path.join(ckpt_dir, 'resume.pt')
    best_model_path = os.path.join(ckpt_dir, 'best_model.pt')

    start_epoch = 1
    best_val_loss = float('inf')

    # — Resume if possible —
    if os.path.exists(resume_path):
        ckpt = torch.load(resume_path, map_location=device)
        model.load_state_dict(ckpt['model_state'])
        optimizer.load_state_dict(ckpt['optim_state'])
        start_epoch = ckpt['epoch'] + 1
        best_val_loss = ckpt['best_val_loss']
        print(f"Resumed from epoch {ckpt['epoch']} with best_val_loss = {best_val_loss:.4f}")

    # — Storage for curves —
    train_losses, val_losses, val_accs = [], [], []

    # — Training loop with checkpointing —
    for epoch in tqdm(range(start_epoch, epochs + 1), desc='Epochs'):
        # ---- train ----
        tl = train_one_epoch(model, train_loader, optimizer, criterion, device)
        # ---- validate ----
        vl, va = validate_one_epoch(model, val_loader, criterion, device)

        # ---- record logs ----
        train_losses.append(tl)
        val_losses.append(vl)
        val_accs.append(va)
        with open(log_csv, 'a') as f:
            f.write(f"{epoch},{tl:.6f},{vl:.6f},{va:.4f}\n")
        print(f"Epoch {epoch}/{epochs} - Train: {tl:.4f}, Val Loss: {vl:.4f}, Val Acc: {va:.4f}")

        # ---- save best-model snapshot first (if improved) ----
        if vl < best_val_loss:
            best_val_loss = vl
            torch.save(model.state_dict(), best_model_path)
            print(f"🎉 New best model saved at epoch {epoch} (val_loss {vl:.4f})")

        # ---- now save the resume checkpoint with the updated best_val_loss ----
        torch.save({
            'epoch'        : epoch,
            'model_state'  : model.state_dict(),
            'optim_state'  : optimizer.state_dict(),
            'best_val_loss': best_val_loss
        }, resume_path)

    # — After training, save final plots —
    epochs_range = range(1, len(train_losses) + 1)
    plt.figure()
    plt.plot(epochs_range, train_losses, label='Train Loss')
    plt.plot(epochs_range, val_losses,   label='Val Loss')
    plt.legend()
    plt.savefig('logs/loss_curve_lambda.png')

    plt.figure()
    plt.plot(epochs_range, val_accs, label='Val Acc')
    plt.legend()
    plt.savefig('logs/acc_curve_lambda.png')

    print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))

In [None]:
# Training for 5 epochs
CNN(
    train_csv=f"{os.getcwd()}\\datasets\\train.csv",
    val_csv=f"{os.getcwd()}\\datasets\\val.csv"
)

After training, the finetuned CNN model is ran on the test set

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
import numpy as np

# Obtain testing accuracy
test_csv = f"{os.getcwd()}\\datasets\\test.csv"
batch_size = 32
workers = 0
model_name = 'lambda_resnet26rpt_256.c1_in1k'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckpt_path = f"{os.getcwd()}\\checkpoints_lambda\\best_model.pt"

test_loader = DataLoader(
        MelCSVDataset(test_csv, unit_variance=True), batch_size=batch_size,
        shuffle=True, num_workers=workers, pin_memory=True
    )

model = timm.create_model(
    model_name,
    pretrained=False,      # we’ll load our own weights
    num_classes=2,
    in_chans=3
).to(device)
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state)
model.eval()

y_true = []
y_pred = []
y_proba = []

with torch.no_grad():
    for x, y in tqdm(test_loader, desc='Test batches'):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        probs = F.softmax(logits, dim=1)[:, 1]
        preds = logits.argmax(dim=1)

        y_true.extend(y.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())
        y_proba.extend(probs.cpu().numpy())

y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_proba = np.array(y_proba)

# Compute metrics
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)  # Sensitivity
f1 = f1_score(y_true, y_pred)
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
specificity = tn / (tn + fp)
auc = roc_auc_score(y_true, y_proba)

# Print results
print(f"Precision: {precision:.4f}")
print(f"Recall (Sensitivity): {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Specificity: {specificity:.4f}")
print(f"AUC-ROC: {auc:.4f}")

# Vision Transformer

Model: 'swinv2_small_window16_256.ms_in1k'

In [None]:
# Main entry point
def ViT(
    train_csv: str,
    val_csv: str,
    batch_size: int = 20,
    workers: int = 0,
    epochs: int = 5,
    lr: float = 1e-4,       # changed
    model_name: str = 'swinv2_small_window16_256.ms_in1k'
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')

    # — Dataset & Dataloaders —
    train_loader = DataLoader(
        MelCSVDataset(train_csv, unit_variance=True), batch_size=batch_size,
        shuffle=True, num_workers=workers, pin_memory=True
    )
    val_loader = DataLoader(
        MelCSVDataset(val_csv, unit_variance=True), batch_size=batch_size,
        shuffle=False, num_workers=workers, pin_memory=True
    )

    # — Model, optimizer, criterion —
    model = timm.create_model(
        model_name,
        pretrained=True,
        num_classes=2,
        in_chans=3
    ).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # — Logging setup —
    os.makedirs('logs', exist_ok=True)
    log_csv = os.path.join('logs', 'losses_swim.csv')
    # If starting fresh, write header; if resuming, append
    write_header = not os.path.exists('checkpoints_swim/resume.pt')
    with open(log_csv, 'a') as f:
        if write_header:
            f.write('epoch,train_loss,val_loss,val_acc\n')

    # — Checkpoint setup —
    ckpt_dir        = 'checkpoints_swim'
    os.makedirs(ckpt_dir, exist_ok=True)
    resume_path     = os.path.join(ckpt_dir, 'resume.pt')
    best_model_path = os.path.join(ckpt_dir, 'best_model.pt')

    start_epoch = 1
    best_val_loss = float('inf')

    # — Resume if possible —
    if os.path.exists(resume_path):
        ckpt = torch.load(resume_path, map_location=device)
        model.load_state_dict(ckpt['model_state'])
        optimizer.load_state_dict(ckpt['optim_state'])
        start_epoch = ckpt['epoch'] + 1
        best_val_loss = ckpt['best_val_loss']
        print(f"Resumed from epoch {ckpt['epoch']} with best_val_loss = {best_val_loss:.4f}")

    # — Storage for curves —
    train_losses, val_losses, val_accs = [], [], []

    # — Training loop with checkpointing —
    for epoch in tqdm(range(start_epoch, epochs + 1), desc='Epochs'):
        # ---- train ----
        tl = train_one_epoch(model, train_loader, optimizer, criterion, device)
        # ---- validate ----
        vl, va = validate_one_epoch(model, val_loader, criterion, device)

        # ---- record logs ----
        train_losses.append(tl)
        val_losses.append(vl)
        val_accs.append(va)
        with open(log_csv, 'a') as f:
            f.write(f"{epoch},{tl:.6f},{vl:.6f},{va:.4f}\n")
        print(f"Epoch {epoch}/{epochs} - Train: {tl:.4f}, Val Loss: {vl:.4f}, Val Acc: {va:.4f}")

        # ---- save best-model snapshot first (if improved) ----
        if vl < best_val_loss:
            best_val_loss = vl
            torch.save(model.state_dict(), best_model_path)
            print(f"🎉 New best model saved at epoch {epoch} (val_loss {vl:.4f})")

        # ---- now save the resume checkpoint with the updated best_val_loss ----
        torch.save({
            'epoch'        : epoch,
            'model_state'  : model.state_dict(),
            'optim_state'  : optimizer.state_dict(),
            'best_val_loss': best_val_loss
        }, resume_path)

    # — After training, save final plots —
    epochs_range = range(1, len(train_losses) + 1)
    plt.figure()
    plt.plot(epochs_range, train_losses, label='Train Loss')
    plt.plot(epochs_range, val_losses,   label='Val Loss')
    plt.legend()
    plt.savefig('logs/loss_curve_swim.png')

    plt.figure()
    plt.plot(epochs_range, val_accs, label='Val Acc')
    plt.legend()
    plt.savefig('logs/acc_curve_swim.png')

    print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))

In [None]:
# Training for 5 epochs
ViT(
    train_csv=f"{os.getcwd()}\\datasets\\train.csv",
    val_csv=f"{os.getcwd()}\\datasets\\val.csv"
)

After training, the finetuned ViT model is ran on the test set

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
import numpy as np

# Obtain testing accuracy
test_csv = f"{os.getcwd()}\\datasets\\test.csv"
batch_size = 32
workers = 0
model_name = 'swinv2_small_window16_256.ms_in1k'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckpt_path = f"{os.getcwd()}\\checkpoints_swim\\best_model.pt"

test_loader = DataLoader(
        MelCSVDataset(test_csv, unit_variance=True), batch_size=batch_size,
        shuffle=True, num_workers=workers, pin_memory=True
    )

model = timm.create_model(
    model_name,
    pretrained=False,      # we’ll load our own weights
    num_classes=2,
    in_chans=3
).to(device)
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state)
model.eval()

y_true = []
y_pred = []
y_proba = []

with torch.no_grad():
    for x, y in tqdm(test_loader, desc='Test batches'):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        probs = F.softmax(logits, dim=1)[:, 1]
        preds = logits.argmax(dim=1)

        y_true.extend(y.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())
        y_proba.extend(probs.cpu().numpy())

y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_proba = np.array(y_proba)

# Compute metrics
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)  # Sensitivity
f1 = f1_score(y_true, y_pred)
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
specificity = tn / (tn + fp)
auc = roc_auc_score(y_true, y_proba)

# Print results
print(f"Precision: {precision:.4f}")
print(f"Recall (Sensitivity): {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Specificity: {specificity:.4f}")
print(f"AUC-ROC: {auc:.4f}")