In [1]:
# ============================================================
# HAM10000 classification with ResNet-34
# Uses GroundTruth.csv with columns:
#  image, MEL, NV, BCC, AKIEC, BKL, DF, VASC
# ============================================================

import os
from pathlib import Path
import random
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from torchvision import models, transforms

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    confusion_matrix, classification_report,
    roc_curve, auc
)

# ----------------- CONFIG -----------------
DATA_ROOT = Path("/kaggle/input/ham1000-segmentation-and-classification")
CSV_PATH = DATA_ROOT / "GroundTruth.csv"

# use CROPPED lesion images instead of original full images
IMAGES_DIR = Path("/kaggle/input/ham10000-segment-data/ham_lesion_crops")

OUTPUT_DIR = Path("/kaggle/working/ham_cls_outputs_lesion_crops")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# rest (IMG_SIZE, BATCH_SIZE, etc.) unchanged...

IMG_SIZE = 224            # typical for ResNet
BATCH_SIZE = 32
NUM_EPOCHS = 100
LR = 1e-4
VAL_SPLIT = 0.2
SEED = 42
NUM_WORKERS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("Using device:", DEVICE)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)

# ----------------- LOAD LABELS -----------------
df = pd.read_csv(CSV_PATH)

# columns for one-hot labels
CLASS_NAMES = ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"]
num_classes = len(CLASS_NAMES)

# convert one-hot columns to single label index
label_array = df[CLASS_NAMES].values  # shape: (N,7)
df["label_idx"] = label_array.argmax(axis=1)

print(df.head())

# ----------------- TRAIN / VAL SPLIT -----------------
train_df, val_df = train_test_split(
    df,
    test_size=VAL_SPLIT,
    random_state=SEED,
    stratify=df["label_idx"]
)

print(f"Total samples: {len(df)}")
print(f"Train: {len(train_df)}, Val: {len(val_df)}")

# ----------------- DATASET -----------------

class HAMClassificationDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_id = row["image"]
        label = int(row["label_idx"])

        img_path = self.img_dir / f"{img_id}.jpg"
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

# ----------------- TRANSFORMS -----------------
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

train_dataset = HAMClassificationDataset(train_df, IMAGES_DIR, train_transform)
val_dataset   = HAMClassificationDataset(val_df, IMAGES_DIR, val_transform)

# ----------------- CLASS IMBALANCE HANDLING -----------------

# compute class counts in training set
class_counts = train_df["label_idx"].value_counts().sort_index().values
print("Class counts (train):", dict(zip(range(num_classes), class_counts)))

# weight for each sample = 1 / class_count
class_weights = 1.0 / class_counts
sample_weights = class_weights[train_df["label_idx"].values]
sample_weights = torch.from_numpy(sample_weights).float()

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

# ----------------- MODEL (ResNet-34) -----------------
resnet = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
# replace final FC
in_features = resnet.fc.in_features
resnet.fc = nn.Linear(in_features, num_classes)

model = resnet.to(DEVICE)

# Class weights for loss (optional, helps slight imbalance)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=3, verbose=True
)

print("Model params (M):", sum(p.numel() for p in model.parameters()) / 1e6)

# ----------------- TRAIN & EVAL FUNCTIONS -----------------

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    for images, labels in loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * images.size(0)
        running_corrects += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total
    return epoch_loss, epoch_acc


def evaluate(model, loader, criterion, collect_probs=False):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    all_labels = []
    all_preds = []
    all_probs = []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = model(images)
            loss = criterion(outputs, labels)

            probs = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)

            running_loss += loss.item() * images.size(0)
            running_corrects += (preds == labels).sum().item()
            total += labels.size(0)

            all_labels.append(labels.cpu().numpy())
            all_preds.append(preds.cpu().numpy())
            if collect_probs:
                all_probs.append(probs.cpu().numpy())

    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total

    all_labels = np.concatenate(all_labels)
    all_preds = np.concatenate(all_preds)

    if collect_probs:
        all_probs = np.concatenate(all_probs)
    else:
        all_probs = None

    return epoch_loss, epoch_acc, all_labels, all_preds, all_probs

# ----------------- TRAINING LOOP -----------------

history = []
best_val_acc = 0.0
best_model_path = OUTPUT_DIR / "best_cls_model.pth"

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc, _, _, _ = evaluate(model, val_loader, criterion, collect_probs=False)

    scheduler.step(val_acc)

    history.append({
        "epoch": epoch,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "val_loss": val_loss,
        "val_acc": val_acc,
    })

    print(
        f"Epoch {epoch:02d}/{NUM_EPOCHS} - "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
    )

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"  -> New best model saved (val_acc={best_val_acc:.4f})")

# save training history
history_df = pd.DataFrame(history)
history_csv_path = OUTPUT_DIR / "training_history.csv"
history_df.to_csv(history_csv_path, index=False)
print("Saved training history to:", history_csv_path)

# ----------------- PLOTS: LOSS & ACC -----------------

plt.figure()
plt.plot(history_df["epoch"], history_df["train_loss"], label="Train Loss")
plt.plot(history_df["epoch"], history_df["val_loss"], label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Classification Loss Curves")
plt.legend()
loss_fig_path = OUTPUT_DIR / "loss_curves.png"
plt.savefig(loss_fig_path, dpi=150, bbox_inches="tight")
plt.close()
print("Saved loss curves to:", loss_fig_path)

plt.figure()
plt.plot(history_df["epoch"], history_df["train_acc"], label="Train Acc")
plt.plot(history_df["epoch"], history_df["val_acc"], label="Val Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Classification Accuracy Curves")
plt.legend()
acc_fig_path = OUTPUT_DIR / "acc_curves.png"
plt.savefig(acc_fig_path, dpi=150, bbox_inches="tight")
plt.close()
print("Saved accuracy curves to:", acc_fig_path)

# ----------------- FINAL EVAL: CONFUSION MATRIX, REPORT, ROC -----------------

# load best model
model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))

val_loss, val_acc, y_true, y_pred, y_prob = evaluate(
    model, val_loader, criterion, collect_probs=True
)

print(f"Final Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
print("Confusion matrix:\n", cm)

cm_fig_path = OUTPUT_DIR / "confusion_matrix.png"

plt.figure(figsize=(6, 5))
plt.imshow(cm, interpolation="nearest")
plt.title("Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(num_classes)
plt.xticks(tick_marks, CLASS_NAMES, rotation=45)
plt.yticks(tick_marks, CLASS_NAMES)
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(
            j, i, format(cm[i, j], "d"),
            horizontalalignment="center",
            color="white" if cm[i, j] > thresh else "black",
        )
plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.tight_layout()
plt.savefig(cm_fig_path, dpi=150)
plt.close()
print("Saved confusion_matrix to:", cm_fig_path)

# Classification report
report = classification_report(
    y_true, y_pred,
    target_names=CLASS_NAMES, digits=4
)
print("Classification report:\n", report)

report_path = OUTPUT_DIR / "classification_report.txt"
with open(report_path, "w") as f:
    f.write(report)
print("Saved classification_report to:", report_path)

# ROC curves (one-vs-rest)
if y_prob is not None:
    roc_fig_path = OUTPUT_DIR / "roc_curves.png"
    plt.figure(figsize=(8, 6))
    for i, cls_name in enumerate(CLASS_NAMES):
        # binarize labels
        y_true_bin = (y_true == i).astype(int)
        fpr, tpr, _ = roc_curve(y_true_bin, y_prob[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f"{cls_name} (AUC = {roc_auc:.3f})")

    plt.plot([0, 1], [0, 1], "k--", label="Random")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curves (One-vs-Rest)")
    plt.legend(loc="lower right")
    plt.tight_layout()
    plt.savefig(roc_fig_path, dpi=150)
    plt.close()
    print("Saved ROC curves to:", roc_fig_path)

print("All done (classification).")

Using device: cuda
          image  MEL   NV  BCC  AKIEC  BKL   DF  VASC  label_idx
0  ISIC_0024306  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
1  ISIC_0024307  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
2  ISIC_0024308  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
3  ISIC_0024309  0.0  1.0  0.0    0.0  0.0  0.0   0.0          1
4  ISIC_0024310  1.0  0.0  0.0    0.0  0.0  0.0   0.0          0
Total samples: 10015
Train: 8012, Val: 2003
Class counts (train): {0: 890, 1: 5364, 2: 411, 3: 262, 4: 879, 5: 92, 6: 114}


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 213MB/s]


Model params (M): 21.288263




Epoch 01/100 - Train Loss: 0.4312, Train Acc: 0.6485 | Val Loss: 0.8868, Val Acc: 0.5477
  -> New best model saved (val_acc=0.5477)
Epoch 02/100 - Train Loss: 0.2170, Train Acc: 0.7560 | Val Loss: 0.8971, Val Acc: 0.4978
Epoch 03/100 - Train Loss: 0.1664, Train Acc: 0.7976 | Val Loss: 0.7957, Val Acc: 0.6061
  -> New best model saved (val_acc=0.6061)
Epoch 04/100 - Train Loss: 0.1285, Train Acc: 0.8240 | Val Loss: 0.8491, Val Acc: 0.5347
Epoch 05/100 - Train Loss: 0.1120, Train Acc: 0.8354 | Val Loss: 0.7634, Val Acc: 0.5931
Epoch 06/100 - Train Loss: 0.1008, Train Acc: 0.8499 | Val Loss: 0.9037, Val Acc: 0.5831
Epoch 07/100 - Train Loss: 0.0917, Train Acc: 0.8576 | Val Loss: 0.6931, Val Acc: 0.6016
Epoch 08/100 - Train Loss: 0.0611, Train Acc: 0.8834 | Val Loss: 0.6176, Val Acc: 0.6640
  -> New best model saved (val_acc=0.6640)
Epoch 09/100 - Train Loss: 0.0431, Train Acc: 0.9036 | Val Loss: 0.6213, Val Acc: 0.6710
  -> New best model saved (val_acc=0.6710)
Epoch 10/100 - Train Loss: 