In [1]:
# ============================================================
# Segmentation-Guided ResNet-34 (SG-ResNet)
# Input: 4 channels = RGB + lesion mask
# ============================================================

import os
from pathlib import Path
import random
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc

# ----------------- CONFIG -----------------
DATA_ROOT = Path("/kaggle/input/ham1000-segmentation-and-classification")
IMAGES_DIR = DATA_ROOT / "images"
MASKS_DIR = DATA_ROOT / "masks"
CSV_PATH = DATA_ROOT / "GroundTruth.csv"

OUTPUT_DIR = Path("/kaggle/working/ham_sg_resnet_outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_EPOCHS = 100
LR = 1e-4
VAL_SPLIT = 0.2
SEED = 42
NUM_WORKERS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

CLASS_NAMES = ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"]
NUM_CLASSES = len(CLASS_NAMES)

print("Using device:", DEVICE)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)

# ----------------- LOAD LABELS -----------------
df = pd.read_csv(CSV_PATH)

label_array = df[CLASS_NAMES].values
df["label_idx"] = label_array.argmax(axis=1)

print(df.head())

# ----------------- TRAIN / VAL SPLIT -----------------
train_df, val_df = train_test_split(
    df,
    test_size=VAL_SPLIT,
    random_state=SEED,
    stratify=df["label_idx"]
)

print(f"Total samples: {len(df)}")
print(f"Train: {len(train_df)}, Val: {len(val_df)}")

# ----------------- DATASET -----------------

class MaskGuidedHAMDataset(Dataset):
    """
    Returns:
      x_4ch (4,H,W) = [R,G,B,Mask], label_idx
    """
    def __init__(self, dataframe, img_dir, mask_dir, img_size=224):
        self.df = dataframe.reset_index(drop=True)
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_size = img_size

        self.img_transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
        ])

        # mask will be resized separately, then turned into [0,1] float tensor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_id = row["image"]
        label = int(row["label_idx"])

        img_path = self.img_dir / f"{img_id}.jpg"
        mask_path = self.mask_dir / f"{img_id}_segmentation.png"

        # load RGB image
        img = Image.open(img_path).convert("RGB")
        img = self.img_transform(img)  # (3,H,W)

        # load mask (if missing -> zeros)
        if mask_path.exists():
            mask = Image.open(mask_path).convert("L")
            mask = mask.resize((self.img_size, self.img_size), resample=Image.NEAREST)
            mask_np = np.array(mask)
            mask_bin = (mask_np > 0).astype("float32")  # 0/1
        else:
            mask_bin = np.zeros((self.img_size, self.img_size), dtype="float32")

        mask_tensor = torch.from_numpy(mask_bin).unsqueeze(0)  # (1,H,W)

        # concatenate as 4-channel input
        x_4ch = torch.cat([img, mask_tensor], dim=0)  # (4,H,W)

        return x_4ch, label

# ----------------- DATA LOADERS -----------------
train_dataset = MaskGuidedHAMDataset(train_df, IMAGES_DIR, MASKS_DIR, IMG_SIZE)
val_dataset   = MaskGuidedHAMDataset(val_df,   IMAGES_DIR, MASKS_DIR, IMG_SIZE)

# class imbalance
class_counts = train_df["label_idx"].value_counts().sort_index().values
print("Class counts (train):", dict(zip(range(NUM_CLASSES), class_counts)))

class_weights = 1.0 / class_counts
sample_weights = class_weights[train_df["label_idx"].values]
sample_weights = torch.from_numpy(sample_weights).float()

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

# ----------------- MODEL: 4-CHANNEL RESNET-34 -----------------

def build_4ch_resnet34(num_classes=7):
    # load normal 3-channel ResNet-34
    resnet = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)

    # modify first conv to accept 4 channels
    old_conv = resnet.conv1
    new_conv = nn.Conv2d(
        in_channels=4,
        out_channels=old_conv.out_channels,
        kernel_size=old_conv.kernel_size,
        stride=old_conv.stride,
        padding=old_conv.padding,
        bias=old_conv.bias is not None,
    )

    # copy pretrained weights for first 3 channels
    with torch.no_grad():
        new_conv.weight[:, :3, :, :] = old_conv.weight
        # initialize 4th channel as mean of RGB weights
        new_conv.weight[:, 3:4, :, :] = old_conv.weight.mean(dim=1, keepdim=True)

    resnet.conv1 = new_conv

    # replace final FC
    in_features = resnet.fc.in_features
    resnet.fc = nn.Linear(in_features, num_classes)

    return resnet

model = build_4ch_resnet34(num_classes=NUM_CLASSES).to(DEVICE)

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=3, verbose=True
)

print("Model params (M):", sum(p.numel() for p in model.parameters()) / 1e6)

# ----------------- TRAIN & EVAL -----------------

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    for x_4ch, labels in loader:
        x_4ch = x_4ch.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(x_4ch)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * labels.size(0)
        running_corrects += (preds == labels).sum().item()
        total += labels.size(0)

    return running_loss / total, running_corrects / total


def evaluate(model, loader, criterion, collect_probs=False):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    all_labels = []
    all_preds = []
    all_probs = []

    with torch.no_grad():
        for x_4ch, labels in loader:
            x_4ch = x_4ch.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = model(x_4ch)
            loss = criterion(outputs, labels)

            probs = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)

            running_loss += loss.item() * labels.size(0)
            running_corrects += (preds == labels).sum().item()
            total += labels.size(0)

            all_labels.append(labels.cpu().numpy())
            all_preds.append(preds.cpu().numpy())
            if collect_probs:
                all_probs.append(probs.cpu().numpy())

    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total

    all_labels = np.concatenate(all_labels)
    all_preds = np.concatenate(all_preds)
    if collect_probs:
        all_probs = np.concatenate(all_probs)
    else:
        all_probs = None

    return epoch_loss, epoch_acc, all_labels, all_preds, all_probs

# ----------------- TRAINING LOOP -----------------

history = []
best_val_acc = 0.0
best_model_path = OUTPUT_DIR / "best_sg_resnet_model.pth"

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc, _, _, _ = evaluate(model, val_loader, criterion, collect_probs=False)

    scheduler.step(val_acc)

    history.append({
        "epoch": epoch,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "val_loss": val_loss,
        "val_acc": val_acc,
    })

    print(
        f"Epoch {epoch:02d}/{NUM_EPOCHS} - "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
    )

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"  -> New best model saved (val_acc={best_val_acc:.4f})")

# save history
history_df = pd.DataFrame(history)
history_csv_path = OUTPUT_DIR / "training_history.csv"
history_df.to_csv(history_csv_path, index=False)
print("Saved training history to:", history_csv_path)

# ----------------- PLOTS -----------------

plt.figure()
plt.plot(history_df["epoch"], history_df["train_loss"], label="Train Loss")
plt.plot(history_df["epoch"], history_df["val_loss"], label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("SG-ResNet Loss Curves")
plt.legend()
loss_fig_path = OUTPUT_DIR / "loss_curves.png"
plt.savefig(loss_fig_path, dpi=150, bbox_inches="tight")
plt.close()
print("Saved loss curves to:", loss_fig_path)

plt.figure()
plt.plot(history_df["epoch"], history_df["train_acc"], label="Train Acc")
plt.plot(history_df["epoch"], history_df["val_acc"], label="Val Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("SG-ResNet Accuracy Curves")
plt.legend()
acc_fig_path = OUTPUT_DIR / "acc_curves.png"
plt.savefig(acc_fig_path, dpi=150, bbox_inches="tight")
plt.close()
print("Saved accuracy curves to:", acc_fig_path)

# ----------------- FINAL EVAL -----------------

model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
val_loss, val_acc, y_true, y_pred, y_prob = evaluate(
    model, val_loader, criterion, collect_probs=True
)

print(f"Final Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

cm = confusion_matrix(y_true, y_pred)
print("Confusion matrix:\n", cm)

cm_fig_path = OUTPUT_DIR / "confusion_matrix.png"
plt.figure(figsize=(6, 5))
plt.imshow(cm, interpolation="nearest")
plt.title("SG-ResNet Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(NUM_CLASSES)
plt.xticks(tick_marks, CLASS_NAMES, rotation=45)
plt.yticks(tick_marks, CLASS_NAMES)
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(
            j, i, format(cm[i, j], "d"),
            horizontalalignment="center",
            color="white" if cm[i, j] > thresh else "black",
        )
plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.tight_layout()
plt.savefig(cm_fig_path, dpi=150)
plt.close()
print("Saved confusion_matrix to:", cm_fig_path)

report = classification_report(
    y_true, y_pred,
    target_names=CLASS_NAMES, digits=4
)
print("Classification report:\n", report)

report_path = OUTPUT_DIR / "classification_report.txt"
with open(report_path, "w") as f:
    f.write(report)
print("Saved classification_report to:", report_path)

if y_prob is not None:
    roc_fig_path = OUTPUT_DIR / "roc_curves.png"
    plt.figure(figsize=(8, 6))
    for i, cls_name in enumerate(CLASS_NAMES):
        y_true_bin = (y_true == i).astype(int)
        fpr, tpr, _ = roc_curve(y_true_bin, y_prob[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f"{cls_name} (AUC = {roc_auc:.3f})")

    plt.plot([0, 1], [0, 1], "k--", label="Random")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("SG-ResNet ROC Curves (One-vs-Rest)")
    plt.legend(loc="lower right")
    plt.tight_layout()
    plt.savefig(roc_fig_path, dpi=150)
    plt.close()
    print("Saved ROC curves to:", roc_fig_path)

print("All done (SG-ResNet classification).")

Using device: cuda
          image  MEL   NV  BCC  AKIEC  BKL   DF  VASC  label_idx
0  ISIC_0024306  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
1  ISIC_0024307  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
2  ISIC_0024308  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
3  ISIC_0024309  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
4  ISIC_0024310  1.0  0.0  0.0    0.0  0.0  0.0   0.0          0
Total samples: 10015
Train: 8012, Val: 2003
Class counts (train): {0: 890, 1: 5364, 2: 411, 3: 262, 4: 879, 5: 92, 6: 114}


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 192MB/s]


Model params (M): 21.291399




Epoch 01/100 - Train Loss: 0.2511, Train Acc: 0.7274 | Val Loss: 0.8834, Val Acc: 0.5082
  -> New best model saved (val_acc=0.5082)
Epoch 02/100 - Train Loss: 0.0684, Train Acc: 0.8848 | Val Loss: 0.7920, Val Acc: 0.6191
  -> New best model saved (val_acc=0.6191)
Epoch 03/100 - Train Loss: 0.0657, Train Acc: 0.8889 | Val Loss: 0.7722, Val Acc: 0.6510
  -> New best model saved (val_acc=0.6510)
Epoch 04/100 - Train Loss: 0.0424, Train Acc: 0.9071 | Val Loss: 1.0072, Val Acc: 0.5522
Epoch 05/100 - Train Loss: 0.0384, Train Acc: 0.9165 | Val Loss: 0.7311, Val Acc: 0.6860
  -> New best model saved (val_acc=0.6860)
Epoch 06/100 - Train Loss: 0.0173, Train Acc: 0.9453 | Val Loss: 0.7814, Val Acc: 0.6995
  -> New best model saved (val_acc=0.6995)
Epoch 07/100 - Train Loss: 0.0371, Train Acc: 0.9324 | Val Loss: 0.8134, Val Acc: 0.7049
  -> New best model saved (val_acc=0.7049)
Epoch 08/100 - Train Loss: 0.0323, Train Acc: 0.9319 | Val Loss: 1.2257, Val Acc: 0.4953
Epoch 09/100 - Train Loss: 0.0