# BioSignal-X Baseline Benchmarks
This notebook compares EfficientNet-B0 + metadata fusion vs. ViT-based fusion on ISIC/HAM10000 templates.
Metrics: AUC, sensitivity, specificity, Brier score, ECE, and Grad-CAM visualization. Outputs results to `results/benchmark_metrics.csv`.

In [None]:
# Imports
import os, math, json, random, time, pathlib
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, confusion_matrix, f1_score, brier_score_loss
import matplotlib.pyplot as plt
from pathlib import Path

# Local imports
import sys
sys.path.append(str(Path('src').resolve()))
from data_loader import build_dataset
from models.biosignal_model import BioSignalModel
from utils.gradcam import GradCAM

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
RESULTS_DIR = Path('results'); RESULTS_DIR.mkdir(parents=True, exist_ok=True)

### Data Templates (ISIC/HAM10000)
Update the paths below to point to your local ISIC/HAM10000 datasets. For quick smoke tests, the synthetic sample dataset is used.

In [None]:
# Auto-detect dataset and prepare SkinDataset
DATA_ROOT = Path('data')

def detect_dataset():
    candidates = [
        ('isic2019', DATA_ROOT / 'isic2019'),
        ('isic2018', DATA_ROOT / 'isic2018'),
        ('ham10000', DATA_ROOT / 'ham10000'),
        ('sample', DATA_ROOT / 'synthetic_benchmark'),
    ]
    for name, root in candidates:
        meta = root / 'metadata.csv'
        if meta.exists() or root.exists():
            return name, root
    # Default to sample
    return 'sample', DATA_ROOT / 'synthetic_benchmark'

dataset_key, root = detect_dataset()
print('Using dataset:', dataset_key, 'at', root)

# Build dataset (creates sample or indexes if needed)
ds = build_dataset(root, dataset_key)

from torch.utils.data import Subset, DataLoader
indices = np.arange(len(ds))
np.random.seed(42); np.random.shuffle(indices)
split = int(0.8*len(ds))
train_idx, val_idx = indices[:split], indices[split:]
train_loader = DataLoader(Subset(ds, train_idx), batch_size=16, shuffle=True, num_workers=0)
val_loader = DataLoader(Subset(ds, val_idx), batch_size=32, shuffle=False, num_workers=0)
n_classes = 2
n_meta = ds[0][1].shape[0]

### Models: EfficientNet-B0 + metadata fusion vs. ViT-based fusion

In [None]:
# EfficientNet-based model from project
eff_model = BioSignalModel(num_classes=n_classes, metadata_dim=n_meta).to(DEVICE)

# ViT-based fusion baseline using timm
import timm
class ViTFusion(nn.Module):
    def __init__(self, num_classes, metadata_dim, vit_name='vit_base_patch16_224'):
        super().__init__()
        self.vit = timm.create_model(vit_name, pretrained=False, num_classes=0)
        vis_dim = self.vit.num_features
        self.meta = nn.Sequential(nn.Linear(metadata_dim, 64), nn.ReLU(), nn.Dropout(0.1), nn.Linear(64, 64), nn.ReLU())
        self.head = nn.Sequential(nn.Linear(vis_dim+64, 128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, num_classes))
    def forward(self, x, m):
        v = self.vit(x)
        u = self.meta(m)
        return self.head(torch.cat([v,u], dim=1))

vit_model = ViTFusion(n_classes, n_meta).to(DEVICE)

### Training/Evaluation Utilities and Calibration

In [None]:
def ece_score(probs, labels, n_bins=10):
    bins = np.linspace(0.0, 1.0, n_bins+1)
    idx = np.digitize(probs, bins) - 1
    ece = 0.0
    for b in range(n_bins):
        mask = idx == b
        if not np.any(mask):
            continue
        acc = (labels[mask] == (probs[mask] >= 0.5)).mean()
        conf = probs[mask].mean()
        ece += (np.sum(mask)/len(probs)) * abs(acc - conf)
    return float(ece)

def fairness_groups(raw_df: pd.DataFrame):
    groups = {}
    if 'gender' in raw_df.columns:
        for g in raw_df['gender'].dropna().unique():
            groups[f'gender:{g}'] = raw_df[raw_df['gender']==g].index.tolist()
    if 'age' in raw_df.columns:
        age_bins = pd.cut(raw_df['age'], bins=[0,30,50,70,200], labels=['0-30','31-50','51-70','71+'])
        for ab in age_bins.unique():
            groups[f'age:{ab}'] = raw_df[age_bins==ab].index.tolist()
    if 'skin_type' in raw_df.columns:
        for st in raw_df['skin_type'].dropna().unique():
            groups[f'skin_type:{st}'] = raw_df[raw_df['skin_type']==st].index.tolist()
    return groups

@torch.no_grad()
def evaluate(model, loader, device=DEVICE, raw_df: pd.DataFrame | None = None, subset_indices=None):
    model.eval()
    y_true, y_prob = [], []
    all_indices = []
    for i, batch in enumerate(loader):
        imgs, meta, labels = batch[0], batch[1], batch[2] if len(batch)==3 else batch[2]
        imgs, meta, labels = imgs.to(device), meta.to(device), labels.to(device)
        logits, probs_full = model(imgs, meta)
        probs = torch.softmax(logits, dim=1)[:, 1]
        y_true.append(labels.detach().cpu().numpy())
        y_prob.append(probs.detach().cpu().numpy())
        start = i * loader.batch_size
        for k in range(imgs.size(0)):
            all_indices.append(start + k)
    y_true = np.concatenate(y_true)
    y_prob = np.concatenate(y_prob)
    try:
        auc = roc_auc_score(y_true, y_prob)
    except Exception:
        auc = float('nan')
    y_pred = (y_prob >= 0.5).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    sens = tp / (tp + fn + 1e-8)
    spec = tn / (tn + fp + 1e-8)
    brier = brier_score_loss(y_true, y_prob)
    ece = ece_score(y_prob, y_true)
    metrics = {
        'auc': float(auc), 'sensitivity': float(sens), 'specificity': float(spec),
        'brier': float(brier), 'ece': float(ece)
    }
    if raw_df is not None:
        fg = fairness_groups(raw_df)
        fairness = {}
        for name, idxs in fg.items():
            local = [i for i in idxs if i < len(y_true)]
            if len(local) < 5:
                continue
            yt = y_true[local]
            yp = y_prob[local]
            try:
                g_auc = roc_auc_score(yt, yp)
            except Exception:
                g_auc = float('nan')
            ypp = (yp >= 0.5).astype(int)
            tn2, fp2, fn2, tp2 = confusion_matrix(yt, ypp).ravel() if len(set(yt))>1 else (0,0,0,0)
            fairness[name] = {
                'auc': float(g_auc),
                'sens': float(tp2/(tp2+fn2+1e-8) if (tp2+fn2)>0 else float('nan')),
                'spec': float(tn2/(tn2+fp2+1e-8) if (tn2+fp2)>0 else float('nan'))
            }
        metrics['fairness'] = fairness
    return metrics

def save_metrics(name, metrics, path=RESULTS_DIR/"benchmark_metrics.csv"):
    flat = {'model': name, **{k:v for k,v in metrics.items() if k!='fairness'}}
    if Path(path).exists():
        pd.DataFrame([flat]).to_csv(path, mode='a', header=False, index=False)
    else:
        pd.DataFrame([flat]).to_csv(path, index=False)
    # fairness dump separate JSON
    if 'fairness' in metrics:
        (RESULTS_DIR/"fairness").mkdir(exist_ok=True)
        with open(RESULTS_DIR/"fairness"/f"{name}_fairness.json", 'w', encoding='utf-8') as fh:
            json.dump(metrics['fairness'], fh, indent=2)
    return path

### Quick Training (1 epoch) and Evaluation

In [None]:
def train_one_epoch(model, loader, device=DEVICE, lr=1e-3):
    model.train()
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    for imgs, meta, labels in loader:
        imgs, meta, labels = imgs.to(device), meta.to(device, dtype=torch.float), labels.to(device)
        opt.zero_grad()
        logits, _ = model(imgs, meta)
        loss = criterion(logits, labels)
        loss.backward(); opt.step()

# Train tiny epoch and evaluate both models
eff_model = BioSignalModel(metadata_dim=n_meta).to(DEVICE)
train_one_epoch(eff_model, train_loader)
eff_metrics = evaluate(eff_model, val_loader, device=DEVICE, raw_df=getattr(ds, 'frame', None))
save_metrics('efficientnet_b0_fusion', eff_metrics)
print('EfficientNet metrics:', eff_metrics)

import timm
class ViTFusion(nn.Module):
    def __init__(self, num_classes, metadata_dim, vit_name='vit_base_patch16_224'):
        super().__init__()
        self.vit = timm.create_model(vit_name, pretrained=False, num_classes=0)
        vis_dim = self.vit.num_features
        self.meta = nn.Sequential(nn.Linear(metadata_dim, 64), nn.ReLU(), nn.Dropout(0.1), nn.Linear(64, 64), nn.ReLU())
        self.head = nn.Sequential(nn.Linear(vis_dim+64, 128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, num_classes))
    def forward(self, x, m):
        v = self.vit(x)
        u = self.meta(m)
        return self.head(torch.cat([v,u], dim=1))

vit_model = ViTFusion(n_classes, n_meta).to(DEVICE)
train_one_epoch(vit_model, train_loader)
vit_logits = lambda imgs, metas: vit_model(imgs, metas)
vit_metrics = evaluate(vit_model, val_loader, device=DEVICE, raw_df=getattr(ds, 'frame', None))
save_metrics('vit_base_patch16_fusion', vit_metrics)
print('ViT metrics:', vit_metrics)

### Grad-CAM Visualization

In [None]:
# Generate Grad-CAM on a single validation sample (EfficientNet)
batch = next(iter(val_loader))
img, meta, lbl = batch[0][0:1].to(DEVICE), batch[1][0:1].to(DEVICE, dtype=torch.float), batch[2][0:1]

eff_model.eval()
cam = GradCAM(eff_model)
cam_map = cam(img, meta)
plt.figure(figsize=(4,4))
plt.imshow(cam_map, cmap='jet')
plt.title('Grad-CAM (EfficientNet Fusion)')
plt.axis('off')
plt.show()

In [None]:
# Inter-site performance variance analysis
import json, csv
from pathlib import Path as _P
sites_col = None
raw_df = getattr(ds, 'frame', None)
if raw_df is not None:
    # Identify site columns (binary flags) or single 'site' column
    site_flag_cols = [c for c in raw_df.columns if c.lower().startswith('site_')]
    if site_flag_cols:
        site_groups = {}
        for c in site_flag_cols:
            members = raw_df[raw_df[c]==1].index.tolist()
            if members:
                site_groups[c] = members
    elif 'site' in raw_df.columns:
        site_groups = {f"site:{s}": raw_df[raw_df['site']==s].index.tolist() for s in raw_df['site'].unique()}
    else:
        site_groups = {}
    inter_rows = []
    for site_name, idxs in site_groups.items():
        if len(idxs) < 5:
            continue
        subset = Subset(ds, idxs)
        loader_site = DataLoader(subset, batch_size=32, shuffle=False, num_workers=0)
        m_eff = evaluate(eff_model, loader_site, device=DEVICE, raw_df=raw_df)
        inter_rows.append({'site': site_name, 'auc': m_eff['auc'], 'sens': m_eff['sensitivity'], 'spec': m_eff['specificity']})
    if inter_rows:
        import pandas as _pd
        inter_df = _pd.DataFrame(inter_rows)
        (_P('results')/ 'plots').mkdir(exist_ok=True)
        inter_df.to_csv('results/inter_site_variability.csv', index=False)
        # Boxplot
        plt.figure(figsize=(6,4))
        for metric in ['auc','sens','spec']:
            plt.boxplot(inter_df[metric].dropna(), positions=[{'auc':1,'sens':2,'spec':3}[metric]], widths=0.6)
        plt.xticks([1,2,3], ['AUC','Sensitivity','Specificity'])
        plt.title('Inter-Site Performance Variance')
        plt.savefig('results/plots/inter_site_variance.png', dpi=150, bbox_inches='tight')
        plt.close()
        print('Inter-site variance artifacts written.')
else:
    print('No raw dataframe available for inter-site variance analysis.')