In [1]:
# ============================================================
# Patch-Level Attention MIL Network (PLA-MIL) for HAM10000
# - Uses lesion crops (224x224)
# - Splits into 4x4 patches (16 patches)
# - Encodes patches with ResNet-18 backbone
# - Attention pooling over patches -> classification
# ============================================================

import os
from pathlib import Path
import random
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc

# ----------------- CONFIG -----------------
DATA_ROOT = Path("/kaggle/input/ham1000-segmentation-and-classification")
CROP_IMG_DIR = Path("/kaggle/input/ham10000-segment-data/ham_lesion_crops")  # lesion-centered crops/kaggle/input/ham10000-segment-data/ham_lesion_crops
CSV_PATH = DATA_ROOT / "GroundTruth.csv"

OUTPUT_DIR = Path("/kaggle/working/ham_plamil_outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = 224               # input size for lesion crop
PATCH_GRID = 4               # 4x4 -> 16 patches
PATCH_SIZE = IMG_SIZE // PATCH_GRID  # 56
BATCH_SIZE = 16
NUM_EPOCHS = 100
LR = 1e-4
VAL_SPLIT = 0.2
SEED = 42
NUM_WORKERS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

CLASS_NAMES = ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"]
NUM_CLASSES = len(CLASS_NAMES)

print("Using device:", DEVICE)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)

# ----------------- LOAD LABELS -----------------
df = pd.read_csv(CSV_PATH)
label_array = df[CLASS_NAMES].values
df["label_idx"] = label_array.argmax(axis=1)

print(df.head())

# ----------------- TRAIN / VAL SPLIT -----------------
train_df, val_df = train_test_split(
    df,
    test_size=VAL_SPLIT,
    random_state=SEED,
    stratify=df["label_idx"]
)

print(f"Total samples: {len(df)}")
print(f"Train: {len(train_df)}, Val: {len(val_df)}")

# ----------------- DATASET -----------------

class PatchMILDataset(Dataset):
    """
    Dataset that:
      - loads lesion crop (224x224)
      - applies transforms
      - splits into 4x4 patches of size 56x56
    Returns:
      patches: (num_patches, 3, PATCH_SIZE, PATCH_SIZE)
      label_idx: int
    """
    def __init__(self, dataframe, img_dir, img_size=224, patch_grid=4, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.img_dir = img_dir
        self.img_size = img_size
        self.patch_grid = patch_grid
        self.patch_size = img_size // patch_grid
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_id = row["image"]
        label = int(row["label_idx"])

        img_path = self.img_dir / f"{img_id}.jpg"
        img = Image.open(img_path).convert("RGB")

        if self.transform:
            img = self.transform(img)  # (3,H,W) with H=W=img_size

        # split into patches using unfold
        # img: (C,H,W)
        C, H, W = img.shape
        assert H == self.img_size and W == self.img_size, "Image not resized correctly."

        patches = img.unfold(1, self.patch_size, self.patch_size) \
                    .unfold(2, self.patch_size, self.patch_size)
        # patches: (C, grid_y, grid_x, patch_size, patch_size)
        patches = patches.permute(1, 2, 0, 3, 4)  # (grid_y, grid_x, C, patch_size, patch_size)
        patches = patches.reshape(-1, C, self.patch_size, self.patch_size)  # (num_patches, C, H_p, W_p)

        return patches, label

# ----------------- TRANSFORMS -----------------
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

train_dataset = PatchMILDataset(train_df, CROP_IMG_DIR, IMG_SIZE, PATCH_GRID, transform=train_transform)
val_dataset   = PatchMILDataset(val_df,   CROP_IMG_DIR, IMG_SIZE, PATCH_GRID, transform=val_transform)

# ----------------- CLASS IMBALANCE HANDLING -----------------
class_counts = train_df["label_idx"].value_counts().sort_index().values
print("Class counts (train):", dict(zip(range(NUM_CLASSES), class_counts)))

class_weights = 1.0 / class_counts
sample_weights = class_weights[train_df["label_idx"].values]
sample_weights = torch.from_numpy(sample_weights).float()

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

def collate_patches(batch):
    """
    batch: list of (patches_i, label_i)
    patches_i: (num_patches, C, H, W)
    returns:
      patches: (B, num_patches, C, H, W)
      labels: (B,)
    """
    patches_list, labels_list = zip(*batch)
    patches = torch.stack(patches_list, dim=0)   # (B, num_patches, C, H, W)
    labels = torch.tensor(labels_list, dtype=torch.long)
    return patches, labels

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=collate_patches,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=collate_patches,
)

# ----------------- MODEL: Patch-Level Attention MIL -----------------

class PatchEncoder(nn.Module):
    """
    CNN encoder for patches using ResNet-18 backbone (without final FC).
    """
    def __init__(self, out_dim=256):
        super().__init__()
        base = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        # remove the final fc; keep everything up to global pooling
        self.features = nn.Sequential(*list(base.children())[:-1])  # -> (B,512,1,1)
        self.proj = nn.Linear(512, out_dim)

    def forward(self, x):
        # x: (B_patches, 3, H, W)
        feat = self.features(x)  # (B_patches, 512, 1, 1)
        feat = feat.view(feat.size(0), -1)  # (B_patches, 512)
        feat = self.proj(feat)             # (B_patches, out_dim)
        return feat


class AttentionMIL(nn.Module):
    """
    Attention-based MIL pooling.
    Given patch features H: (B, N, D)
    Computes attention weights and bag representation.
    """
    def __init__(self, dim, hidden_dim=128):
        super().__init__()
        self.attention_a = nn.Linear(dim, hidden_dim)
        self.attention_b = nn.Linear(hidden_dim, 1)

    def forward(self, H):
        # H: (B, N, D)
        A = torch.tanh(self.attention_a(H))      # (B,N,hidden)
        A = self.attention_b(A).squeeze(-1)      # (B,N)
        A = torch.softmax(A, dim=1)              # (B,N)
        # Weighted sum
        bag_repr = torch.sum(H * A.unsqueeze(-1), dim=1)  # (B,D)
        return bag_repr, A


class PatchMILLesionClassifier(nn.Module):
    """
    PLA-MIL:
      - encodes patches with CNN
      - attention-pooling to get image-level representation
      - classification head
    """
    def __init__(self, num_classes=7, patch_dim=256):
        super().__init__()
        self.encoder = PatchEncoder(out_dim=patch_dim)
        self.mil_pool = AttentionMIL(dim=patch_dim, hidden_dim=128)
        self.classifier = nn.Sequential(
            nn.Linear(patch_dim, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, patches):
        """
        patches: (B, N, C, H, W)
        """
        B, N, C, H, W = patches.shape
        patches = patches.view(B * N, C, H, W)   # (B*N, C, H, W)
        patch_feats = self.encoder(patches)      # (B*N, D)
        D = patch_feats.shape[1]
        patch_feats = patch_feats.view(B, N, D)  # (B, N, D)

        bag_repr, att_weights = self.mil_pool(patch_feats)  # (B,D), (B,N)
        logits = self.classifier(bag_repr)                  # (B,num_classes)
        return logits, att_weights

model = PatchMILLesionClassifier(num_classes=NUM_CLASSES, patch_dim=256).to(DEVICE)

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=3, verbose=True
)

print("Model params (M):", sum(p.numel() for p in model.parameters()) / 1e6)

# ----------------- TRAIN & EVAL -----------------

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    for patches, labels in loader:
        patches = patches.to(DEVICE)   # (B,N,C,H,W)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        outputs, att = model(patches)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * labels.size(0)
        running_corrects += (preds == labels).sum().item()
        total += labels.size(0)

    return running_loss / total, running_corrects / total


def evaluate(model, loader, criterion, collect_probs=False):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    all_labels = []
    all_preds = []
    all_probs = []

    with torch.no_grad():
        for patches, labels in loader:
            patches = patches.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs, att = model(patches)
            loss = criterion(outputs, labels)

            probs = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)

            running_loss += loss.item() * labels.size(0)
            running_corrects += (preds == labels).sum().item()
            total += labels.size(0)

            all_labels.append(labels.cpu().numpy())
            all_preds.append(preds.cpu().numpy())
            if collect_probs:
                all_probs.append(probs.cpu().numpy())

    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total

    all_labels = np.concatenate(all_labels)
    all_preds = np.concatenate(all_preds)
    if collect_probs:
        all_probs = np.concatenate(all_probs)
    else:
        all_probs = None

    return epoch_loss, epoch_acc, all_labels, all_preds, all_probs

# ----------------- TRAINING LOOP -----------------

history = []
best_val_acc = 0.0
best_model_path = OUTPUT_DIR / "best_plamil_model.pth"

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc, _, _, _ = evaluate(model, val_loader, criterion, collect_probs=False)

    scheduler.step(val_acc)

    history.append({
        "epoch": epoch,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "val_loss": val_loss,
        "val_acc": val_acc,
    })

    print(
        f"Epoch {epoch:02d}/{NUM_EPOCHS} - "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
    )

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"  -> New best model saved (val_acc={best_val_acc:.4f})")

# save history
history_df = pd.DataFrame(history)
history_csv_path = OUTPUT_DIR / "training_history.csv"
history_df.to_csv(history_csv_path, index=False)
print("Saved training history to:", history_csv_path)

# ----------------- PLOTS -----------------

plt.figure()
plt.plot(history_df["epoch"], history_df["train_loss"], label="Train Loss")
plt.plot(history_df["epoch"], history_df["val_loss"], label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("PLA-MIL Loss Curves")
plt.legend()
loss_fig_path = OUTPUT_DIR / "loss_curves.png"
plt.savefig(loss_fig_path, dpi=150, bbox_inches="tight")
plt.close()
print("Saved loss curves to:", loss_fig_path)

plt.figure()
plt.plot(history_df["epoch"], history_df["train_acc"], label="Train Acc")
plt.plot(history_df["epoch"], history_df["val_acc"], label="Val Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("PLA-MIL Accuracy Curves")
plt.legend()
acc_fig_path = OUTPUT_DIR / "acc_curves.png"
plt.savefig(acc_fig_path, dpi=150, bbox_inches="tight")
plt.close()
print("Saved accuracy curves to:", acc_fig_path)

# ----------------- FINAL EVAL: CONFUSION MATRIX / REPORT / ROC -----------------

model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
val_loss, val_acc, y_true, y_pred, y_prob = evaluate(
    model, val_loader, criterion, collect_probs=True
)

print(f"Final Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

cm = confusion_matrix(y_true, y_pred)
print("Confusion matrix:\n", cm)

cm_fig_path = OUTPUT_DIR / "confusion_matrix.png"
plt.figure(figsize=(6, 5))
plt.imshow(cm, interpolation="nearest")
plt.title("PLA-MIL Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(NUM_CLASSES)
plt.xticks(tick_marks, CLASS_NAMES, rotation=45)
plt.yticks(tick_marks, CLASS_NAMES)
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(
            j, i, format(cm[i, j], "d"),
            horizontalalignment="center",
            color="white" if cm[i, j] > thresh else "black",
        )
plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.tight_layout()
plt.savefig(cm_fig_path, dpi=150)
plt.close()
print("Saved confusion_matrix to:", cm_fig_path)

report = classification_report(
    y_true, y_pred,
    target_names=CLASS_NAMES, digits=4
)
print("Classification report:\n", report)

report_path = OUTPUT_DIR / "classification_report.txt"
with open(report_path, "w") as f:
    f.write(report)
print("Saved classification_report to:", report_path)

if y_prob is not None:
    roc_fig_path = OUTPUT_DIR / "roc_curves.png"
    plt.figure(figsize=(8, 6))
    for i, cls_name in enumerate(CLASS_NAMES):
        y_true_bin = (y_true == i).astype(int)
        fpr, tpr, _ = roc_curve(y_true_bin, y_prob[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f"{cls_name} (AUC = {roc_auc:.3f})")

    plt.plot([0, 1], [0, 1], "k--", label="Random")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("PLA-MIL ROC Curves (One-vs-Rest)")
    plt.legend(loc="lower right")
    plt.tight_layout()
    plt.savefig(roc_fig_path, dpi=150)
    plt.close()
    print("Saved ROC curves to:", roc_fig_path)

print("All done (PLA-MIL classification).")

Using device: cuda
          image  MEL   NV  BCC  AKIEC  BKL   DF  VASC  label_idx
0  ISIC_0024306  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
1  ISIC_0024307  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
2  ISIC_0024308  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
3  ISIC_0024309  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
4  ISIC_0024310  1.0  0.0  0.0    0.0  0.0  0.0   0.0          0
Total samples: 10015
Train: 8012, Val: 2003
Class counts (train): {0: 890, 1: 5364, 2: 411, 3: 262, 4: 879, 5: 92, 6: 114}


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 211MB/s]


Model params (M): 11.408456




Epoch 01/100 - Train Loss: 0.7718, Train Acc: 0.4466 | Val Loss: 1.6730, Val Acc: 0.1213
  -> New best model saved (val_acc=0.1213)
Epoch 02/100 - Train Loss: 0.4791, Train Acc: 0.5724 | Val Loss: 1.3487, Val Acc: 0.2077
  -> New best model saved (val_acc=0.2077)
Epoch 03/100 - Train Loss: 0.4061, Train Acc: 0.6228 | Val Loss: 1.6592, Val Acc: 0.2147
  -> New best model saved (val_acc=0.2147)
Epoch 04/100 - Train Loss: 0.3696, Train Acc: 0.6460 | Val Loss: 1.1781, Val Acc: 0.4338
  -> New best model saved (val_acc=0.4338)
Epoch 05/100 - Train Loss: 0.3178, Train Acc: 0.6739 | Val Loss: 1.0854, Val Acc: 0.4333
Epoch 06/100 - Train Loss: 0.2911, Train Acc: 0.7061 | Val Loss: 1.1079, Val Acc: 0.4738
  -> New best model saved (val_acc=0.4738)
Epoch 07/100 - Train Loss: 0.2685, Train Acc: 0.7263 | Val Loss: 0.9794, Val Acc: 0.4788
  -> New best model saved (val_acc=0.4788)
Epoch 08/100 - Train Loss: 0.2404, Train Acc: 0.7410 | Val Loss: 1.1556, Val Acc: 0.4224
Epoch 09/100 - Train Loss: 0.2