In [1]:
# ============================================================
# TransUNet-AEO style segmentation on HAM10000
# Dataset structure:
#   /kaggle/input/ham1000-segmentation-and-classification/
#       images/
#       masks/
#       GroundTruth.csv   (not needed for segmentation)
# ============================================================

import os
import random
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt

# If albumentations is not installed, uncomment:
# !pip install -q albumentations==1.4.3

import albumentations as A
from albumentations.pytorch import ToTensorV2

# ----------------- CONFIG -----------------
DATA_ROOT = Path("/kaggle/input/ham1000-segmentation-and-classification")
IMAGES_DIR = DATA_ROOT / "images"
MASKS_DIR = DATA_ROOT / "masks"

OUTPUT_DIR = Path("/kaggle/working/ham_transunet_outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = 256
BATCH_SIZE = 8
NUM_EPOCHS = 50
LR = 1e-4
VAL_SPLIT = 0.2
SEED = 42
NUM_WORKERS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("Using device:", DEVICE)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)

# ----------------- DATASET -----------------

class HAMSegmentationDataset(Dataset):
    """
    Loads image + mask pairs for lesion segmentation.
    Assumes:
        images: <id>.jpg
        masks : <id>_segmentation.png
    """
    def __init__(self, image_ids, img_dir, mask_dir, transforms=None):
        self.image_ids = image_ids
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transforms = transforms

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]

        img_path = self.img_dir / f"{img_id}.jpg"
        mask_path = self.mask_dir / f"{img_id}_segmentation.png"

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"))

        # ensure binary mask 0/1
        mask = (mask > 0).astype("float32")

        # Albumentations expects HWC
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented["image"]          # tensor CxHxW
            mask = augmented["mask"].unsqueeze(0)  # 1xHxW
        else:
            # fallback: simple tensor conversion
            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            mask = torch.from_numpy(mask).unsqueeze(0).float()

        return image, mask


# --------------- TRANSFORMS ---------------

train_transforms = A.Compose(
    [
        A.Resize(IMG_SIZE, IMG_SIZE),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(
            shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.5,
            border_mode=0
        ),
        A.GaussNoise(p=0.2),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

val_transforms = A.Compose(
    [
        A.Resize(IMG_SIZE, IMG_SIZE),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

# --------------- COLLECT IMAGE IDS ---------------

# image filenames are something like ISIC_0024306.jpg
all_img_files = sorted([f for f in IMAGES_DIR.iterdir() if f.suffix.lower() == ".jpg"])
all_ids = [f.stem for f in all_img_files]

train_ids, val_ids = train_test_split(
    all_ids, test_size=VAL_SPLIT, random_state=SEED
)

print(f"Total images: {len(all_ids)}")
print(f"Train: {len(train_ids)}, Val: {len(val_ids)}")

train_dataset = HAMSegmentationDataset(train_ids, IMAGES_DIR, MASKS_DIR, train_transforms)
val_dataset   = HAMSegmentationDataset(val_ids,   IMAGES_DIR, MASKS_DIR, val_transforms)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)

# --------------- MODEL (TransUNet-like) ---------------

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_ch, out_ch)

    def forward(self, x, skip):
        x = self.up(x)
        # handle size mismatch due to rounding
        if x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2):
            diff_y = skip.size(-2) - x.size(-2)
            diff_x = skip.size(-1) - x.size(-1)
            x = F.pad(x, [diff_x // 2, diff_x - diff_x // 2,
                          diff_y // 2, diff_y - diff_y // 2])
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)


class TransUNetAEO(nn.Module):
    """
    TransUNet-style UNet with transformer bottleneck.
    'AEO' (optimizer) is not embedded here; you would tune hyperparams outside.
    """
    def __init__(self,
                 img_size=256,
                 in_ch=3,
                 num_classes=1,
                 base_ch=64,
                 num_heads=4,
                 transformer_layers=4,
                 dropout=0.1):
        super().__init__()

        self.enc1 = ConvBlock(in_ch, base_ch)
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = ConvBlock(base_ch, base_ch * 2)
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = ConvBlock(base_ch * 2, base_ch * 4)
        self.pool3 = nn.MaxPool2d(2)

        self.enc4 = ConvBlock(base_ch * 4, base_ch * 8)
        self.pool4 = nn.MaxPool2d(2)

        bottleneck_ch = base_ch * 16
        self.bottleneck_conv = ConvBlock(base_ch * 8, bottleneck_ch)

        # Transformer encoder at bottleneck
        H = img_size // 16
        W = img_size // 16
        self.num_tokens = H * W
        d_model = bottleneck_ch

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=False,   # (S, B, E)
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=transformer_layers
        )

        # Decoder
        self.up4 = UpBlock(bottleneck_ch, base_ch * 8)
        self.up3 = UpBlock(base_ch * 8, base_ch * 4)
        self.up2 = UpBlock(base_ch * 4, base_ch * 2)
        self.up1 = UpBlock(base_ch * 2, base_ch)

        self.final_conv = nn.Conv2d(base_ch, num_classes, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)

        e2 = self.enc2(p1)
        p2 = self.pool2(e2)

        e3 = self.enc3(p2)
        p3 = self.pool3(e3)

        e4 = self.enc4(p3)
        p4 = self.pool4(e4)

        b = self.bottleneck_conv(p4)  # B x C x H x W

        B, C, H, W = b.shape
        seq = b.flatten(2).permute(2, 0, 1)   # (S, B, C)
        seq = self.transformer(seq)
        b = seq.permute(1, 2, 0).view(B, C, H, W)

        d4 = self.up4(b, e4)
        d3 = self.up3(d4, e3)
        d2 = self.up2(d3, e2)
        d1 = self.up1(d2, e1)

        out = self.final_conv(d1)  # logits
        return out


# --------------- LOSS & METRICS ---------------

class DiceBCELoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.smooth = smooth

    def forward(self, logits, targets):
        bce = self.bce(logits, targets)
        probs = torch.sigmoid(logits)
        num = 2 * (probs * targets).sum(dim=(2, 3)) + self.smooth
        den = probs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3)) + self.smooth
        dice_loss = 1 - (num / den).mean()
        return bce + dice_loss


def dice_coef(logits, targets, smooth=1.0):
    probs = torch.sigmoid(logits)
    probs = (probs > 0.5).float()
    num = 2 * (probs * targets).sum(dim=(2, 3)) + smooth
    den = probs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3)) + smooth
    return (num / den).mean().item()


def iou_score(logits, targets, smooth=1.0):
    probs = torch.sigmoid(logits)
    probs = (probs > 0.5).float()
    intersection = (probs * targets).sum(dim=(2, 3))
    union = probs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3)) - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou.mean().item()


def pixel_accuracy(logits, targets):
    probs = torch.sigmoid(logits)
    preds = (probs > 0.5).float()
    correct = (preds == targets).float().mean()
    return correct.item()

# --------------- TRAIN / EVAL LOOPS ---------------

def train_one_epoch(model, loader, optimizer, loss_fn):
    model.train()
    running_loss = 0.0
    running_dice = 0.0
    running_iou = 0.0
    running_acc = 0.0
    n_batches = 0

    for imgs, masks in loader:
        imgs = imgs.to(DEVICE)
        masks = masks.to(DEVICE)

        optimizer.zero_grad()
        logits = model(imgs)
        loss = loss_fn(logits, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_dice += dice_coef(logits, masks)
        running_iou += iou_score(logits, masks)
        running_acc += pixel_accuracy(logits, masks)
        n_batches += 1

    return {
        "loss": running_loss / n_batches,
        "dice": running_dice / n_batches,
        "iou": running_iou / n_batches,
        "acc": running_acc / n_batches,
    }


def evaluate(model, loader, loss_fn, collect_for_cm=False, max_batches_cm=30):
    model.eval()
    running_loss = 0.0
    running_dice = 0.0
    running_iou = 0.0
    running_acc = 0.0
    n_batches = 0

    y_true_all = []
    y_pred_all = []

    with torch.no_grad():
        for batch_idx, (imgs, masks) in enumerate(loader):
            imgs = imgs.to(DEVICE)
            masks = masks.to(DEVICE)

            logits = model(imgs)
            loss = loss_fn(logits, masks)

            running_loss += loss.item()
            running_dice += dice_coef(logits, masks)
            running_iou += iou_score(logits, masks)
            running_acc += pixel_accuracy(logits, masks)
            n_batches += 1

            if collect_for_cm and batch_idx < max_batches_cm:
                probs = torch.sigmoid(logits)
                preds = (probs > 0.5).float()
                # flatten
                y_true_all.append(masks.cpu().numpy().ravel())
                y_pred_all.append(preds.cpu().numpy().ravel())

    metrics = {
        "loss": running_loss / n_batches,
        "dice": running_dice / n_batches,
        "iou": running_iou / n_batches,
        "acc": running_acc / n_batches,
    }

    if collect_for_cm and len(y_true_all) > 0:
        y_true_all = np.concatenate(y_true_all)
        y_pred_all = np.concatenate(y_pred_all)
        return metrics, y_true_all, y_pred_all
    else:
        return metrics, None, None


# --------------- INITIALIZE ---------------

model = TransUNetAEO(
    img_size=IMG_SIZE,
    in_ch=3,
    num_classes=1,
    base_ch=64,
    num_heads=4,
    transformer_layers=4,
    dropout=0.1,
).to(DEVICE)

loss_fn = DiceBCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

print("Model parameters:", sum(p.numel() for p in model.parameters()) / 1e6, "M")

# --------------- TRAINING LOOP ---------------

history = []

best_val_dice = 0.0
best_model_path = OUTPUT_DIR / "best_transunet_aeo.pth"

for epoch in range(1, NUM_EPOCHS + 1):
    train_metrics = train_one_epoch(model, train_loader, optimizer, loss_fn)
    val_metrics, _, _ = evaluate(model, val_loader, loss_fn, collect_for_cm=False)

    epoch_info = {
        "epoch": epoch,
        "train_loss": train_metrics["loss"],
        "train_dice": train_metrics["dice"],
        "train_iou": train_metrics["iou"],
        "train_acc": train_metrics["acc"],
        "val_loss": val_metrics["loss"],
        "val_dice": val_metrics["dice"],
        "val_iou": val_metrics["iou"],
        "val_acc": val_metrics["acc"],
    }
    history.append(epoch_info)

    print(
        f"Epoch {epoch:02d}/{NUM_EPOCHS} | "
        f"Train Loss: {train_metrics['loss']:.4f} | Val Loss: {val_metrics['loss']:.4f} | "
        f"Val Dice: {val_metrics['dice']:.4f} | Val IoU: {val_metrics['iou']:.4f} | "
        f"Val Acc: {val_metrics['acc']:.4f}"
    )

    # save best model by Dice
    if val_metrics["dice"] > best_val_dice:
        best_val_dice = val_metrics["dice"]
        torch.save(model.state_dict(), best_model_path)

# Save training history
history_df = pd.DataFrame(history)
history_csv_path = OUTPUT_DIR / "training_history.csv"
history_df.to_csv(history_csv_path, index=False)
print("Saved training history to:", history_csv_path)

# --------------- PLOTS (Loss & Dice) ---------------

plt.figure()
plt.plot(history_df["epoch"], history_df["train_loss"], label="Train Loss")
plt.plot(history_df["epoch"], history_df["val_loss"], label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss Curves")
plt.legend()
loss_fig_path = OUTPUT_DIR / "loss_curves.png"
plt.savefig(loss_fig_path, dpi=150, bbox_inches="tight")
plt.close()
print("Saved loss curves to:", loss_fig_path)

plt.figure()
plt.plot(history_df["epoch"], history_df["train_dice"], label="Train Dice")
plt.plot(history_df["epoch"], history_df["val_dice"], label="Val Dice")
plt.xlabel("Epoch")
plt.ylabel("Dice Coefficient")
plt.title("Dice Curves")
plt.legend()
dice_fig_path = OUTPUT_DIR / "dice_curves.png"
plt.savefig(dice_fig_path, dpi=150, bbox_inches="tight")
plt.close()
print("Saved dice curves to:", dice_fig_path)

# --------------- FINAL EVAL: CONFUSION MATRIX + CLASSIF REPORT ---------------

# load best model
model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))

val_metrics, y_true, y_pred = evaluate(
    model, val_loader, loss_fn, collect_for_cm=True, max_batches_cm=50
)

print("Final Val metrics:", val_metrics)

# y_true / y_pred are pixel-wise labels (0/1), sampled from up to 50 batches
if y_true is not None:
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    print("Confusion matrix:\n", cm)

    cm_fig_path = OUTPUT_DIR / "confusion_matrix.png"
    plt.figure()
    plt.imshow(cm, interpolation="nearest")
    plt.title("Pixel-wise Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(2)
    plt.xticks(tick_marks, ["Background", "Lesion"])
    plt.yticks(tick_marks, ["Background", "Lesion"])
    plt.xlabel("Predicted")
    plt.ylabel("True")

    # annotate cells
    thresh = cm.max() / 2.0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(
                j, i, format(cm[i, j], "d"),
                horizontalalignment="center",
                color="white" if cm[i, j] > thresh else "black",
            )

    plt.tight_layout()
    plt.savefig(cm_fig_path, dpi=150)
    plt.close()
    print("Saved confusion matrix plot to:", cm_fig_path)

    # classification report
    clf_report = classification_report(
        y_true, y_pred,
        target_names=["Background", "Lesion"]
    )
    print("Classification report:\n", clf_report)

    report_path = OUTPUT_DIR / "classification_report.txt"
    with open(report_path, "w") as f:
        f.write(clf_report)
    print("Saved classification report to:", report_path)

print("All done.")

Using device: cuda
Total images: 10015
Train: 8012, Val: 2003


  original_init(self, **validated_kwargs)


Model parameters: 81.428417 M
Epoch 01/50 | Train Loss: 0.6331 | Val Loss: 0.8105 | Val Dice: 0.7331 | Val IoU: 0.6108 | Val Acc: 0.8304
Epoch 02/50 | Train Loss: 0.3900 | Val Loss: 0.3082 | Val Dice: 0.8901 | Val IoU: 0.8186 | Val Acc: 0.9447
Epoch 03/50 | Train Loss: 0.3194 | Val Loss: 0.3376 | Val Dice: 0.8773 | Val IoU: 0.8038 | Val Acc: 0.9401
Epoch 04/50 | Train Loss: 0.3726 | Val Loss: 0.4055 | Val Dice: 0.8549 | Val IoU: 0.7763 | Val Acc: 0.9192
Epoch 05/50 | Train Loss: 0.3770 | Val Loss: 0.3672 | Val Dice: 0.8638 | Val IoU: 0.7865 | Val Acc: 0.9288
Epoch 06/50 | Train Loss: 0.3858 | Val Loss: 0.3671 | Val Dice: 0.8610 | Val IoU: 0.7829 | Val Acc: 0.9293
Epoch 07/50 | Train Loss: 0.3545 | Val Loss: 0.3243 | Val Dice: 0.8779 | Val IoU: 0.8060 | Val Acc: 0.9385
Epoch 08/50 | Train Loss: 0.3380 | Val Loss: 0.3173 | Val Dice: 0.8803 | Val IoU: 0.8074 | Val Acc: 0.9399
Epoch 09/50 | Train Loss: 0.3518 | Val Loss: 0.3201 | Val Dice: 0.8794 | Val IoU: 0.8079 | Val Acc: 0.9378
Epoch 1