# Skin Lesion Classification Model Explanations

This notebook demonstrates the implementation of explainable AI techniques (LIME and SHAP) for the skin lesion classification model. We'll train a CNN model based on the architecture described in the paper "Skin lesion classification of dermoscopic images using machine learning and convolutional neural network" and explain its predictions.


In [None]:
# Import libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import time
from tqdm.notebook import tqdm
import sys
from pathlib import Path
import random
import seaborn as sns
from skimage.segmentation import mark_boundaries

In [None]:
# For explainability
import lime
from lime import lime_image
import shap

In [None]:
# Add the project root directory to the Python path
sys.path.append('..')

In [None]:
# Import project modules
from XAI.config import CLASS_NAMES, MODEL_INPUT_SIZE, RANDOM_SEED, BATCH_SIZE
from XAI.modeling.model import SkinLesionCNN
from XAI.dataset import prepare_data, get_transforms

In [None]:
# Set plotting style
%matplotlib inline
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (12, 8)

In [None]:
# Set random seeds for reproducibility
def set_seed(seed=RANDOM_SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

## 1. Load and Prepare Data

In [None]:
# Load the data
train_loader, val_loader, test_loader = prepare_data(balanced=True)

In [None]:
# Display the number of batches in each loader
print(f"Number of batches in training set: {len(train_loader)}")
print(f"Number of batches in validation set: {len(val_loader)}")
print(f"Number of batches in test set: {len(test_loader)}")

In [None]:
# View a batch of training data
def show_batch(loader, num_samples=9):
    # Get a batch of data
    images, labels = next(iter(loader))

    # Convert from tensor to numpy for visualization
    images = images[:num_samples].cpu().numpy()
    labels = labels[:num_samples].cpu().numpy()

    # Denormalize images for display
    images = np.transpose(images, (0, 2, 3, 1))
    images = images * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    images = np.clip(images, 0, 1)

    # Plot images
    fig, axes = plt.subplots(3, 3, figsize=(10, 10))
    axes = axes.flatten()

    for i in range(num_samples):
        axes[i].imshow(images[i])
        axes[i].set_title(f"Class: {list(CLASS_NAMES.values())[labels[i]]}")
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

show_batch(train_loader)

## 2. Model Creation and Training

In [None]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Create the model
model = SkinLesionCNN().to(device)
print(model)

In [None]:
# Define training hyperparameters
num_epochs = 30  # Reduced for demonstration
learning_rate = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

In [None]:
# Define training and validation functions
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in tqdm(loader, desc="Training"):
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Calculate metrics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc="Validating"):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Calculate metrics
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
# Train the model
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs):
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }

    best_val_acc = 0.0

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")

        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)

        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)

        # Adjust learning rate
        scheduler.step(val_loss)

        # Save metrics
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        # Print metrics
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), '../models/best_model.pth')
            print(f"Model saved with val_acc: {val_acc:.4f}")

        print("-" * 50)

    return history

In [None]:
# Create directory for model checkpoints if it doesn't exist
os.makedirs('../models', exist_ok=True)

In [None]:
# Train the model
history = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs)

In [None]:
# Plot training history
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train')
plt.plot(history['val_loss'], label='Validation')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train')
plt.plot(history['val_acc'], label='Validation')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# Load the best model
best_model_path = '../models/best_model.pth'
model.load_state_dict(torch.load(best_model_path))
model.eval()

In [None]:
# Evaluate on test set
test_loss, test_acc = validate(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")

## 3. Model Explanations with LIME

In [None]:
# Get a batch of test images
test_images, test_labels = next(iter(test_loader))
test_images = test_images.to(device)
test_labels = test_labels.to(device)

In [None]:
# Make predictions
with torch.no_grad():
    outputs = model(test_images)
    _, predicted = torch.max(outputs, 1)

In [None]:
# Display some test images and their predictions
num_images_to_show = 6
plt.figure(figsize=(15, 10))

In [None]:
images_to_explain = []
for i in range(num_images_to_show):
    # Get image and convert for display
    img = test_images[i].cpu().numpy().transpose(1, 2, 0)
    img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img = np.clip(img, 0, 1)
    images_to_explain.append(img)

    plt.subplot(2, 3, i+1)
    plt.imshow(img)
    plt.title(f"True: {list(CLASS_NAMES.values())[test_labels[i]]}, \nPred: {list(CLASS_NAMES.values())[predicted[i]]}")
    plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Define a function to make predictions for LIME
def predict_fn(images):
    # Convert images to PyTorch format
    batch = torch.stack([transforms.ToTensor()(img) for img in images])

    # Normalize images
    normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    batch = torch.stack([normalize(img) for img in batch])

    # Resize images to match model input size
    resize = transforms.Resize(MODEL_INPUT_SIZE)
    batch = torch.stack([resize(img.unsqueeze(0)).squeeze(0) for img in batch])

    batch = batch.to(device)

    # Make prediction
    with torch.no_grad():
        outputs = model(batch)
        probs = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()

    return probs

In [None]:
# Initialize LIME explainer
explainer = lime_image.LimeImageExplainer()

In [None]:
# Function to explain an image with LIME
def explain_with_lime(image, explainer, predict_fn, num_samples=1000):
    explanation = explainer.explain_instance(
        image,
        predict_fn,
        top_labels=len(CLASS_NAMES),
        hide_color=0,
        num_samples=num_samples
    )
    return explanation

In [None]:
# Get LIME explanations for a few test images
explanations = []
for i, img in enumerate(images_to_explain[:3]):  # Just explain 3 images to save time
    print(f"Explaining image {i+1}/3...")
    explanation = explain_with_lime(img, explainer, predict_fn)
    explanations.append(explanation)

In [None]:
# Visualize LIME explanations
def show_lime_explanations(images, explanations, predictions, true_labels):
    num_images = len(images)
    fig, axs = plt.subplots(num_images, 3, figsize=(18, 6*num_images))

    for i in range(num_images):
        # Original image
        axs[i, 0].imshow(images[i])
        axs[i, 0].set_title(f"Original Image\nTrue: {list(CLASS_NAMES.values())[true_labels[i]]}\nPred: {list(CLASS_NAMES.values())[predictions[i]]}")
        axs[i, 0].axis('off')

        # Get the prediction label
        pred_label = predictions[i]

        # Positive explanation (features supporting the prediction)
        temp, mask = explanations[i].get_image_and_mask(
            pred_label, positive_only=True, num_features=5, hide_rest=False
        )
        axs[i, 1].imshow(mark_boundaries(temp, mask))
        axs[i, 1].set_title(f"Positive Influence\nHighlighting regions supporting prediction")
        axs[i, 1].axis('off')

        # Negative explanation (features against the prediction)
        temp, mask = explanations[i].get_image_and_mask(
            pred_label, positive_only=False, negative_only=True, num_features=5, hide_rest=False
        )
        axs[i, 2].imshow(mark_boundaries(temp, mask))
        axs[i, 2].set_title(f"Negative Influence\nHighlighting regions against prediction")
        axs[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
# Get the predictions and true labels
predictions = predicted[:3].cpu().numpy()
true_labels = test_labels[:3].cpu().numpy()

In [None]:
# Show explanations
show_lime_explanations(images_to_explain[:3], explanations, predictions, true_labels)

## 4. Model Explanations with SHAP

In [None]:
# Create a DeepExplainer (SHAP)
# We need a background dataset for DeepExplainer
background = torch.stack([test_images[i] for i in range(10)])

In [None]:
# Define a wrapper for the model to handle SHAP inputs properly
class ModelWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

model_wrapper = ModelWrapper(model)

In [None]:
# Initialize SHAP explainer
shap_explainer = shap.DeepExplainer(model_wrapper, background)

In [None]:
# Select a few images to explain
images_to_explain_shap = test_images[:3]  # Just use 3 images to save time

In [None]:
# Generate SHAP values
print("Generating SHAP values... (this may take a while)")
shap_values = shap_explainer.shap_values(images_to_explain_shap)

In [None]:
# Print shapes to understand the SHAP output
print(f"Number of classes: {len(shap_values)}")
print(f"SHAP values shape for first class: {shap_values[0].shape}")

In [None]:
# Convert test images to numpy for visualization
test_images_np = images_to_explain_shap.cpu().numpy()

In [None]:
# Visualize SHAP values
def plot_shap_explanations(images, shap_values, predictions, true_labels):
    num_images = len(images)
    fig, axs = plt.subplots(num_images, 2, figsize=(15, 5*num_images))

    for i in range(num_images):
        # Get the prediction for this image
        pred_idx = predictions[i]
        true_idx = true_labels[i]

        # Original image
        img = images[i].transpose(1, 2, 0)
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)

        axs[i, 0].imshow(img)
        axs[i, 0].set_title(f"Original Image\nTrue: {list(CLASS_NAMES.values())[true_idx]}\nPred: {list(CLASS_NAMES.values())[pred_idx]}")
        axs[i, 0].axis('off')

        # SHAP values for the predicted class
        shap_values_pred_class = shap_values[pred_idx][i]

        # Calculate absolute SHAP values and sum across color channels
        abs_shap = np.abs(shap_values_pred_class).sum(axis=0)

        # Normalize for better visualization
        abs_shap = abs_shap / abs_shap.max()

        # Create an RGB image where intensity represents SHAP importance
        shap_overlay = np.zeros(img.shape)
        for c in range(3):
            shap_overlay[:,:,c] = img[:,:,c]

        # Use a colormap for better visualization
        cmap = plt.cm.hot
        shap_img = cmap(abs_shap)[:,:,:3]  # Drop the alpha channel

        # Blend original image with SHAP visualization
        alpha = 0.7  # transparency of the SHAP overlay
        blended = img * (1-alpha) + shap_img * alpha

        axs[i, 1].imshow(blended)
        axs[i, 1].set_title(f"SHAP Explanation\nHighlighting influential regions for prediction")
        axs[i, 1].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
# Get the predictions and true labels for the selected images
with torch.no_grad():
    outputs = model(images_to_explain_shap)
    _, shap_predictions = torch.max(outputs, 1)

In [None]:
# Show SHAP explanations
true_labels_shap = test_labels[:3].cpu().numpy()
predictions_shap = shap_predictions.cpu().numpy()

In [None]:
plot_shap_explanations(test_images_np, shap_values, predictions_shap, true_labels_shap)

In [None]:
# Create a summary plot for one image
plt.figure(figsize=(10, 8))
shap_values_to_plot = [sv[0] for sv in shap_values]  # First image
shap.image_plot(shap_values_to_plot, -test_images_np[0], show=False)
plt.title(f"SHAP Summary Plot for {list(CLASS_NAMES.values())[predictions_shap[0]]}")
plt.tight_layout()
plt.show()

In [None]:
# Compare LIME and SHAP for the first image
fig, axs = plt.subplots(1, 3, figsize=(18, 6))

In [None]:
# Original image
img0 = images_to_explain[0]
axs[0].imshow(img0)
axs[0].set_title(f"Original Image\nTrue: {list(CLASS_NAMES.values())[true_labels[0]]}\nPred: {list(CLASS_NAMES.values())[predictions[0]]}")
axs[0].axis('off')

In [None]:
# LIME explanation
pred_label = predictions[0]
temp, mask = explanations[0].get_image_and_mask(
    pred_label, positive_only=True, num_features=5, hide_rest=False
)
axs[1].imshow(mark_boundaries(temp, mask))
axs[1].set_title(f"LIME Explanation\nHighlighting regions supporting prediction")
axs[1].axis('off')

In [None]:
# SHAP explanation
shap_values_pred_class = shap_values[pred_label][0]
abs_shap = np.abs(shap_values_pred_class).sum(axis=0)
abs_shap = abs_shap / abs_shap.max()
cmap = plt.cm.hot
shap_img = cmap(abs_shap)[:,:,:3]
alpha = 0.7
blended = img0 * (1-alpha) + shap_img * alpha
axs[2].imshow(blended)
axs[2].set_title(f"SHAP Explanation\nHighlighting influential regions for prediction")
axs[2].axis('off')

plt.tight_layout()
plt.show()