# Xai _ Grad-Cam (TF)

In [None]:
import sys
sys.path.append('/usr/users/pred_lung_cancer/piquet_con/Project-lung-cancer/src/Segmentation')
from TheDuneAI import ContourPilot as cp
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm
import cv2

# Import for viewer
import nrrd
import napari

# Imports for Xplique
import requests
from PIL import Image
BATCH_SIZE = 8
import torch
from torchvision.transforms.functional import to_pil_image, to_tensor, resize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import xplique


# Visualize the data for one patient

In [None]:
model_path = '/Users/constance/Documents/Project_lung_cancer/src/Segmentation/model_files'
path_to_test_data = '/Users/constance/Documents/Project_lung_cancer/NIH dataset_raw/NRRD/converted_nrrds/'
save_path = '/Users/constance/Documents/Project_lung_cancer/NIH dataset_raw/Processed'

In [None]:
model = cp(model_path,path_to_test_data,save_path,verbosity=True)
for i, layer in enumerate(model.model1.layers):
        print(i, layer.name, layer.__class__.__name__)


In [None]:
gen_with_progress = tqdm(model.Patients_gen, desc='Progress')
gen_iterator = iter(gen_with_progress)  # <-- transformer en itérateur
img, _, filename, params = next(gen_iterator)
print(img.shape) 

# Select the middle slice of the 3D image
print(f"Image shape: {img.shape}") 
img=img.squeeze()  
print(f"Image shape after squeeze: {img.shape}") 
# --- Select the middle slice ---
middle_index = img.shape[0] // 2
print(f"Middle slice index: {middle_index}")
middle_slice = img[middle_index, :, :]  # shape: (512, 512)

mask = model.model1.predict(img, batch_size=1, verbose=1)


In [None]:
# Lire les fichiers NRRD
data_image, header_image = nrrd.read(img)
data_mask, header_mask = nrrd.read(mask)

# Lancer napari et afficher les images
viewer = napari.Viewer()

# Ajouter l'image CT (en utilisant des niveaux de gris)
viewer.add_image(data_image, name='CT Image', blending='additive')

# Ajouter le masque (en ajustant la transparence)
viewer.add_image(data_mask, name='DL Mask', blending='additive', opacity=0.5)

napari.run()

# Grad-cam (basic version)

In [None]:
def compute_gradcam(model, image, target_layer_name):
    grad_model = tf.keras.models.Model(
        [model.input],
        [model.get_layer(target_layer_name).output, model.output]
    )

    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(np.expand_dims(image, axis=0))
        # Use mean prediction over the output map for loss
        loss = tf.reduce_mean(predictions)

    grads = tape.gradient(loss, conv_outputs)[0]
    conv_outputs = conv_outputs[0]

    weights = tf.reduce_mean(grads, axis=(0, 1))
    cam = np.zeros(conv_outputs.shape[:2], dtype=np.float32)

    for i, w in enumerate(weights):
        cam += w * conv_outputs[:, :, i]

    cam = np.maximum(cam, 0)
    cam = cam / (np.max(cam) + 1e-8)
    cam = cv2.resize(cam, (image.shape[1], image.shape[0]))
    
    return cam

def show_gradcam_overlay(image, cam, alpha=0.4):
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    # Convert grayscale to BGR (3 channels)
    if len(image.shape) == 2:
        image_color = cv2.cvtColor(np.uint8(image * 255), cv2.COLOR_GRAY2BGR)
    else:
        image_color = np.uint8(image * 255)
    overlay = cv2.addWeighted(image_color, 1-alpha, heatmap, alpha, 0)
    plt.figure(figsize=(6,6))
    plt.imshow(overlay)
    plt.axis('off')
    plt.title("Grad-CAM Overlay")
    plt.show()


In [None]:
print(f"Middle slice shape: {middle_slice.shape}")
# --- Prepare the slice for the model ---
input_slice = middle_slice.reshape(512, 512, 1).astype(np.float32)

# --- Compute Grad-CAM ---
cam = compute_gradcam(model.model1, input_slice, target_layer_name="conv2d_23")

# --- Visualize overlay ---
show_gradcam_overlay(input_slice.squeeze(), cam)

# Xplique 

In [None]:
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad, SquareGrad,
                                  Occlusion, Rise, SobolAttributionMethod, HsicAttributionMethod)

from xplique.plots import plot_attributions

img_tf = tf.convert_to_tensor(img, dtype=tf.float32)
img_tf = tf.expand_dims(img_tf, axis=-1)  # Now shape is [38, 512, 512, 1]
mask_tf = tf.convert_to_tensor(mask, dtype=tf.float32)
mask_alpha = 0.5
images_with_masks = (1 - mask_alpha) * img_tf + mask_alpha * mask_tf

explainers = {
    Saliency: {},
    GradientInput: {},
    IntegratedGradients: {"steps": 20},
    SmoothGrad: {"nb_samples": 50, "noise": 0.75},
    VarGrad: {"nb_samples": 50, "noise": 0.75},
    SquareGrad: {"nb_samples": 100, "noise": 0.5},
    Occlusion: {"patch_size": 40, "patch_stride": 10, "occlusion_value": 0}, -> to long to run
    Rise: {"nb_samples": 4000, "grid_size": 13}, -> to long to run
    SobolAttributionMethod: {"nb_design": 32, "grid_size": 13},
    HsicAttributionMethod: {"nb_design": 1500, "grid_size": 13}
}

explanations = {}
for explainer_class, params in explainers.items():
    torch.cuda.empty_cache()
    print(explainer_class.__name__)

    # instanciate explainer
    explainer = explainer_class(model.model1, operator=xplique.Tasks.SEMANTIC_SEGMENTATION,
                                batch_size=1, **params)

    # compute explanations
    explanation = explainer(img, mask)

    # show explanations for a method
    plot_attributions(explanation, images_with_masks,
                      img_size=4., cols=img_tf.shape[0],
                      cmap='jet', alpha=0.3, absolute_value=False, clip_percentile=0.5)
    plt.show()

    # keep explanations in memory for metrics
    explanations[explainer_class.__name__] = explanation

In [None]:
from xplique.metrics import Deletion, MuFidelity, Insertion, AverageStability
from xplique.plots.metrics import barplot

metrics = {}

# -------------------- MuFidelity's nb_samples was reduced for memory needs
# explanations metrics
explanations_metrics = {
    Deletion: {"baseline_mode": 0, "steps": 10, "max_percentage_perturbed": 0.5},
    MuFidelity: {"baseline_mode": 0, "nb_samples": 5, "subset_percent":0.2, "grid_size": 13},
    Insertion: {"baseline_mode": 0, "steps": 10, "max_percentage_perturbed": 0.5}
}
for metric_class, params in explanations_metrics.items():
    torch.cuda.empty_cache()
    # instanciate the metric
    metric = metric_class(model.model1, np.array(img_tf[:3]), np.array(mask_tf[:3]),
                          operator=xplique.Tasks.SEMANTIC_SEGMENTATION,
                          activation="softmax", batch_size=BATCH_SIZE, **params)

    # iterate on methods explanations
    metrics[metric_class.__name__] = {}
    for method, explanation in explanations.items():
        metrics[metric_class.__name__][method] = metric(explanation[:3])

# # ----------------- Not included for computation cost reason
# explainer metrics
    metric = AverageStability(model.model1, img_tf[:3], mask_tf[:3], batch_size=BATCH_SIZE,
                           nb_samples=20, radius=0.1, distance="l2")
    metrics["AverageStability"] = {}
    for explainer_class, params in explainers.items():
        torch.cuda.empty_cache()

     # instanciate explainer
    explainer = explainer_class(model.model1, operator=xplique.Tasks.SEMANTIC_SEGMENTATION,
                                 batch_size=BATCH_SIZE, **params)

    metrics["AverageStability"][explainer_class.__name__] = metric.evaluate(explainer)

barplot(metrics, sort_metric="Deletion", ascending="True")
plt.show()