In [1]:
# ============================================================
# Dual-Branch Lesion-Aware Classifier (DB-LARNet)
# Branch 1: full image
# Branch 2: segmentation-based lesion crop
# ============================================================

import os
from pathlib import Path
import random
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc

# ----------------- CONFIG -----------------
DATA_ROOT = Path("/kaggle/input/ham1000-segmentation-and-classification")
FULL_IMG_DIR = DATA_ROOT / "images"   # full dermoscopic images
CROP_IMG_DIR = Path("/kaggle/working/ham_lesion_crops")  # lesion crops from masks
CSV_PATH = DATA_ROOT / "GroundTruth.csv"

OUTPUT_DIR = Path("/kaggle/working/ham_dualbranch_cls_outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_EPOCHS = 100
LR = 1e-4
VAL_SPLIT = 0.2
SEED = 42
NUM_WORKERS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

CLASS_NAMES = ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"]
NUM_CLASSES = len(CLASS_NAMES)

print("Using device:", DEVICE)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)

# ----------------- LOAD LABELS -----------------
df = pd.read_csv(CSV_PATH)

# convert one-hot to single index
label_array = df[CLASS_NAMES].values  # shape: (N,7)
df["label_idx"] = label_array.argmax(axis=1)

print(df.head())

# ----------------- TRAIN / VAL SPLIT -----------------
train_df, val_df = train_test_split(
    df,
    test_size=VAL_SPLIT,
    random_state=SEED,
    stratify=df["label_idx"]
)

print(f"Total samples: {len(df)}")
print(f"Train: {len(train_df)}, Val: {len(val_df)}")

# ----------------- DATASET -----------------

class DualBranchHAMDataset(Dataset):
    """
    Returns:
      full_img_tensor, crop_img_tensor, label
    For each row, it loads:
      - full image from FULL_IMG_DIR / <image>.jpg
      - cropped lesion from CROP_IMG_DIR / <image>.jpg
    """
    def __init__(self, dataframe, full_img_dir, crop_img_dir, transform_full=None, transform_crop=None):
        self.df = dataframe.reset_index(drop=True)
        self.full_img_dir = full_img_dir
        self.crop_img_dir = crop_img_dir
        self.transform_full = transform_full
        self.transform_crop = transform_crop

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_id = row["image"]
        label = int(row["label_idx"])

        full_path = self.full_img_dir / f"{img_id}.jpg"
        crop_path = self.crop_img_dir / f"{img_id}.jpg"

        # fallback if crop is missing
        full_img = Image.open(full_path).convert("RGB")
        if crop_path.exists():
            crop_img = Image.open(crop_path).convert("RGB")
        else:
            crop_img = full_img.copy()

        if self.transform_full:
            full_img = self.transform_full(full_img)
        if self.transform_crop:
            crop_img = self.transform_crop(crop_img)

        return full_img, crop_img, label

# ----------------- TRANSFORMS -----------------
train_transform_full = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# you can use slightly lighter aug for crops if you like
train_transform_crop = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

train_dataset = DualBranchHAMDataset(
    train_df, FULL_IMG_DIR, CROP_IMG_DIR,
    transform_full=train_transform_full,
    transform_crop=train_transform_crop
)
val_dataset = DualBranchHAMDataset(
    val_df, FULL_IMG_DIR, CROP_IMG_DIR,
    transform_full=val_transform,
    transform_crop=val_transform
)

# ----------------- CLASS IMBALANCE HANDLING -----------------
class_counts = train_df["label_idx"].value_counts().sort_index().values
print("Class counts (train):", dict(zip(range(NUM_CLASSES), class_counts)))

class_weights = 1.0 / class_counts
sample_weights = class_weights[train_df["label_idx"].values]
sample_weights = torch.from_numpy(sample_weights).float()

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

# ----------------- MODEL: Dual-Branch ResNet-18 -----------------

class DualBranchResNet(nn.Module):
    """
    DB-LARNet:
      Branch A: full image -> ResNet-18 backbone
      Branch B: lesion crop -> ResNet-18 backbone
      Fuse: concat pooled features -> FC -> logits
    """
    def __init__(self, num_classes=7, pretrained=True):
        super().__init__()
        # Branch for full image
        base_full = models.resnet18(
            weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
        )
        # Branch for crop image
        base_crop = models.resnet18(
            weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
        )

        # remove FC layers, keep feature extractors
        self.full_backbone = nn.Sequential(*list(base_full.children())[:-1])  # -> (B,512,1,1)
        self.crop_backbone = nn.Sequential(*list(base_crop.children())[:-1])  # -> (B,512,1,1)

        self.dropout = nn.Dropout(p=0.5)
        self.classifier = nn.Sequential(
            nn.Linear(512 * 2, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, full_img, crop_img):
        # full_img, crop_img: (B,3,H,W)
        full_feat = self.full_backbone(full_img)   # (B,512,1,1)
        crop_feat = self.crop_backbone(crop_img)   # (B,512,1,1)

        full_feat = full_feat.view(full_feat.size(0), -1)  # (B,512)
        crop_feat = crop_feat.view(crop_feat.size(0), -1)  # (B,512)

        fused = torch.cat([full_feat, crop_feat], dim=1)   # (B,1024)
        fused = self.dropout(fused)
        logits = self.classifier(fused)                    # (B,num_classes)
        return logits

model = DualBranchResNet(num_classes=NUM_CLASSES, pretrained=True).to(DEVICE)

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=3, verbose=True
)

print("Model params (M):", sum(p.numel() for p in model.parameters()) / 1e6)

# ----------------- TRAIN & EVAL -----------------

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    for full_img, crop_img, labels in loader:
        full_img = full_img.to(DEVICE)
        crop_img = crop_img.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(full_img, crop_img)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * labels.size(0)
        running_corrects += (preds == labels).sum().item()
        total += labels.size(0)

    return running_loss / total, running_corrects / total


def evaluate(model, loader, criterion, collect_probs=False):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    all_labels = []
    all_preds = []
    all_probs = []

    with torch.no_grad():
        for full_img, crop_img, labels in loader:
            full_img = full_img.to(DEVICE)
            crop_img = crop_img.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = model(full_img, crop_img)
            loss = criterion(outputs, labels)

            probs = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)

            running_loss += loss.item() * labels.size(0)
            running_corrects += (preds == labels).sum().item()
            total += labels.size(0)

            all_labels.append(labels.cpu().numpy())
            all_preds.append(preds.cpu().numpy())
            if collect_probs:
                all_probs.append(probs.cpu().numpy())

    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total

    all_labels = np.concatenate(all_labels)
    all_preds = np.concatenate(all_preds)
    if collect_probs:
        all_probs = np.concatenate(all_probs)
    else:
        all_probs = None

    return epoch_loss, epoch_acc, all_labels, all_preds, all_probs

# ----------------- TRAINING LOOP -----------------

history = []
best_val_acc = 0.0
best_model_path = OUTPUT_DIR / "best_dualbranch_cls_model.pth"

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc, _, _, _ = evaluate(model, val_loader, criterion, collect_probs=False)

    scheduler.step(val_acc)

    history.append({
        "epoch": epoch,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "val_loss": val_loss,
        "val_acc": val_acc,
    })

    print(
        f"Epoch {epoch:02d}/{NUM_EPOCHS} - "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
    )

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"  -> New best model saved (val_acc={best_val_acc:.4f})")

# save history
history_df = pd.DataFrame(history)
history_csv_path = OUTPUT_DIR / "training_history.csv"
history_df.to_csv(history_csv_path, index=False)
print("Saved training history to:", history_csv_path)

# ----------------- PLOTS -----------------

plt.figure()
plt.plot(history_df["epoch"], history_df["train_loss"], label="Train Loss")
plt.plot(history_df["epoch"], history_df["val_loss"], label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("DB-LARNet Loss Curves")
plt.legend()
loss_fig_path = OUTPUT_DIR / "loss_curves.png"
plt.savefig(loss_fig_path, dpi=150, bbox_inches="tight")
plt.close()
print("Saved loss curves to:", loss_fig_path)

plt.figure()
plt.plot(history_df["epoch"], history_df["train_acc"], label="Train Acc")
plt.plot(history_df["epoch"], history_df["val_acc"], label="Val Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("DB-LARNet Accuracy Curves")
plt.legend()
acc_fig_path = OUTPUT_DIR / "acc_curves.png"
plt.savefig(acc_fig_path, dpi=150, bbox_inches="tight")
plt.close()
print("Saved accuracy curves to:", acc_fig_path)

# ----------------- FINAL EVAL: CONFUSION MATRIX, REPORT, ROC -----------------

model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
val_loss, val_acc, y_true, y_pred, y_prob = evaluate(
    model, val_loader, criterion, collect_probs=True
)

print(f"Final Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

cm = confusion_matrix(y_true, y_pred)
print("Confusion matrix:\n", cm)

cm_fig_path = OUTPUT_DIR / "confusion_matrix.png"
plt.figure(figsize=(6, 5))
plt.imshow(cm, interpolation="nearest")
plt.title("DB-LARNet Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(NUM_CLASSES)
plt.xticks(tick_marks, CLASS_NAMES, rotation=45)
plt.yticks(tick_marks, CLASS_NAMES)
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(
            j, i, format(cm[i, j], "d"),
            horizontalalignment="center",
            color="white" if cm[i, j] > thresh else "black",
        )
plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.tight_layout()
plt.savefig(cm_fig_path, dpi=150)
plt.close()
print("Saved confusion_matrix to:", cm_fig_path)

report = classification_report(
    y_true, y_pred,
    target_names=CLASS_NAMES, digits=4
)
print("Classification report:\n", report)

report_path = OUTPUT_DIR / "classification_report.txt"
with open(report_path, "w") as f:
    f.write(report)
print("Saved classification_report to:", report_path)

if y_prob is not None:
    roc_fig_path = OUTPUT_DIR / "roc_curves.png"
    plt.figure(figsize=(8, 6))
    for i, cls_name in enumerate(CLASS_NAMES):
        y_true_bin = (y_true == i).astype(int)
        fpr, tpr, _ = roc_curve(y_true_bin, y_prob[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f"{cls_name} (AUC = {roc_auc:.3f})")

    plt.plot([0, 1], [0, 1], "k--", label="Random")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("DB-LARNet ROC Curves (One-vs-Rest)")
    plt.legend(loc="lower right")
    plt.tight_layout()
    plt.savefig(roc_fig_path, dpi=150)
    plt.close()
    print("Saved ROC curves to:", roc_fig_path)

print("All done (dual-branch classification).")

Using device: cuda
          image  MEL   NV  BCC  AKIEC  BKL   DF  VASC  label_idx
0  ISIC_0024306  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
1  ISIC_0024307  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
2  ISIC_0024308  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
3  ISIC_0024309  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
4  ISIC_0024310  1.0  0.0  0.0    0.0  0.0  0.0   0.0          0
Total samples: 10015
Train: 8012, Val: 2003
Class counts (train): {0: 890, 1: 5364, 2: 411, 3: 262, 4: 879, 5: 92, 6: 114}


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 159MB/s]


Model params (M): 22.881415




Epoch 01/100 - Train Loss: 0.5864, Train Acc: 0.5111 | Val Loss: 1.2437, Val Acc: 0.1962
  -> New best model saved (val_acc=0.1962)
Epoch 02/100 - Train Loss: 0.2255, Train Acc: 0.7114 | Val Loss: 1.2188, Val Acc: 0.2726
  -> New best model saved (val_acc=0.2726)
Epoch 03/100 - Train Loss: 0.1458, Train Acc: 0.7849 | Val Loss: 0.8367, Val Acc: 0.5177
  -> New best model saved (val_acc=0.5177)
Epoch 04/100 - Train Loss: 0.1244, Train Acc: 0.8142 | Val Loss: 0.7872, Val Acc: 0.5597
  -> New best model saved (val_acc=0.5597)
Epoch 05/100 - Train Loss: 0.0892, Train Acc: 0.8538 | Val Loss: 0.6925, Val Acc: 0.6385
  -> New best model saved (val_acc=0.6385)
Epoch 06/100 - Train Loss: 0.0906, Train Acc: 0.8608 | Val Loss: 0.8687, Val Acc: 0.5567
Epoch 07/100 - Train Loss: 0.0745, Train Acc: 0.8684 | Val Loss: 0.8064, Val Acc: 0.5467
Epoch 08/100 - Train Loss: 0.0959, Train Acc: 0.8580 | Val Loss: 0.7425, Val Acc: 0.6246
Epoch 09/100 - Train Loss: 0.0564, Train Acc: 0.8897 | Val Loss: 0.8218, 