<a href="https://colab.research.google.com/github/nyp-sit/iti121-2025s2/blob/main/L4/explainable_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Explainable AI

This lab exercises explores using techniques such as Grad-CAM and LIME to explain the output of CNN image model.

## Part 1: Grad-CAM

In this exercise, we will use Grad-CAM to visualize what features are important in influencing the model prediction.

We will be using [pytorch-gradcam](https://github.com/jacobgil/pytorch-grad-cam) package for this exercise

In [None]:
%pip install grad-cam

In [None]:
from PIL import Image
import torch
import os
import cv2
import numpy as np
from torchvision import models
import torchvision.transforms as transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import (
    show_cam_on_image, deprocess_image, preprocess_image
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, ClassifierOutputReST

We will load the image using cv2 library. We scale the image to be between 0 and 1.0 and further process the image to have the means and std deviations required for resnet as we will be using resnet for our image classification task.

Different pretrained network will have different ways to process the image (rescaling, resizing, etc).

You can find different transformation that is required by looking at the `transform()` method bundled with the weights.

```python

from torchvision.models import resnet50, ResNet50_Weights

weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

```



In [None]:
image = Image.open('cockatoo.jpeg')

rgb_img = cv2.imread('cockatoo.jpeg', 1)[:, :, ::-1]
rgb_img = np.float32(rgb_img) / 255
input_tensor = preprocess_image(rgb_img,
                                    mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]).to("cpu")


We will use the pretrained network Resnet50 for predicting the image.  

You will need to choose the target layer you want to compute the visualization for.
Usually this will be the last convolutional layer in the model.
Some common choices can be:
- Resnet18 and 50: model.layer4
- VGG, densenet161: model.features[-1]
- mnasnet1_0: model.layers[-1]

You can print the model to help chose the layer
You can pass a list with several target layers,
in that case the CAMs will be computed per layer and then aggregated.

In [None]:
# specify the output directory to write the visualization to
output_dir = "output"
output_file = "GradCam_cam.jpg"

# use CPU for computation
device = torch.device("cpu")

In [None]:
model = models.resnet50(pretrained=True).to(device).eval()
target_layers = [model.layer4]


In [None]:
# targets = None    # If targets is None, the highest scoring category (for every member in the batch) will be used.
targets = [ClassifierOutputTarget(89)]   # Take the gradient of the score for class 281 w.r.t. the convolutional activations.”
# targets = [ClassifierOutputReST(89)] # Highlight regions that specifically distinguish one class from the rest. This often produces sharper, more discriminative heatmaps,


cam_algorithm = GradCAM
with cam_algorithm(model=model,
                    target_layers=target_layers) as cam:

    # cam.batch_size = 32
    grayscale_cam = cam(input_tensor=input_tensor,
                        targets=targets,
                        aug_smooth=True,  # Apply test time augmentation to smooth the CAM
                        eigen_smooth=True) # Reduce noise by taking the first principle component

    # heat map
    grayscale_cam = grayscale_cam[0, :]

    # overlays the Grad-CAM heatmap (grayscale intensity) on top of the original RGB image — so you can visually see where the model is focusing.
    cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

    # OpenCV internally represents images in BGR order by default.
    cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)

os.makedirs(output_dir, exist_ok=True)
cam_output_path = os.path.join(output_dir, output_file)

cv2.imwrite(cam_output_path, cam_image)


Let's display the resultant viualization.

In [None]:
import matplotlib.pyplot as plt

img = plt.imread(cam_output_path)
plt.imshow(img)

## Part 2: LIME

In this part of the exercise, we will use LIME, a perturbation, black-box technique to explain the output of the image model.

In [None]:
%pip install lime

In [None]:
from lime import lime_image
from skimage.segmentation import mark_boundaries
from torchvision import models, transforms
from PIL import Image
import numpy as np
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 1. Load pretrained model ---
model = models.resnet18(pretrained=True).to(device).eval()

### Preprocess image

Before passing the image to model, we need to preprocess image (transform, resize, etc) to what the model expected during its training.  For example, resnet expects the images to have mean of (0.485, 0.456, 0.406) and std deviation of (0.229, 0.224, 0.225), for each channel.

In [None]:
# Preprocess the image
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


We need to define a classifier function that LIME can use to get the predicted probabilities

In [None]:
def batch_predict(images):
    """Convert numpy arrays to tensor batch and predict"""
    batch = torch.stack([preprocess(Image.fromarray(img.astype('uint8')))
                         for img in images], dim=0).to(device)

    with torch.no_grad():
        logits = model(batch)
        probs = torch.nn.functional.softmax(logits, dim=1)

    return probs.cpu().numpy()

Now will use the LimeImageExplainer to introduce perturbations (masking off regions) and train a simple linear model (surrogate model) to predict feature importance.

`num_samples` tells LIME how many perturbed versions of the input image to generate when building its local surrogate model.  Higher num_samples will provide smoother and more stable explanations, at the expense of longer runtime.

In [None]:
# Load example image
img = Image.open("cockatoo.jpeg").convert("RGB")

# Initialize LIME explainer
explainer = lime_image.LimeImageExplainer()

# Explain a prediction
explanation = explainer.explain_instance(
    np.array(img),
    classifier_fn=batch_predict,
    top_labels=1,  # LIME will only explain the top predicted label (the class with highest probability).
    hide_color=0, # When LIME “hides” a superpixel, it replaces its pixels with this value (color). In this case, it is black
    num_samples=1000 # Number of perturbed samples (versions of the image) to generate.
)


Now let's visualize which regions are more important for the prediction.

In [None]:
# Visualize result
from matplotlib import pyplot as plt
temp, mask = explanation.get_image_and_mask(
    label=explanation.top_labels[0],  # which class to explain
    positive_only=True, # Show only features (superpixels) that increase the probability of that class
    hide_rest=True, # If True, hide non-important regions (fill them with gray/black); if False, keep the full image visible
    num_features=5, # Number of most influential superpixels to highlight
    min_weight=0.0 # Minimum importance threshold
)

plt.imshow(mark_boundaries(temp / 255.0, mask))
plt.title("LIME Explanation")
plt.axis("off")
plt.show()