# BioSignal-X Baseline Benchmarks
This notebook compares EfficientNet-B0 + metadata fusion vs. ViT-based fusion on ISIC/HAM10000 templates.
Metrics: AUC, sensitivity, specificity, Brier score, ECE, and Grad-CAM visualization. Outputs results to `results/benchmark_metrics.csv`.

In [None]:
# Imports
import os, math, json, random, time, pathlib
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, confusion_matrix, f1_score, brier_score_loss
import matplotlib.pyplot as plt
from pathlib import Path

# Local imports
import sys
sys.path.append(str(Path('src').resolve()))
from data_loader import SkinDataset, create_sample_dataset
from models.biosignal_model import BioSignalModel
from utils.gradcam import GradCAM

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
RESULTS_DIR = Path('results'); RESULTS_DIR.mkdir(parents=True, exist_ok=True)

### Data Templates (ISIC/HAM10000)
Update the paths below to point to your local ISIC/HAM10000 datasets. For quick smoke tests, the synthetic sample dataset is used.

In [None]:
# Configure dataset paths (edit as needed)
DATASET_NAME = 'SYNTHETIC'  # options: SYNTHETIC, ISIC, HAM10000
DATA_ROOT = Path('data')

if DATASET_NAME.upper() == 'SYNTHETIC':
    df = create_sample_dataset(n=256, out_dir=str(DATA_ROOT / 'synthetic_benchmark'))
    from torchvision import transforms
    tfm = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    ds = SkinDataset(df, img_col='image_path', label_col='label', metadata_cols=['age','sex_M','sex_F'], transform=tfm)
else:
    # TODO: Implement real ISIC/HAM10000 readers mapping to SkinDataset-compatible DataFrame
    raise NotImplementedError('Provide ISIC/HAM10000 loaders and set DATASET_NAME accordingly.')

# Split
indices = np.arange(len(ds))
np.random.seed(42); np.random.shuffle(indices)
split = int(0.8*len(ds))
train_idx, val_idx = indices[:split], indices[split:]
from torch.utils.data import Subset, DataLoader
train_loader = DataLoader(Subset(ds, train_idx), batch_size=16, shuffle=True, num_workers=0)
val_loader = DataLoader(Subset(ds, val_idx), batch_size=32, shuffle=False, num_workers=0)
n_classes = 2
n_meta = ds[0]['metadata'].shape[0] if isinstance(ds[0], dict) else len(ds[0][2])

### Models: EfficientNet-B0 + metadata fusion vs. ViT-based fusion

In [None]:
# EfficientNet-based model from project
eff_model = BioSignalModel(num_classes=n_classes, metadata_dim=n_meta).to(DEVICE)

# ViT-based fusion baseline using timm
import timm
class ViTFusion(nn.Module):
    def __init__(self, num_classes, metadata_dim, vit_name='vit_base_patch16_224'):
        super().__init__()
        self.vit = timm.create_model(vit_name, pretrained=False, num_classes=0)
        vis_dim = self.vit.num_features
        self.meta = nn.Sequential(nn.Linear(metadata_dim, 64), nn.ReLU(), nn.Dropout(0.1), nn.Linear(64, 64), nn.ReLU())
        self.head = nn.Sequential(nn.Linear(vis_dim+64, 128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, num_classes))
    def forward(self, x, m):
        v = self.vit(x)
        u = self.meta(m)
        return self.head(torch.cat([v,u], dim=1))

vit_model = ViTFusion(n_classes, n_meta).to(DEVICE)

### Training/Evaluation Utilities and Calibration

In [None]:
def ece_score(probs, labels, n_bins=10):
    bins = np.linspace(0.0, 1.0, n_bins+1)
    idx = np.digitize(probs, bins) - 1
    ece = 0.0
    for b in range(n_bins):
        mask = idx == b
        if not np.any(mask):
            continue
        acc = (labels[mask] == (probs[mask] >= 0.5)).mean()
        conf = probs[mask].mean()
        ece += (np.sum(mask)/len(probs)) * abs(acc - conf)
    return float(ece)

@torch.no_grad()
def evaluate(model, loader, device=DEVICE):
    model.eval()
    y_true, y_prob = [], []
    for batch in loader:
        if isinstance(batch, dict):
            x = batch['image'].to(device)
            m = batch['metadata'].to(device, dtype=torch.float)
            y = batch['label'].to(device)
        else:
            x, y, m = batch
            x, y, m = x.to(device), y.to(device), m.to(device, dtype=torch.float)
        logits = model(x, m)
        probs = torch.softmax(logits, dim=1)[:, 1] if logits.shape[1] > 1 else torch.sigmoid(logits.squeeze(1))
        y_true.append(y.detach().cpu().numpy())
        y_prob.append(probs.detach().cpu().numpy())
    y_true = np.concatenate(y_true)
    y_prob = np.concatenate(y_prob)
    try:
        auc = roc_auc_score(y_true, y_prob)
    except Exception:
        auc = float('nan')
    y_pred = (y_prob >= 0.5).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    sens = tp / (tp + fn + 1e-8)
    spec = tn / (tn + fp + 1e-8)
    brier = brier_score_loss(y_true, y_prob)
    ece = ece_score(y_prob, y_true)
    return {
        'auc': float(auc), 'sensitivity': float(sens), 'specificity': float(spec),
        'brier': float(brier), 'ece': float(ece)
    }

def save_metrics(name, metrics, path=RESULTS_DIR/"benchmark_metrics.csv"):
    row = {'model': name, **metrics}
    if Path(path).exists():
        pd.DataFrame([row]).to_csv(path, mode='a', header=False, index=False)
    else:
        pd.DataFrame([row]).to_csv(path, index=False)
    return path

### Quick Training (1 epoch) and Evaluation

In [None]:
def train_one_epoch(model, loader, device=DEVICE, lr=1e-3):
    model.train()
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss() if n_classes > 1 else nn.BCEWithLogitsLoss()
    for x in loader:
        if isinstance(x, dict):
            imgs = x['image'].to(device)
            metas = x['metadata'].to(device, dtype=torch.float)
            labels = x['label'].to(device)
        else:
            imgs, labels, metas = x
            imgs, labels, metas = imgs.to(device), labels.to(device), metas.to(device, dtype=torch.float)
        opt.zero_grad()
        logits = model(imgs, metas)
        if n_classes > 1:
            loss = criterion(logits, labels)
        else:
            loss = criterion(logits.squeeze(1), labels.float())
        loss.backward(); opt.step()

# Train tiny epoch
train_one_epoch(eff_model, train_loader)
eff_metrics = evaluate(eff_model, val_loader)
save_metrics('efficientnet_b0_fusion', eff_metrics)
print('EfficientNet metrics:', eff_metrics)

train_one_epoch(vit_model, train_loader)
vit_metrics = evaluate(vit_model, val_loader)
save_metrics('vit_base_patch16_fusion', vit_metrics)
print('ViT metrics:', vit_metrics)

### Grad-CAM Visualization

In [None]:
# Generate Grad-CAM on a single validation sample (EfficientNet)
batch = next(iter(val_loader))
if isinstance(batch, dict):
    img = batch['image'][0:1].to(DEVICE)
    meta = batch['metadata'][0:1].to(DEVICE, dtype=torch.float)
else:
    img, lbl, meta = batch
    img, meta = img[0:1].to(DEVICE), meta[0:1].to(DEVICE, dtype=torch.float)

eff_model.eval()
cam = GradCAM(eff_model)
cam_map = cam(img, meta)
plt.figure(figsize=(4,4))
plt.imshow(cam_map, cmap='jet')
plt.title('Grad-CAM (EfficientNet Fusion)')
plt.axis('off')
plt.show()