<a target="_blank" href="https://colab.research.google.com/github/compomics/ML-course-VIB-2024/blob/master/notebooks/Melanoma_CNN.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

*Disclaimer: this notebook extends the analysis of the [Flowers Image Classification notebook](https://www.tensorflow.org/tutorials/images/classification) to Malanoma images.*

# Skin lesion image classification

The data will be downloaded from the [
ISIC 2018: Skin Lesion Analysis Towards Melanoma Detection challenge
](https://challenge2018.isic-archive.com/) [1].

The goal of this recurring challenge is to help participants develop image analysis tools to enable the automated diagnosis of melanoma from dermoscopic images.

The lesion images come from the HAM10000 Dataset [2], and were acquired with a variety of dermatoscope types, from all anatomic sites (excluding mucosa and nails), from a historical sample of patients presented for skin cancer screening, from several different institutions. Images were collected with approval of the Ethics Review Committee of University of Queensland (Protocol-No. 2017001223) and Medical University of Vienna (Protocol-No. 1804/2017).

There are 7 classes:

- MEL: “Melanoma” diagnosis confidence
- NV: “Melanocytic nevus” diagnosis confidence
- BCC: “Basal cell carcinoma” diagnosis confidence
- AKIEC: “Actinic keratosis / Bowen’s disease (intraepithelial carcinoma)” diagnosis confidence
- BKL: “Benign keratosis (solar lentigo / seborrheic keratosis / lichen planus-like keratosis)” diagnosis confidence
- DF: “Dermatofibroma” diagnosis confidence
- VASC: “Vascular lesion” diagnosis confidence

The distribution of disease states represent a modified “real world” setting whereby there are more benign lesions than malignant lesions, but an over-representation of malignancies.

Here are some examples (taken from the ISIC2018 website):

<br/>
<br/>
<img src="https://miro.medium.com/v2/resize:fit:1400/format:webp/1*le3-EQ-rpTLKtgB4G8jKkw.png">
<br/>
<br/>

[1] Noel Codella, Veronica Rotemberg, Philipp Tschandl, M. Emre Celebi, Stephen Dusza, David Gutman, Brian Helba, Aadi Kalloo, Konstantinos Liopyris, Michael Marchetti, Harald Kittler, Allan Halpern: “Skin Lesion Analysis Toward Melanoma Detection 2018: A Challenge Hosted by the International Skin Imaging Collaboration (ISIC)”, 2018; https://arxiv.org/abs/1902.03368

[2] Tschandl, P., Rosendahl, C. & Kittler, H. The HAM10000 dataset, a large collection of multi-source dermatoscopic images of common pigmented skin lesions. Sci. Data 5, 180161 doi:10.1038/sdata.2018.161 (2018).


In [None]:
!pip install torch-xla pytorch-lightning torch==2.3.0 torchvision lightning numpy

## Import TensorFlow and other libraries

In [None]:
import random
import os
import pathlib

from tqdm import tqdm

import matplotlib.pyplot as plt

import torch
from torch.utils.data import random_split, DataLoader
from torch import nn
from torch.nn import functional as F

from torchvision import transforms, datasets

import pytorch_lightning as pl
from pytorch_lightning import Trainer

## Download and parse the dataset

In [None]:
!wget https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task3_Training_Input.zip
!unzip ISIC2018_Task3_Training_Input.zip
!wget https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task3_Training_GroundTruth.zip
!unzip ISIC2018_Task3_Training_GroundTruth.zip
!mkdir data

In [None]:
# Define directories
data_dir = pathlib.Path("data/")
input_dir = pathlib.Path("ISIC2018_Task3_Training_Input")
ground_truth_file = "ISIC2018_Task3_Training_GroundTruth/ISIC2018_Task3_Training_GroundTruth.csv"

# Read ground truth and prepare class folders
with open(ground_truth_file) as f:
    header = f.readline().strip().split(",")

    # Create class directories
    for dir_class in header[1:]:  # Skip the first column (image ID)
        class_dir = data_dir / dir_class
        class_dir.mkdir(parents=True, exist_ok=True)

# Process each image in the ground truth file
with open(ground_truth_file) as f:
    next(f)  # Skip the header
    for row in tqdm(f, desc="Organizing images"):
        row = row.strip().split(",")
        image_id = row[0]
        labels = row[1:]

        src = input_dir / f"{image_id}.jpg"
        if not src.exists():  # Skip if the source file does not exist
            continue

        for idx, label in enumerate(labels):
            if label == "1.0":  # Only process relevant labels
                dst = data_dir / header[idx + 1] / f"{image_id}.jpg"  # Class-specific directory
                os.rename(src, dst)  # Move the file

After downloading, you should now have a copy of the dataset available:

In [None]:
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

## Create a dataset

Define some parameters for the loader:

In [None]:
batch_size = 512
img_height = 64
img_width = 64
pl.seed_everything(42)

In [None]:
# Define data augmentation and preprocessing
transform = transforms.Compose([
    transforms.Resize((img_height, img_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
# Load dataset
data_dir = "data/"
dataset = datasets.ImageFolder(data_dir, transform=transform)

In [None]:
# Split into training and validation datasets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

In [None]:
# Define DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
# Display a random image from the dataset
class_names = dataset.classes
def show_random_image(dataset):
    idx = random.randint(0, len(dataset) - 1)
    print(idx)
    image, label = dataset[idx]
    plt.imshow(image.permute(1, 2, 0).numpy() * 0.5 + 0.5)  # Unnormalize
    plt.title(f"Label: {class_names[label]}")
    plt.show()

show_random_image(dataset)

In [None]:
class MelanomaClassifier(pl.LightningModule):
    def __init__(self, num_classes):
        super(MelanomaClassifier, self).__init__()
        # Convolutional layers with fewer filters
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # Fewer filters: 16
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)  # Fewer filters: 32
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # Additional layer for feature extraction

        # Pooling layer
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Global average pooling to reduce dimensions
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Fully connected layers with fewer neurons
        self.fc1 = nn.Linear(64, 64)  # Reduced drastically due to global pooling
        self.fc2 = nn.Linear(64, num_classes)

        # Dropout layer
        self.dropout = nn.Dropout(0.3)

        # Loss function
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        # Convolutional and pooling layers
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))

        # Global average pooling
        x = self.global_avg_pool(x)

        # Flatten for the fully connected layers
        x = torch.flatten(x, 1)  # (Batch size, Channels)

        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        preds = torch.argmax(outputs, dim=1)
        acc = (preds == labels).float().mean()
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_acc", acc, on_epoch=True, prog_bar=True, logger=True)

    def predict_step(self, batch, batch_idx):
        # Return logits for predictions
        inputs, _ = batch  # Ignore labels during prediction
        return self(inputs)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.1)

In [None]:
# Instantiate the model
num_classes = len(class_names)
model = MelanomaClassifier(num_classes=num_classes)

# Initialize Trainer to use only CPU
trainer = Trainer(
    max_epochs=5,
)

# Train the model
trainer.fit(model, train_loader, val_loader)

In [None]:
# Initialize the model
num_classes = len(dataset.classes)  # Replace with the number of classes
model = MelanomaClassifier(num_classes=num_classes)

# Load trained model (if already trained)
# model = MelanomaClassifier.load_from_checkpoint("path/to/checkpoint.ckpt")

# Initialize Trainer
trainer = Trainer(accelerator="cpu", devices=1, max_epochs=10)

# Predict on the validation set
predictions = trainer.predict(model, dataloaders=val_loader)

# Combine predictions into a single tensor
predictions = torch.cat(predictions)

# Print or inspect predictions
print(predictions)


In [None]:
# Convert logits to probabilities
probabilities = torch.softmax(predictions, dim=1)

# Get predicted class labels
predicted_labels = torch.argmax(probabilities, dim=1)

# Display predictions
print("Predicted probabilities:", probabilities)
print("Predicted labels:", predicted_labels)


In [None]:
import matplotlib.pyplot as plt
import torch

# Predict on the validation set
predictions = trainer.predict(model, dataloaders=val_loader)

# Combine predictions into a single tensor
predictions = torch.cat(predictions)

# Convert logits to probabilities and labels
probabilities = torch.softmax(predictions, dim=1)
predicted_labels = torch.argmax(probabilities, dim=1)

# Retrieve true labels and input images
val_images = []
val_labels = []

for batch in val_loader:
    inputs, labels = batch
    val_images.append(inputs)
    val_labels.append(labels)

val_images = torch.cat(val_images)
val_labels = torch.cat(val_labels)


In [None]:
def show_predictions(images, true_labels, predicted_labels, probabilities, class_names, num_images=5):
    """Displays a few predictions with their probabilities."""
    plt.figure(figsize=(15, num_images * 3))
    for i in range(num_images):
        ax = plt.subplot(1, num_images, i + 1)
        img = images[i].permute(1, 2, 0).numpy() * 0.5 + 0.5  # Unnormalize
        plt.imshow(img)
        true_label = class_names[true_labels[i].item()]
        predicted_label = class_names[predicted_labels[i].item()]
        confidence = probabilities[i][predicted_labels[i]].item()
        plt.title(f"True: {true_label}\nPred: {predicted_label}\nConf: {confidence:.2f}")
        plt.axis("off")
    plt.tight_layout()
    plt.show()

# Display predictions for the first few images
class_names = dataset.classes  # Get class names from the dataset
show_predictions(
    images=val_images[:5],
    true_labels=val_labels[:5],
    predicted_labels=predicted_labels[:5],
    probabilities=probabilities[:5],
    class_names=class_names
)
