# ML Basics with MedMNIST (Binary 2D) — Colab
Use **breastmnist** or **pneumoniamnist**. Set runtime to **GPU**.


In [None]:
!pip -q install torch torchvision medmnist scikit-learn matplotlib tqdm


In [ ]:
import os, random, json, numpy as np
import torch, torch.nn as nn
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as T
import torchvision.models as tvm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, accuracy_score
import medmnist
from medmnist import INFO

def set_seed(seed: int = 42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class AverageMeter:
    def __init__(self): self.reset()
    def reset(self): self.sum = 0.0; self.cnt = 0
    def update(self, val, n=1): self.sum += float(val) * n; self.cnt += n
    @property
    def avg(self): return self.sum / max(1, self.cnt)

def compute_metrics(y_true, y_prob, n_classes):
    y_pred = np.argmax(y_prob, axis=1)
    acc = accuracy_score(y_true, y_pred)
    if n_classes == 2:
        auroc = roc_auc_score(y_true, y_prob[:,1])
    else:
        y_true_1hot = np.eye(n_classes)[y_true]
        try:
            auroc = roc_auc_score(y_true_1hot, y_prob, average="macro", multi_class="ovr")
        except Exception:
            auroc = np.nan
    return {"acc": acc, "auroc": auroc}

def reliability_diagram(y_true, y_prob, n_bins=10):
    confidences = np.max(y_prob, axis=1)
    preds = np.argmax(y_prob, axis=1)
    correct = (preds == y_true).astype(np.float32)
    bins = np.linspace(0, 1, n_bins+1)
    ece = 0.0
    bin_accs, bin_confs = [], []
    xs = np.linspace(0.5/n_bins, 1-0.5/n_bins, n_bins)
    for i in range(n_bins):
        lo, hi = bins[i], bins[i+1]
        mask = (confidences > lo) & (confidences <= hi) if i>0 else (confidences >= lo) & (confidences <= hi)
        if mask.sum() == 0:
            bin_accs.append(0.0); bin_confs.append((lo+hi)/2); continue
        acc_i = correct[mask].mean()
        conf_i = confidences[mask].mean()
        frac_i = mask.mean()
        ece += abs(acc_i - conf_i) * frac_i
        bin_accs.append(acc_i); bin_confs.append(conf_i)
    plt.figure(); plt.plot([0,1],[0,1], linestyle='--')
    plt.bar(xs, bin_accs, width=1.0/n_bins, alpha=0.6, edgecolor='k')
    plt.plot(xs, bin_confs, marker='o')
    plt.xlabel("Confidence"); plt.ylabel("Accuracy")
    plt.title(f"Reliability Diagram (ECE={ece:.3f})")
    plt.tight_layout(); plt.show()
    return float(ece)

def get_medmnist_dataset(key: str, split: str, as_rgb=True, size=64, download=True):
    info = INFO[key]
    DataClass = getattr(medmnist, info['python_class'])
    tf = [T.Resize((size,size)), T.ToTensor()]
    if as_rgb: tf.append(T.Lambda(lambda x: x.repeat(3,1,1) if x.shape[0]==1 else x))
    transform = T.Compose(tf)
    return DataClass(split=split, transform=transform, download=download)

def get_loaders(key, batch_size=128, num_workers=2, label_frac=1.0, seed=42):
    set_seed(seed)
    ds_train = get_medmnist_dataset(key, 'train')
    ds_val   = get_medmnist_dataset(key, 'val')
    ds_test  = get_medmnist_dataset(key, 'test')
    n_classes = len(INFO[key]['label'])
    if 0 < label_frac < 1.0:
        y = np.array([int(t[1]) for t in ds_train])
        idxs = []
        for c in np.unique(y):
            cls_idx = np.where(y==c)[0]
            k = max(1, int(len(cls_idx)*label_frac))
            idxs.extend(np.random.choice(cls_idx, size=k, replace=False))
        ds_train = Subset(ds_train, sorted(idxs))
    train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader   = DataLoader(ds_val, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, val_loader, test_loader, n_classes

class SmallCNN(nn.Module):
    def __init__(self, in_ch=3, n_classes=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(128, n_classes)
    def forward(self, x):
        feat = self.net(x).flatten(1)
        return self.fc(feat)

def make_resnet18(n_classes=2, in_ch=3, pretrained=True):
    m = tvm.resnet18(weights=tvm.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
    if in_ch != 3:
        w = m.conv1.weight
        m.conv1 = torch.nn.Conv2d(in_ch, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if in_ch == 1:
            with torch.no_grad():
                m.conv1.weight.copy_(w.sum(dim=1, keepdim=True))
    in_dim = m.fc.in_features
    m.fc = torch.nn.Linear(in_dim, n_classes)
    return m

print('✅ Setup ready')

In [ ]:
# ---- Config (choose ONE binary dataset) ----
set_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

DATASET_KEY = 'pneumoniamnist'  # or 'breastmnist'
MODEL_NAME  = 'resnet18'        # 'smallcnn' if slow
EPOCHS      = 5
BATCH_SIZE  = 128
LR          = 3e-4
WEIGHT_DECAY= 1e-4
FINETUNE    = 'head'            # 'head' or 'all' (for resnet18)
ECE_BINS    = 15

In [ ]:
# ---- Train & evaluate once ----
train_loader, val_loader, test_loader, n_classes = get_loaders(DATASET_KEY, batch_size=BATCH_SIZE, label_frac=1.0, seed=42)

if MODEL_NAME == 'smallcnn':
    model = SmallCNN(in_ch=3, n_classes=n_classes)
    params = model.parameters()
else:
    model = make_resnet18(n_classes=n_classes, in_ch=3, pretrained=True)
    if FINETUNE == 'head':
        for p in model.parameters(): p.requires_grad = False
        for p in model.fc.parameters(): p.requires_grad = True
        params = model.fc.parameters()
    else:
        params = model.parameters()

model.to(device)
opt = torch.optim.AdamW(params, lr=LR, weight_decay=WEIGHT_DECAY)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
criterion = torch.nn.CrossEntropyLoss()

best_score = -1.0; best_state = None
for epoch in range(1, EPOCHS+1):
    # train
    model.train()
    loss_sum = 0.0; n_sum = 0
    for x,y in train_loader:
        x = x.to(device); y = y.long().to(device)
        logits = model(x)
        loss = criterion(logits, y)
        opt.zero_grad(); loss.backward(); opt.step()
        loss_sum += float(loss.item()) * x.size(0); n_sum += x.size(0)
    # val
    model.eval()
    with torch.no_grad():
        from numpy import argmax
        import numpy as np
        yv, pv = [], []
        for x,y in val_loader:
            x=x.to(device)
            p = torch.softmax(model(x), dim=1).cpu().numpy()
            pv.append(p); yv.append(y.numpy())
        pv = np.concatenate(pv, 0); yv = np.concatenate(yv, 0)
    from sklearn.metrics import roc_auc_score, accuracy_score
    acc = accuracy_score(yv, np.argmax(pv, axis=1))
    try:
        auroc = roc_auc_score(yv, pv[:,1])
    except Exception:
        auroc = float('nan')
    score = auroc if auroc==auroc else acc
    print(f"[{epoch:02d}] loss={loss_sum/max(1,n_sum):.4f} | val acc={acc:.4f} auroc={auroc:.4f}")
    sch.step()

# test
if best_state is not None:
    model.load_state_dict(best_state, strict=True)
yt, pt = [], []
with torch.no_grad():
    for x,y in test_loader:
        x=x.to(device)
        p = torch.softmax(model(x), dim=1).cpu().numpy()
        pt.append(p); yt.append(y.numpy())
import numpy as np
pt = np.concatenate(pt, 0); yt = np.concatenate(yt, 0)
from sklearn.metrics import roc_auc_score, accuracy_score
acc = accuracy_score(yt, np.argmax(pt, axis=1))
try:
    auroc = roc_auc_score(yt, pt[:,1])
except Exception:
    auroc = float('nan')
print(f"TEST | acc={acc:.4f} auroc={auroc:.4f}")

from math import isnan
def reliability_diagram_inline(y_true, y_prob, n_bins=10):
    confidences = np.max(y_prob, axis=1)
    preds = np.argmax(y_prob, axis=1)
    correct = (preds == y_true).astype(np.float32)
    bins = np.linspace(0, 1, n_bins+1)
    ece = 0.0
    xs = np.linspace(0.5/n_bins, 1-0.5/n_bins, n_bins)
    bin_accs, bin_confs = [], []
    for i in range(n_bins):
        lo, hi = bins[i], bins[i+1]
        mask = (confidences > lo) & (confidences <= hi) if i>0 else (confidences >= lo) & (confidences <= hi)
        if mask.sum() == 0:
            bin_accs.append(0.0); bin_confs.append((lo+hi)/2); continue
        acc_i = correct[mask].mean(); conf_i = confidences[mask].mean(); frac_i = mask.mean()
        ece += abs(acc_i - conf_i) * frac_i
        bin_accs.append(acc_i); bin_confs.append(conf_i)
    import matplotlib.pyplot as plt
    plt.figure(); plt.plot([0,1],[0,1], linestyle='--')
    plt.bar(xs, bin_accs, width=1.0/n_bins, alpha=0.6, edgecolor='k')
    plt.plot(xs, bin_confs, marker='o')
    plt.xlabel('Confidence'); plt.ylabel('Accuracy')
    plt.title(f'Reliability Diagram (ECE={ece:.3f})')
    plt.tight_layout(); plt.show()
    return float(ece)

ece = reliability_diagram_inline(yt, pt, n_bins=15)
print('ECE =', ece)