## Importing Libraries

In [1]:
import os
import gzip
import shutil
import h5py
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.models import load_model

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from pytorch_grad_cam import GradCAM

from tqdm import tqdm

from xplique.attributions import (
    Saliency, IntegratedGradients,
    Occlusion, KernelShap, Lime
)
from xplique.metrics import Insertion, Deletion

import data_manipulation
import model
import explicability_pipeline

  from .autonotebook import tqdm as notebook_tqdm


## Importing data

In [19]:
from pathlib import Path
import os

base = "https://zenodo.org/record/2546921/files"
files = [
    "camelyonpatch_level_2_split_train_x.h5",
    "camelyonpatch_level_2_split_train_y.h5",
    "camelyonpatch_level_2_split_valid_x.h5",
    "camelyonpatch_level_2_split_valid_y.h5",
    "camelyonpatch_level_2_split_test_x.h5",
    "camelyonpatch_level_2_split_test_y.h5",
]

os.makedirs("data", exist_ok=True)

for fname in files:
    out_path = os.path.join("data", fname)
    if not os.path.exists(out_path):
        print(f"Downloading {fname}...")
        url = f"{base}/{fname}.gz?download=1"
        os.system(f"curl -L -o {out_path} '{url}'")

In [2]:
# Preprocessing 
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomApply([transforms.ColorJitter(0.2, 0.2, 0.2)], p=0.5)
])

# Downloaded files 
train_x = "data/camelyonpatch_level_2_split_train_x.h5"
train_y = "data/camelyonpatch_level_2_split_train_y.h5"
valid_x = "data/camelyonpatch_level_2_split_valid_x.h5"
valid_y = "data/camelyonpatch_level_2_split_valid_y.h5"
test_x  = "data/camelyonpatch_level_2_split_test_x.h5"
test_y  = "data/camelyonpatch_level_2_split_test_y.h5"

# Instantiate Datasets 
train_dataset = data_manipulation.PatchCamelyonH5Dataset(train_x, train_y, transform=transform)
valid_dataset = data_manipulation.PatchCamelyonH5Dataset(valid_x, valid_y)
test_dataset  = data_manipulation.PatchCamelyonH5Dataset(test_x, test_y)

# Instatiate Loaders
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

## Model

Training a model from scratch would require a significant amount of time — over 10 hours on local hardware.
To make the process more efficient and focused, we chose a different strategy: using a ResNet-18 backbone and training only its final classification head.

This choice offers several advantages:
- ResNet-18 is a lightweight yet powerful convolutional network, known for its solid performance even on relatively small datasets.
- It has a simple and clean architecture, making it more interpretable than deeper models like ResNet-50 or ResNet-101.
- Its relatively low number of parameters also reduces overfitting risk, especially important when working with limited data.

To further mitigate overfitting, we restricted training to only 10% of the dataset, while ensuring stratified sampling. This way, the model is still exposed to examples from all classes, despite the small training set.

In addition, we applied data augmentation:
- Random rotations to simulate different orientations of tissue patches.
- Color jitter to make the model more robust to variations in color and brightness that naturally occur in histology images.

### Why ResNet-18?

ResNet-18 strikes a balance between complexity and interpretability.
It is deep enough to capture relevant spatial features but shallow enough to allow methods like Grad-CAM and Integrated Gradients to produce meaningful and localized explanations.
Deeper models, while potentially more accurate, often create more diffused and harder-to-interpret explanation maps.

Overall, ResNet-18 provides a solid foundation for balancing performance, training time, and explainability, which is crucial given the goals of this project.

In [5]:
# Set to True only if you want to train again
TRAIN = True

device, checkpoint_path = model.setup_device_and_paths()
resnet18, optimizer, criterion = model.initialize_model(device)
resnet18, optimizer, start_epoch, best_val_acc = model.load_checkpoint(resnet18, optimizer, checkpoint_path, device)

if TRAIN:
    resnet18 = model.train_model(resnet18, optimizer, criterion, device,
                        checkpoint_path, train_loader, val_loader,
                        start_epoch=start_epoch, num_epochs=5, best_val_acc=best_val_acc)

Starting training from epoch 0 to 5.


Epoch 1/5: 100%|██████████| 4096/4096 [37:54<00:00,  1.80it/s]   


Epoch [1/5] Loss: 0.3644
Validation Accuracy: 0.8072
Best model saved at epoch 1 with val_acc 0.8072


Epoch 2/5: 100%|██████████| 4096/4096 [36:48<00:00,  1.85it/s]   


Epoch [2/5] Loss: 0.2491
Validation Accuracy: 0.8711
Best model saved at epoch 2 with val_acc 0.8711


Epoch 3/5: 100%|██████████| 4096/4096 [32:43<00:00,  2.09it/s]  


Epoch [3/5] Loss: 0.2069
Validation Accuracy: 0.8219
EarlyStopping counter: 1 of 3


Epoch 4/5: 100%|██████████| 4096/4096 [28:45<00:00,  2.37it/s]


Epoch [4/5] Loss: 0.1831
Validation Accuracy: 0.8719
Best model saved at epoch 4 with val_acc 0.8719


Epoch 5/5: 100%|██████████| 4096/4096 [24:37<00:00,  2.77it/s]


Epoch [5/5] Loss: 0.1659
Validation Accuracy: 0.8637
EarlyStopping counter: 1 of 3
Training completed. Final model saved.


## Explicability

To interpret the model’s predictions and understand where it focuses during decision-making, we selected three different explanation techniques:
- Integrated Gradients
- Grad-CAM
- LIME

Each method provides a different perspective on the model’s behavior, and together they give a broader view of the model’s reasoning process.

### Integrated Gradients

Integrated Gradients attribute the prediction by accumulating the gradients of the model’s output with respect to the input, along a straight path from a baseline (usually a black image) to the actual input.
In simpler terms, it tells us which pixels contributed the most to the model’s decision.

Pros:
- Provides fine-grained, pixel-level attributions.
- Theoretically well-founded, satisfying important axioms like sensitivity and implementation invariance.
- Does not require any modification to the model architecture.

Cons:
- Produces dense attribution maps, which can be harder to visually interpret without further processing.
- Requires defining a suitable baseline (black image, blurred image, etc.), and results can be sensitive to this choice.

### Grad-CAM

Grad-CAM (Gradient-weighted Class Activation Mapping) uses the gradients flowing into the last convolutional layers to produce a heatmap that highlights the important regions of the image for a given decision.
Instead of focusing on individual pixels, it emphasizes higher-level spatial regions.

Pros:
- Produces smooth and localized explanations over meaningful areas.
- Intuitive and easy to visualize, especially for CNNs.
- Computationally efficient since it uses activations already computed during forward pass.

Cons:
- Can be less precise at the pixel level — focuses more on regions rather than fine structures.
- Sensitive to the choice of the convolutional layer used for the explanation.

### LIME

LIME (Local Interpretable Model-agnostic Explanations) explains a single prediction by learning a simple, interpretable model (like a linear model) that approximates the original model’s behavior in the neighborhood of that prediction.
It does so by perturbing the input and observing how the prediction changes.

Pros:
- Model-agnostic: can be applied to any classifier, not only neural networks.
- Focuses on superpixels rather than individual pixels, making explanations visually intuitive.
- Good for highlighting compact, highly-informative regions.

Cons:
- Computationally expensive, as it requires many forward passes with perturbed inputs.
- Results can vary depending on the choice of perturbations and parameters like the number of samples.
- Sometimes explanations can be unstable if the local approximation is not faithful enough.

Summary

Each method offers a complementary perspective:
- Integrated Gradients focuses on detailed pixel contributions.
- Grad-CAM highlights important spatial regions.
- LIME captures local, interpretable areas based on perturbation analysis.

Using all three methods allows for a more complete and reliable understanding of the model’s behavior, especially in critical domains such as medical imaging, where interpretability is essential.

In [None]:
device, checkpoint_path = model.setup_device_and_paths()
resnet18, optimizer, _ = model.initialize_model(device)
resnet18, optimizer, start_epoch, best_val_acc = model.load_checkpoint(resnet18, optimizer, checkpoint_path, device)

transform = explicability_pipeline.build_transforms()
ig, cam, explainer = explicability_pipeline.build_explainers(resnet18)

explicability_pipeline.run_explanations(resnet18, val_loader, ig, cam, explainer, transform, device, max_images=10)

## Analysing Results

<!-- Fig 1 -->
<div align="center">
<img src="comparative_outputs/comparison_0.png" alt="Grad-CAM / IG / LIME – image 0">
<span style="font-size:0.85em"><b>Figure 1 — Panel 0</b></span>
</div>

<!-- Fig 2 -->
<div align="center">
<img src="comparative_outputs/comparison_1.png" alt="Grad-CAM / IG / LIME – image 1">
<span style="font-size:0.85em"><b>Figure 2 — Panel 1</b></span>
</div>

<!-- Fig 3  (note: file-name updated) -->
<div align="center">
<img src="comparative_outputs/comparison_6.png" alt="Grad-CAM / IG / LIME – image 2">
<span style="font-size:0.85em"><b>Figure 3 — Panel 2</b></span>
</div>

<!-- Fig 4 -->
<div align="center">
<img src="comparative_outputs/comparison_4.png" alt="Grad-CAM / IG / LIME – image 3">
<span style="font-size:0.85em"><b>Figure 4 — Panel 3</b></span>
</div>

### Quick visual read-out
- Fig 1 – All three methods converge on the same triangular region at the bottom-centre ⇒ strong, trustworthy evidence.
- Fig 2 – Grad-CAM lights up the upper-left border while LIME outlines the lower-left edge; IG is almost uniform.
Likely means the model is partially focusing on patch borders rather than tissue.
- Fig 3 – Grad-CAM shows a coarse “cross” pattern (typical up-sampling artefact); IG is noisy; LIME marks two faint areas at the far edges.
- Fig 4 – Consistent highlight: V-shaped region is marked by all three methods, reinforcing its anatomical relevance.

### Shared trends vs. differences
- Scale - Grad-CAM ≫ LIME (super-pixels) ≫ IG (pixels).
- Noise - IG is inherently speckled; Grad-CAM/LIME are smoother.
- Artefacts - Borders (Fig 2) and checkerboard blocks (Fig 3) are typical failure modes to watch for.

In [None]:
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from lime import lime_image

# Grad-CAM wrapper
target_layers = [resnet18.layer4[-1]]
cam = GradCAM(model=resnet18, target_layers=target_layers, use_cuda=(device.type == "cuda"))

def gc_expl(x, cls):
    return cam(x, targets=[ClassifierOutputTarget(cls)])[0] # H×W

# Integrated Gradients wrapper
ig = IntegratedGradients(resnet18)

def ig_expl(x, cls):
    base = torch.zeros_like(x).to(device)
    attr = ig.attribute(x, base, target=cls, n_steps=50)
    attr = attr.squeeze().abs().mean(0).cpu().numpy() # H×W
    return attr


# LIME wrapper
lime_exp = lime_image.LimeImageExplainer()

def lime_expl(x, cls):
    # tensor → uint8 RGB
    d = x.squeeze().cpu()
    d *= torch.tensor([0.229,0.224,0.225]).view(3,1,1)
    d += torch.tensor([0.485,0.456,0.406]).view(3,1,1)
    img = (d.clamp(0,1).permute(1,2,0).numpy()*255).astype(np.uint8)

    exp = lime_exp.explain_instance(img,
              lambda ims: explicability_pipeline.predict_fn(ims, resnet18, transform, device),
              top_labels=2, hide_color=0, num_samples=1000)
    mask = exp.get_image_and_mask(cls, positive_only=True,
                                  num_features=5, hide_rest=False)[1]
    return mask.astype(np.float32) # H×W

: 

In [None]:
from metrics_expl import evaluate_method

for name, fn in [("Grad-CAM", gc_expl),
                 ("Integrated Gradients", ig_expl),
                 ("LIME", lime_expl)]:

    del_auc, ins_auc = evaluate_method(
        model=resnet18,
        loader=val_loader,
        explainer_fn=fn,
        name=name,
        device=device,
        max_imgs=30,          # evaluate on 30 validation images
        steps=100             # 100 points on the curve
    )

    print(f"{name:<20}  Deletion AUC: {del_auc:.4f}  |  Insertion AUC: {ins_auc:.4f}")

In [4]:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# 1.  Instantiate Grad-CAM
target_layers = [resnet18.layer4[-1]]
cam = GradCAM(model=resnet18, target_layers=target_layers)

def gradcam_explainer(input_tensor: torch.Tensor, label: int) -> np.ndarray:
    # Grad-CAM returns (1, H, W); we take [0]
    return cam(input_tensor, targets=[ClassifierOutputTarget(label)])[0]

# 2.  Run evaluation
del_auc, ins_auc = explicability_pipeline.evaluate_method(
    model=resnet18,
    val_loader=val_loader,
    explainer_fn=gradcam_explainer,
    device=device,
    method_name="Grad-CAM",
    steps=100,
    max_images=30
)

print(f"Grad-CAM  Deletion AUC:  {del_auc:.4f}")
print(f"Grad-CAM  Insertion AUC: {ins_auc:.4f}")

Evaluating Grad-CAM:   0%|          | 0/512 [01:23<?, ?it/s]

Grad-CAM  Deletion AUC:  0.4580
Grad-CAM  Insertion AUC: 0.5393





In [6]:
# LIME wrapper (superpixel explanation)
from lime import lime_image
from skimage.segmentation import mark_boundaries
lime_exp = lime_image.LimeImageExplainer()

def lime_explainer(input_tensor, label):
    # Convert tensor to uint8 RGB
    denorm = input_tensor.squeeze().cpu().clone()
    denorm *= torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
    denorm += torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
    denorm = (denorm.clamp(0,1).permute(1,2,0).numpy()*255).astype(np.uint8)

    explanation = lime_exp.explain_instance(
        denorm, explicability_pipeline.predict_fn, top_labels=2, hide_color=0, num_samples=1000)
    mask = explanation.get_image_and_mask(label, positive_only=True,
                                          num_features=5, hide_rest=False)[1]
    return mask.astype(np.float32)              # H × W  (0/1 mask)

del_auc, ins_auc = explicability_pipeline.evaluate_method(
    model=resnet18,
    val_loader=val_loader,
    explainer_fn=lime_explainer,
    device=device,
    method_name="Lime",
    steps=100,
    max_images=30
)

print(f"Lime  Deletion AUC:  {del_auc:.4f}")
print(f"Lime  Insertion AUC: {ins_auc:.4f}")

  1%|          | 9/1000 [00:00<00:00, 2056.70it/s]?it/s]
Evaluating Lime:   0%|          | 0/512 [00:00<?, ?it/s]


TypeError: predict_fn() missing 3 required positional arguments: 'model', 'transform', and 'device'

In [None]:
# Integrated Gradients wrapper
from captum.attr import IntegratedGradients
ig = IntegratedGradients(model)

def ig_explainer(input_tensor, label):
    baseline = torch.zeros_like(input_tensor).to(device)
    attributions = ig.attribute(input_tensor, baseline, target=label, n_steps=50)
    attr = attributions.squeeze().abs().mean(dim=0).cpu().numpy()  # H × W
    return attr

del_auc, ins_auc = explicability_pipeline.evaluate_method(
    model=resnet18,
    val_loader=val_loader,
    explainer_fn=gradcam_explainer,
    device=device,
    method_name="Grad-CAM",
    steps=100,
    max_images=30
)

print(f"Grad-CAM  Deletion AUC:  {del_auc:.4f}")
print(f"Grad-CAM  Insertion AUC: {ins_auc:.4f}")