In [None]:
"""
File to train and output a simple custom CNN baseline for HAM10k lesion classification
"""
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision import transforms

# Constants
SEED = 4
DATA_DIR = "data/ham10k_data/"
PATH_TO_METADATA_FILE = os.path.join(DATA_DIR, "HAM10000_metadata.csv")
PATH_TO_IMAGES = os.path.join(DATA_DIR, "HAM10000_images/")
BATCH_SIZE = 64
EPOCHS = 10  
NUM_CLASSES = 7  # number of diagnosis classes

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class HAMDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_id = self.df.loc[idx, "image_id"]
        label = int(self.df.loc[idx, "dx"])
        path = os.path.join(self.img_dir, img_id + ".jpg")
        image = read_image(path)  # returns C×H×W

        if self.transform:
            image = self.transform(image)

        return image, label


class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        # 3 Conv layers: 3->32, 32->64, 64->128
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        # single pooling + dropout
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(p=0.5)
        # classifier
        # after conv: input size remains 224, pooling halves to 112
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 112 * 112, num_classes)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.pool(x)
        x = self.dropout(x)
        x = self.classifier(x)
        return x



def train_model(model, train_loader, val_loader, criterion, optimizer, epochs):
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

    for epoch in range(1, epochs + 1):
        # Training
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * imgs.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total * 100
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)

        # Validation
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * imgs.size(0)
                preds = outputs.argmax(dim=1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

        val_loss = val_loss / val_total
        val_acc = val_correct / val_total * 100
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        print(f"Epoch {epoch}/{epochs} | "
              f"Train: loss={train_loss:.4f}, acc={train_acc:.2f}% | "
              f"Val:   loss={val_loss:.4f}, acc={val_acc:.2f}%")

    return history


def main() -> None:
    # read in metadata
    metadata_df = pd.read_csv(PATH_TO_METADATA_FILE)

    # map label to numeric
    label_map = {label: i for i, label in enumerate(metadata_df["dx"].unique())}
    print(label_map)
    metadata_df["dx"] = metadata_df["dx"].map(label_map)

    # split: 70% train, 30% temp
    train_df, temp_df = train_test_split(
        metadata_df, test_size=0.3, shuffle=True, random_state=SEED
    )
    # split temp into 50% test, 50% val (each 15% overall)
    test_df, val_df = train_test_split(
        temp_df, test_size=0.5, shuffle=True, random_state=SEED
    )

    # transforms & datasets
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
    ])

    train_ds = HAMDataset(train_df, PATH_TO_IMAGES, transform)
    val_ds   = HAMDataset(val_df,   PATH_TO_IMAGES, transform)
    test_ds  = HAMDataset(test_df,  PATH_TO_IMAGES, transform)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False)
    test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False)

    # model, loss, optimizer
    model = SimpleCNN(num_classes=NUM_CLASSES).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # train
    history = train_model(model, train_loader, val_loader,
                          criterion, optimizer, epochs=EPOCHS)

    # plot & save
    plt.figure(); plt.plot(history["train_acc"], label="Train Acc"); plt.plot(history["val_acc"], label="Val Acc"); plt.legend(); plt.savefig("accuracy_curve.png")
    plt.figure(); plt.plot(history["train_loss"], label="Train Loss"); plt.plot(history["val_loss"], label="Val Loss"); plt.legend(); plt.savefig("loss_curve.png")

    # test predictions
    model.eval()
    preds = []
    with torch.no_grad():
        for imgs, _ in test_loader:
            imgs = imgs.to(device)
            outputs = model(imgs)
            preds.extend(outputs.argmax(dim=1).cpu().numpy())

    # save preds
    filenames = test_df["image_id"].tolist()
    out_df = pd.DataFrame({"img_id": filenames, "pred_label": preds})
    out_df.to_csv("predictions.csv", index=False)
    print("Test predictions saved to predictions.csv")

    # test accuracy
    true = test_df["dx"].values
    acc = (np.array(preds) == true).mean() * 100
    print(f"Test Accuracy: {acc:.2f}%")

if __name__ == "__main__":
    main()
