# Rexia Final Project
See description at README.md

### Authors
- Gabriel Souza Lima
- Augustin Cobena

## Importing Libraries

In [18]:
import os
import gzip
import shutil
import h5py
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.models import load_model

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from pytorch_grad_cam import GradCAM

from tqdm import tqdm

import data_manipulation
import model
import explicability_pipeline

## Importing data

In [19]:
from pathlib import Path
import os

base = "https://zenodo.org/record/2546921/files"
files = [
    "camelyonpatch_level_2_split_train_x.h5",
    "camelyonpatch_level_2_split_train_y.h5",
    "camelyonpatch_level_2_split_valid_x.h5",
    "camelyonpatch_level_2_split_valid_y.h5",
    "camelyonpatch_level_2_split_test_x.h5",
    "camelyonpatch_level_2_split_test_y.h5",
]

os.makedirs("data", exist_ok=True)

for fname in files:
    out_path = os.path.join("data", fname)
    if not os.path.exists(out_path):
        print(f"Downloading {fname}...")
        url = f"{base}/{fname}.gz?download=1"
        os.system(f"curl -L -o {out_path} '{url}'")

In [2]:
# Preprocessing 
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomApply([transforms.ColorJitter(0.2, 0.2, 0.2)], p=0.5)
])

# Downloaded files 
train_x = "data/camelyonpatch_level_2_split_train_x.h5"
train_y = "data/camelyonpatch_level_2_split_train_y.h5"
valid_x = "data/camelyonpatch_level_2_split_valid_x.h5"
valid_y = "data/camelyonpatch_level_2_split_valid_y.h5"
test_x  = "data/camelyonpatch_level_2_split_test_x.h5"
test_y  = "data/camelyonpatch_level_2_split_test_y.h5"

# Instantiate Datasets 
train_dataset = data_manipulation.PatchCamelyonH5Dataset(train_x, train_y, transform=transform)
valid_dataset = data_manipulation.PatchCamelyonH5Dataset(valid_x, valid_y)
test_dataset  = data_manipulation.PatchCamelyonH5Dataset(test_x, test_y)

# Instatiate Loaders
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

## Model

Training a model from scratch would require a significant amount of time — over 10 hours on local hardware.
To make the process more efficient and focused, we chose a different strategy: using a ResNet-18 backbone and training only its final classification head - which reduced training to 2.5 hours.

This choice offers several advantages:
- ResNet-18 is a lightweight yet powerful convolutional network, known for its solid performance even on relatively small datasets.
- It has a simple and clean architecture, making it more interpretable than deeper models like ResNet-50 or ResNet-101.
- Its relatively low number of parameters also reduces overfitting risk, especially important when working with limited data.

To further mitigate overfitting, we restricted training to only 10% of the dataset, while ensuring stratified sampling. This way, the model is still exposed to examples from all classes, despite the small training set.

In addition, we applied data augmentation:
- Random rotations to simulate different orientations of tissue patches.
- Color jitter to make the model more robust to variations in color and brightness that naturally occur in histology images.

### Why ResNet-18?

ResNet-18 strikes a balance between complexity and interpretability.
It is deep enough to capture relevant spatial features but shallow enough to allow methods like Grad-CAM and Integrated Gradients to produce meaningful and localized explanations.
Deeper models, while potentially more accurate, often create more diffused and harder-to-interpret explanation maps.

Overall, ResNet-18 provides a solid foundation for balancing performance, training time, and explainability, which is crucial given the goals of this project.

In [5]:
# Set to True only if you want to train again
TRAIN = False

device, checkpoint_path = model.setup_device_and_paths()
resnet18, optimizer, criterion = model.initialize_model(device)
resnet18, optimizer, start_epoch, best_val_acc = model.load_checkpoint(resnet18, optimizer, checkpoint_path, device)

if TRAIN:
    resnet18 = model.train_model(resnet18, optimizer, criterion, device,
                        checkpoint_path, train_loader, val_loader,
                        start_epoch=start_epoch, num_epochs=5, best_val_acc=best_val_acc)

Starting training from epoch 0 to 5.


Epoch 1/5: 100%|██████████| 4096/4096 [37:54<00:00,  1.80it/s]   


Epoch [1/5] Loss: 0.3644
Validation Accuracy: 0.8072
Best model saved at epoch 1 with val_acc 0.8072


Epoch 2/5: 100%|██████████| 4096/4096 [36:48<00:00,  1.85it/s]   


Epoch [2/5] Loss: 0.2491
Validation Accuracy: 0.8711
Best model saved at epoch 2 with val_acc 0.8711


Epoch 3/5: 100%|██████████| 4096/4096 [32:43<00:00,  2.09it/s]  


Epoch [3/5] Loss: 0.2069
Validation Accuracy: 0.8219
EarlyStopping counter: 1 of 3


Epoch 4/5: 100%|██████████| 4096/4096 [28:45<00:00,  2.37it/s]


Epoch [4/5] Loss: 0.1831
Validation Accuracy: 0.8719
Best model saved at epoch 4 with val_acc 0.8719


Epoch 5/5: 100%|██████████| 4096/4096 [24:37<00:00,  2.77it/s]


Epoch [5/5] Loss: 0.1659
Validation Accuracy: 0.8637
EarlyStopping counter: 1 of 3
Training completed. Final model saved.


## Explicability

To interpret the model’s predictions and understand where it focuses during decision-making, we selected three different explanation techniques:
- Integrated Gradients
- Grad-CAM
- LIME

Each method provides a different perspective on the model’s behavior, and together they give a broader view of the model’s reasoning process.

### Integrated Gradients

Integrated Gradients attribute the prediction by accumulating the gradients of the model’s output with respect to the input, along a straight path from a baseline (usually a black image) to the actual input.
In simpler terms, it tells us which pixels contributed the most to the model’s decision.

Pros:
- Provides fine-grained, pixel-level attributions.
- Theoretically well-founded, satisfying important axioms like sensitivity and implementation invariance.
- Does not require any modification to the model architecture.

Cons:
- Produces dense attribution maps, which can be harder to visually interpret without further processing.
- Requires defining a suitable baseline (black image, blurred image, etc.), and results can be sensitive to this choice.

### Grad-CAM

Grad-CAM (Gradient-weighted Class Activation Mapping) uses the gradients flowing into the last convolutional layers to produce a heatmap that highlights the important regions of the image for a given decision.
Instead of focusing on individual pixels, it emphasizes higher-level spatial regions.

Pros:
- Produces smooth and localized explanations over meaningful areas.
- Intuitive and easy to visualize, especially for CNNs.
- Computationally efficient since it uses activations already computed during forward pass.

Cons:
- Can be less precise at the pixel level — focuses more on regions rather than fine structures.
- Sensitive to the choice of the convolutional layer used for the explanation.

### LIME

LIME (Local Interpretable Model-agnostic Explanations) explains a single prediction by learning a simple, interpretable model (like a linear model) that approximates the original model’s behavior in the neighborhood of that prediction.
It does so by perturbing the input and observing how the prediction changes.

Pros:
- Model-agnostic: can be applied to any classifier, not only neural networks.
- Focuses on superpixels rather than individual pixels, making explanations visually intuitive.
- Good for highlighting compact, highly-informative regions.

Cons:
- Computationally expensive, as it requires many forward passes with perturbed inputs.
- Results can vary depending on the choice of perturbations and parameters like the number of samples.
- Sometimes explanations can be unstable if the local approximation is not faithful enough.

In [6]:
device, checkpoint_path = model.setup_device_and_paths()
resnet18, optimizer, _ = model.initialize_model(device)
resnet18, optimizer, start_epoch, best_val_acc = model.load_checkpoint(resnet18, optimizer, checkpoint_path, device)

transform = explicability_pipeline.build_transforms()
ig, cam, explainer = explicability_pipeline.build_explainers(resnet18)

explicability_pipeline.run_explanations(resnet18, val_loader, ig, cam, explainer, transform, device, max_images=10)

Checkpoint found. Loading model...
Model loaded. Resuming from epoch 4, best validation acc = 0.8719


100%|██████████| 1000/1000 [00:04<00:00, 210.68it/s]
100%|██████████| 1000/1000 [00:04<00:00, 214.98it/s]
100%|██████████| 1000/1000 [00:04<00:00, 212.82it/s]
100%|██████████| 1000/1000 [00:04<00:00, 217.30it/s]
100%|██████████| 1000/1000 [00:04<00:00, 217.01it/s]
100%|██████████| 1000/1000 [00:04<00:00, 217.97it/s]
100%|██████████| 1000/1000 [00:04<00:00, 220.20it/s]
100%|██████████| 1000/1000 [00:04<00:00, 218.58it/s]
100%|██████████| 1000/1000 [00:04<00:00, 219.27it/s]
100%|██████████| 1000/1000 [00:04<00:00, 216.92it/s]


## Analysing Results
<!-- Fig 1 -->
<div align="center">
<img src="comparative_outputs/comparison_0_new.png" alt="Grad-CAM / IG / LIME – image 0">
<span style="font-size:0.85em"><b>Figure 1 — Panel 0</b></span>
</div>

<!-- Fig 2 -->
<div align="center">
<img src="comparative_outputs/comparison_3_new.png" alt="Grad-CAM / IG / LIME – image 1">
<span style="font-size:0.85em"><b>Figure 2 — Panel 1</b></span>
</div>

<!-- Fig 3  (note: file-name updated) -->
<div align="center">
<img src="comparative_outputs/comparison_6_new.png" alt="Grad-CAM / IG / LIME – image 2">
<span style="font-size:0.85em"><b>Figure 3 — Panel 2</b></span>
</div>

<!-- Fig 4 -->
<div align="center">
<img src="comparative_outputs/comparison_9_new.png" alt="Grad-CAM / IG / LIME – image 3">
<span style="font-size:0.85em"><b>Figure 4 — Panel 3</b></span>
</div>

---

### Fig. 1
- Grad-CAM shows a strong, rounded hot-spot in the upper-left corner.
- Integrated Gradients also concentrates energy in that same corner, with a tight cluster of bright pixels.
- LIME draws a broad yellow contour that encloses almost the same anatomical region (although its outline is less precise).

Take-away: all three methods agree on the upper-left patch being the main driver of the prediction; the overlap is remarkably tight, so this region is very likely relevant for the network.

---

### Fig. 2
- Grad-CAM highlights a tall triangular ridge running from bottom-left to top-centre.
- Integrated Gradients places its brightest pixels inside that ridge, but in a much smaller inner core.
- LIME detects only a thin super-pixel contour on the lower-right edge, not touching the Grad-CAM ridge.

Interpretation: IG partially corroborates Grad-CAM (same ridge, smaller focus), whereas LIME largely misses it—possibly because the SLIC segmentation breaks the ridge into many small pieces and only one of them meets LIME’s “top features” threshold.

---

### Fig. 3
- Grad-CAM lights up an L-shaped corner (bottom-right) with a very smooth gradient; that pattern looks like an up-sampling artefact rather than tissue detail.
- Integrated Gradients instead highlights the opposite corner (top-left) with a dense block of bright pixels.
- LIME outlines a long vertical band on the far right side; again it sits on the Grad-CAM corner, not on the IG cluster.

Interpretation: the three methods disagree. Given the blocky shape in Grad-CAM and the mismatch with IG, the highlighted regions may be artefactual (patch border, staining variation) rather than true histological signal. Switching Grad-CAM to a slightly earlier convolutional layer could clarify the map.

----

### Fig. 4
- Grad-CAM shows a wide colour band sweeping diagonally across the patch.
- Integrated Gradients scatters multiple bright hotspots along that same band—good positional agreement, albeit noisier.
- LIME draws two disconnected contours sitting on the upper-right and lower-left ends of the band.

Interpretation: all three methods locate the same diagonal structure, but they “slice” it differently: Grad-CAM gives the full band, IG pin-points small sub-regions, and LIME captures only the extreme ends (largest super-pixels).

---

### Cross-figure observations
1.	Scale hierarchy holds everywhere
Grad-CAM paints the broadest areas, IG the finest details, LIME sits in-between (super-pixels).
2.	Agreement is high in Fig. 1 and Fig. 4, partial in Fig. 2, poor in Fig. 3.
Divergence usually indicates either segmentation artefacts (LIME) or coarse feature-map artefacts (Grad-CAM).
3.	When IG and Grad-CAM coincide (Fig. 1), confidence in that region being class-discriminative is strong.
4.	LIME is sensitive to super-pixel size—in Fig. 2 and Fig. 3 it misses narrow structures that IG detects. 

In [38]:
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from lime import lime_image
from captum.attr import IntegratedGradients

# Grad-CAM wrapper
target_layers = [resnet18.layer4[-1]]
cam = GradCAM(model=resnet18, target_layers=target_layers)

def gc_expl(x, cls):
    return cam(x, targets=[ClassifierOutputTarget(cls)])[0] # H×W

ig = IntegratedGradients(resnet18)

def ig_expl(x, cls):
    """Return a 2-D NumPy saliency map (H×W)."""
    baseline = torch.zeros_like(x).to(x.device)
    # attr: 1×C×H×W  (tensor on the same device as x)
    attr = ig.attribute(x, target=cls, baselines=baseline, n_steps=50)
    # collapse channel-dimension → H×W, detach, move to CPU, convert to NumPy
    sal = attr.squeeze(0).abs().mean(0).detach().cpu().numpy()
    return sal                   

# LIME wrapper
lime_exp = lime_image.LimeImageExplainer()

def lime_expl(x, cls):
    # tensor → uint8 RGB
    d = x.squeeze().cpu()
    d *= torch.tensor([0.229,0.224,0.225]).view(3,1,1)
    d += torch.tensor([0.485,0.456,0.406]).view(3,1,1)
    img = (d.clamp(0,1).permute(1,2,0).numpy()*255).astype(np.uint8)

    exp = lime_exp.explain_instance(img,
              lambda ims: explicability_pipeline.predict_fn(ims, resnet18, transform, device),
              top_labels=2, hide_color=0, num_samples=1000)
    mask = exp.get_image_and_mask(cls, positive_only=True,
                                  num_features=5, hide_rest=False)[1]
    return mask.astype(np.float32) # H×W

In [44]:
for name, fn in [("Grad-CAM", gc_expl),
                 ("Integrated Gradients", ig_expl),
                 ("LIME", lime_expl)]:

    del_auc, ins_auc = explicability_pipeline.evaluate_method(
        model=resnet18,
        loader=val_loader,
        explainer_fn=fn,
        name=name,
        device=device,
        max_imgs=30, # evaluate on 30 validation images
        steps=100 # 100 points on the curve
    )

    print(f"{name:<20}  Deletion AUC: {del_auc:.4f}  |  Insertion AUC: {ins_auc:.4f}")

Evaluating Grad-CAM:   0%|          | 0/512 [01:02<?, ?it/s]


Grad-CAM              Deletion AUC: 0.3885  |  Insertion AUC: 0.8014


Evaluating Integrated Gradients:   0%|          | 0/512 [01:06<?, ?it/s]


Integrated Gradients  Deletion AUC: 0.6218  |  Insertion AUC: 0.6359


100%|██████████| 1000/1000 [00:04<00:00, 213.84it/s]t/s]
100%|██████████| 1000/1000 [00:04<00:00, 212.77it/s]
100%|██████████| 1000/1000 [00:04<00:00, 217.09it/s]
100%|██████████| 1000/1000 [00:04<00:00, 216.25it/s]
100%|██████████| 1000/1000 [00:04<00:00, 216.39it/s]
100%|██████████| 1000/1000 [00:04<00:00, 216.26it/s]
100%|██████████| 1000/1000 [00:04<00:00, 205.19it/s]
100%|██████████| 1000/1000 [00:04<00:00, 219.00it/s]
100%|██████████| 1000/1000 [00:04<00:00, 215.31it/s]
100%|██████████| 1000/1000 [00:04<00:00, 216.71it/s]
100%|██████████| 1000/1000 [00:04<00:00, 218.12it/s]
100%|██████████| 1000/1000 [00:04<00:00, 215.77it/s]
100%|██████████| 1000/1000 [00:04<00:00, 217.50it/s]
100%|██████████| 1000/1000 [00:04<00:00, 215.42it/s]
100%|██████████| 1000/1000 [00:04<00:00, 212.34it/s]
100%|██████████| 1000/1000 [00:04<00:00, 216.57it/s]
100%|██████████| 1000/1000 [00:04<00:00, 216.42it/s]
100%|██████████| 1000/1000 [00:04<00:00, 217.21it/s]
100%|██████████| 1000/1000 [00:04<00:00, 2

LIME                  Deletion AUC: 0.3428  |  Insertion AUC: 0.7215





## Quantitative sanity-check: Deletion & Insertion AUCs

| Method                 | Deletion AUC ↓ (lower = better) | Insertion AUC ↑ (higher = better) |
|-------------------------|---------------------------------|-----------------------------------|
| **Grad-CAM**            | 0.39                            | 0.80                             |
| **Integrated Gradients**| 0.62                            | 0.64                             |
| **LIME**                | 0.34                            | 0.72                             |

### Interpretation rule of thumb
- Deletion AUC: how fast confidence drops when we remove the most-important pixels. Lower = saliency map really captured what the model needs.
- Insertion AUC: how fast confidence rises when we re-insert important pixels. Higher = map contains true evidence.

---

### What the numbers tell us
1.	Grad-CAM dominates insertion (0.80)
- The model’s confidence recovers quickly as soon as the Grad-CAM mask is pasted back → its blobs indeed cover decisive regions.
- This echoes Fig. 1 and Fig. 4, where Grad-CAM highlighted the exact class-driving tissue.
2.	LIME wins the deletion race (0.34)
- Once LIME’s super-pixels are blacked out, confidence collapses fastest.
- Even though LIME sometimes missed fine structures (Fig. 2), the segments it did choose are clearly essential.
3.	Integrated Gradients is the most conservative
- Middle-of-the-road scores (0.62 / 0.64) suggest its pixel-wise heat map spreads importance more thinly; removing or adding small fractions doesn’t move the needle as sharply.
- That matches our qualitative note that IG looks “noisier” and often extends into background (Fig. 3).

---

### Linking back to the four visual panels

| Panel                        | Qualitative verdict                     | AUC behaviour explained                                                |
|-------------------------------|-----------------------------------------|------------------------------------------------------------------------|
| **Fig. 1 (tight agreement)**  | All three masks centred on same hot-spot | High Grad-CAM & LIME insertion; deletion hurts quickly for any method  |
| **Fig. 2 (LIME under-covers ridge)** | Grad-CAM + IG agree, LIME partial      | Grad-CAM insertion high; LIME still good at deletion because its few segments are truly critical |
| **Fig. 3 (Grad-CAM artefact)** | Methods disagree; Grad-CAM blocky       | Grad-CAM still useful (AUC 0.80) but deletion score worse (0.39) than LIME’s 0.34, showing artefact cost |
| **Fig. 4 (broad diagonal band)** | Methods align again                   | Consistent with strong Grad-CAM metrics and decent IG/LIME scores       |