# Explainable AI for Skin Cancer Detection

This notebook demonstrates how to use the explainability methods in the XAI package to understand the predictions of skin lesion classification models.

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch
from PIL import Image

# Add the project directory to the path
# This assumes the notebook is in the notebooks/ directory
project_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_dir not in sys.path:
    sys.path.append(project_dir)

# Import project modules
from XAI.config import MODELS_DIR, FIGURES_DIR, CLASS_NAMES, MODEL_INPUT_SIZE
from XAI.dataset import get_transforms
from XAI.modeling.ResizeLayer import ResizedModel
from XAI.modeling.AllModels import dl_models, device
from XAI.modeling.train import load_best_model
from XAI.explainers import LimeExplainer, ShapExplainer, GradCamExplainer

# Set up matplotlib for inline plotting
%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)

## 1. Load a Trained Model

First, we'll load one of our trained models. You can choose a different model by changing the `model_idx`.

In [None]:
# Choose a model to explain (0 for SkinLesionCNN, 2 for CustomCNN, etc.)
model_idx = 0

# Get model class
model_class = dl_models[model_idx]
model_name = model_class.name()
print(f"Using model: {model_name}")

# Create model with proper input size
model = ResizedModel(model_class.inputSize(), model_class()).to(device)

# Load the best model weights
best_model_path, checkpoint = load_best_model(model_name)

if checkpoint is not None:
    # Load model weights
    model.load_state_dict(checkpoint["model_state_dict"])
    print(f"Loaded model from {best_model_path}")
else:
    print(f"No saved model found for {model_name}, using untrained model")

# Set model to evaluation mode
model.eval();

## 2. Load and Preprocess an Image

Now, let's load a sample image to explain. You can replace this with any skin lesion image.

In [None]:
# Replace with the path to your sample image
# You can use a sample from the HAM10000 dataset in data/interim/organized_by_class/
image_path = "../data/interim/organized_by_class/mel/ISIC_0024306.jpg"  # Example melanoma image

# Load image
image = np.array(Image.open(image_path).convert("RGB"))

# Display the image
plt.figure(figsize=(8, 8))
plt.imshow(image)
plt.title("Sample Image")
plt.axis("off")
plt.show()

# Preprocess the image
transform = get_transforms("val")
image_tensor = transform(image)

## 3. Make a Prediction

Let's get the model's prediction for this image.

In [None]:
# Add batch dimension and move to device
batch_tensor = image_tensor.unsqueeze(0).to(device)

# Make prediction
with torch.no_grad():
    outputs = model(batch_tensor)
    probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
    predicted_class = torch.argmax(probabilities).item()

# Get class name
class_keys = list(CLASS_NAMES.keys())
class_name = CLASS_NAMES[class_keys[predicted_class]]
print(f"Predicted class: {class_name} (index: {predicted_class})")

# Plot probabilities
plt.figure(figsize=(10, 6))
plt.bar(range(len(CLASS_NAMES)), probabilities.cpu().numpy())
plt.xlabel("Class Index")
plt.ylabel("Probability")
plt.title("Class Probabilities")
plt.xticks(range(len(CLASS_NAMES)), [CLASS_NAMES[cls] for cls in class_keys], rotation=45, ha="right")
plt.tight_layout()
plt.show()

## 4. Explain the Prediction with LIME

Now, let's use LIME to explain why the model made this prediction.

In [None]:
# Define preprocessing function for LIME
def preprocess_fn(img):
    return transform(img)

# Initialize LIME explainer
lime_explainer = LimeExplainer(model, device, CLASS_NAMES, preprocess_fn)

# Generate explanation
print("Generating LIME explanation (this may take a minute)...")
lime_exp = lime_explainer.explain(image, num_samples=500)  # Reduce samples for speed in notebook

# Visualize explanation
lime_fig, _ = lime_explainer.visualize(lime_exp, image, label=predicted_class)
plt.show()

## 5. Explain the Prediction with SHAP

Next, let's use SHAP to provide another explanation.

In [None]:
# Initialize SHAP explainer
shap_explainer = ShapExplainer(model, device, CLASS_NAMES, preprocess_fn)

# Generate explanation (use a small number of samples for speed in notebook)
print("Generating SHAP explanation (this may take a minute)...")
shap_values = shap_explainer.explain(image, n_samples=25)  # Use a small number for speed

# Visualize explanation
shap_fig, _ = shap_explainer.visualize(shap_values, image, label=predicted_class)
plt.show()

## 6. Explain the Prediction with GradCAM

Finally, let's use GradCAM to highlight the regions of the image that influenced the prediction.

In [None]:
# Initialize GradCAM explainer
try:
    gradcam_explainer = GradCamExplainer(model)
    
    # Generate explanation
    print("Generating GradCAM explanation...")
    gradcam_heatmap = gradcam_explainer.explain(batch_tensor)
    
    # Convert image to [0, 1] range for visualization
    normalized_image = image.astype(float) / 255
    
    # Visualize
    cam_image = gradcam_explainer.visualize(gradcam_heatmap, normalized_image, class_name=class_name)
    plt.show()
except Exception as e:
    print(f"Error generating GradCAM explanation: {e}")
    print("This may happen if the model architecture is not compatible with GradCAM.")

## 7. Comparison of Explanations

Let's compare all the explanations side by side.

In [None]:
# Create a figure for all explanations
plt.figure(figsize=(16, 6))

# Original image
plt.subplot(1, 4, 1)
plt.imshow(image)
plt.title(f"Original\nPrediction: {class_name}")
plt.axis("off")

# LIME explanation
plt.subplot(1, 4, 2)
temp, mask = lime_exp.get_image_and_mask(
    predicted_class, 
    positive_only=True, 
    num_features=5, 
    hide_rest=False
)
plt.imshow(mark_boundaries(temp, mask))
plt.title("LIME Explanation")
plt.axis("off")

# SHAP explanation
plt.subplot(1, 4, 3)
# For visualization, sum absolute SHAP values across channels
shap_combined = np.abs(shap_values[predicted_class][0]).sum(axis=0)
# Normalize to [0, 1] for visualization
shap_normalized = shap_combined / shap_combined.max()
plt.imshow(shap_normalized, cmap='hot')
plt.title("SHAP Values")
plt.axis("off")

# GradCAM explanation
try:
    plt.subplot(1, 4, 4)
    plt.imshow(cam_image)
    plt.title("GradCAM Heatmap")
    plt.axis("off")
except:
    pass

plt.tight_layout()
plt.show()

## 8. Conclusion

In this notebook, we've demonstrated how to use three different explainability methods to understand the predictions of our skin lesion classification model:

1. **LIME**: Shows which regions of the image support or contradict the prediction.
2. **SHAP**: Assigns importance values to each pixel based on cooperative game theory.
3. **GradCAM**: Highlights areas of the image that activate specific features in the network.

These explanations provide complementary views of the model's decision-making process, which is crucial for building trust in the model, especially in medical applications.