In [None]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

from datasets import load_dataset

import random
import json
from PIL import Image
import pandas as pd
from tqdm import tqdm
import time

import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision

import quantus

#from xai_methods import GradCAMHeatmap, LaFAM, LFLRP, RELAX, Randomized7x7
from utils import SquareCropAndResize, imagenet_transform, inverse_transform, target_transform
from utils import evaluate, get_layer_idx


def choose_device() -> str:
    if torch.cuda.is_available():
        return "cuda:0"
    if hasattr(torch.backends, "mps"):
        if torch.backends.mps.is_available():
            return "mps"
    return "cpu"

device = torch.device(choose_device())
print(torch.cuda.get_device_name(device))

# fix seed for reproducibility
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # multi-GPU


In [1]:

def continuity(x, explainer, explanation_base, device, sigma=0.1, n_perturbations=10):


    x = x.to(device)
    explainer = explainer

    if not isinstance(explanation_base, torch.Tensor):
        explanation_base = torch.tensor(explanation_base, device=device, dtype=torch.float32)
    else:
        explanation_base = explanation_base.to(device).float()

    expl_base = explanation_base
    max_diff = 0

    for _ in range(n_perturbations):
        noise = torch.randn_like(x) * sigma
        x_perturbed = (x + noise).detach()
        x_perturbed.requires_grad_()
    
        expl_perturbed = explainer(x_perturbed)
        
        if not isinstance(expl_perturbed, torch.Tensor):
            expl_perturbed = torch.tensor(expl_perturbed, device=device, dtype=torch.float32)
        else:
            expl_perturbed = expl_perturbed.to(device).float()

        
        while expl_perturbed.dim() > 2:
            expl_perturbed = expl_perturbed.squeeze(0)


        if expl_perturbed.shape[-2:] != expl_base.shape[-2:]:
            
            expl_perturbed = expl_perturbed.unsqueeze(0).unsqueeze(0)  # shape: (1,1,H,W)
            expl_perturbed = torch.nn.functional.interpolate(
                expl_perturbed,
                size=expl_base.shape[-2:],  # target spatial size
                mode='nearest'
            )
            expl_perturbed = expl_perturbed.squeeze(0).squeeze(0)

      
        expl_perturbed = (expl_perturbed - expl_perturbed.min()) / (expl_perturbed.max() - expl_perturbed.min() + 1e-8)
        expl_base_norm = (expl_base - expl_base.min()) / (expl_base.max() - expl_base.min() + 1e-8)

        
        diff = torch.abs(expl_base_norm - expl_perturbed).max().item()
        if diff > max_diff:
            max_diff = diff

    
    continuity_score = 1 - max_diff
    return continuity_score


In [2]:
# Consistency Metric

from sklearn.metrics.pairwise import cosine_similarity
from itertools import combinations


method_explanations = {name: [] for name, _ in xai_methods}
method_predictions = []

df_sample = pascal_df.sample(10)
print(f"Evaluating {len(df_sample)} samples...")

for i, row in enumerate(tqdm(df_sample.itertuples(), total=len(df_sample))):
    img_tensor, seg_img = pascal_ds[int(row.dataset_idx)]
    img_tensor = img_tensor.unsqueeze(0).to(device)

    with torch.no_grad():
        pred = resnet(img_tensor).argmax().item()
    method_predictions.append(pred)

    for name, method in xai_methods:
        try:
            heatmap = method(img_tensor)
            heatmap = torch.relu(heatmap)
            heatmap = heatmap / (heatmap.max() + 1e-8)

            heatmap_np = heatmap.squeeze().detach().cpu().numpy()

            # Normalize heatmap
            heatmap_np = (heatmap_np - np.min(heatmap_np)) / (np.max(heatmap_np) - np.min(heatmap_np) + 1e-8)

            # Check for NaNs or infs
            if np.isnan(heatmap_np).any() or np.isinf(heatmap_np).any():
                raise ValueError("Heatmap contains NaN or inf values.")

          
            heatmap_pil = Image.fromarray((heatmap_np * 255).astype(np.uint8))
            heatmap_pil = heatmap_pil.resize((224, 224), Image.NEAREST)
            heatmap = np.array(heatmap_pil).astype(np.float32).flatten()

            method_explanations[name].append(heatmap)

        except Exception as e:
            print(f"[{name}] Sample {i} failed: {e}")
            method_explanations[name].append(None)
            continue


def compute_consistency_from_vectors(explanations, predictions, similarity_threshold=0.9):
    total_pairs = 0
    consistent_pairs = 0

    for i, j in combinations(range(len(explanations)), 2):
        exp_i, exp_j = explanations[i], explanations[j]
        if exp_i is None or exp_j is None:
            continue

        sim = cosine_similarity(
            exp_i.reshape(1, -1),
            exp_j.reshape(1, -1)
        )[0][0]

        if sim >= similarity_threshold:
            total_pairs += 1
            same = predictions[i] == predictions[j]
            consistent_pairs += int(same)
            #print(f"Pair ({i},{j}): sim={sim:.3f}, same_pred={same}")

    if total_pairs == 0:
        return None
    return consistent_pairs / total_pairs


consistency_results = []
for name in method_explanations:
    explanations = method_explanations[name]
    valid_expls = sum([e is not None for e in explanations])
    #print(f"\n{name}: {valid_expls} valid explanations")

    consistency = compute_consistency_from_vectors(
        explanations, method_predictions, similarity_threshold=0.9
    )

    consistency_results.append({
        "xai_method": name,
        "consistency": consistency
    })

df_consistency = pd.DataFrame(consistency_results)
print("Consistency Results:")
print(df_consistency)


ModuleNotFoundError: No module named 'sklearn'

In [3]:


def local_lipschitz(
    x,
    explainer,
    explanation_base,
    n_samples=50,
    noise_std=0.01,
    norm_type=2,
    device='cpu',
    random_seed=None
):
   
    if random_seed is not None:
        torch.manual_seed(random_seed)
        np.random.seed(random_seed)

    x = x.to(device)
    if not torch.is_tensor(explanation_base):
        explanation_base = torch.tensor(explanation_base, device=device)
    explanation_base = explanation_base.unsqueeze(0) if explanation_base.dim() == 2 else explanation_base

    max_lipschitz = 0.0

    for _ in range(n_samples):
        noise = torch.randn_like(x) * noise_std
        x_perturbed = (x + noise).detach().clone().requires_grad_()

        # Get explanation for perturbed input
        explanation_perturbed = explainer(x_perturbed)
        explanation_perturbed = explanation_perturbed.to(device) 
        explanation_perturbed = torch.relu(explanation_perturbed)
        explanation_perturbed = explanation_perturbed / (explanation_perturbed.max() + 1e-8)
        explanation_perturbed = explanation_perturbed.squeeze()
            '''
         Maryam to in ghesmat chon man ba image haie mokhatlef kar mikardm va size ha motefavet dasht, inja to in 
         
         if explanation_perturbed.shape != explanation_base.shape:
            explanation_perturbed = torch.nn.functional.interpolate(
                explanation_perturbed.unsqueeze(0).unsqueeze(0),
                size=explanation_base.shape[-2:],
                mode='nearest'
            ).squeeze()

         hamashon ro ie size mikardm, shayad in ghesmat fght be karet nayad o aziatete kone, inja ro havest bashe, to baghie ie metric ha ham bayad ie hamchinchizi bebini
         
         
    '''

        # Resize if needed
        if explanation_perturbed.shape != explanation_base.shape:
            explanation_perturbed = F.interpolate(
                explanation_perturbed.unsqueeze(0).unsqueeze(0),
                size=explanation_base.shape[-2:],
                mode='nearest'
            ).squeeze()
            explanation_perturbed = explanation_perturbed.to(device)
        # Compute norm of explanation difference
        diff_exp = explanation_perturbed - explanation_base
        norm_exp = torch.norm(diff_exp.view(-1), p=norm_type)

        # Compute norm of input difference
        diff_x = x_perturbed - x
        norm_x = torch.norm(diff_x.view(-1), p=norm_type)

        if norm_x.item() > 0:
            lipschitz_estimate = (norm_exp / norm_x).item()
            if lipschitz_estimate > max_lipschitz:
                max_lipschitz = lipschitz_estimate

    return max_lipschitz


In [None]:
def avg_sensitivity(
    x,
    explainer,
    explanation_base,
    n_samples=50,
    noise_std=0.01,
    norm_type=1,
    device='cpu',
    random_seed=None
):
    if random_seed is not None:
        torch.manual_seed(random_seed)
        np.random.seed(random_seed)

    x = x.to(device)

    if not torch.is_tensor(explanation_base):
        explanation_base = torch.tensor(explanation_base, device=device)
    else:
        explanation_base = explanation_base.to(device)
    explanation_base = explanation_base.unsqueeze(0)

    sensitivities = []
        '''
         Maryam to in ghesmat chon man ba image haie mokhatlef kar mikardm va size ha motefavet dasht, inja to in 
         
         if explanation_perturbed.shape != explanation_base.shape:
            explanation_perturbed = torch.nn.functional.interpolate(
                explanation_perturbed.unsqueeze(0).unsqueeze(0),
                size=explanation_base.shape[-2:],
                mode='nearest'
            ).squeeze()

         hamashon ro ie size mikardm, shayad in ghesmat fght be karet nayad o aziatete kone, inja ro havest bashe, to baghie ie metric ha ham bayad ie hamchinchizi bebini
         
         
    '''
    for _ in range(n_samples):
        noise = torch.randn_like(x) * noise_std
        x_perturbed = (x + noise).detach().clone().requires_grad_()

        explanation_perturbed = explainer(x_perturbed)
        explanation_perturbed = explanation_perturbed.to(device)  # ensure device alignment

        explanation_perturbed = torch.relu(explanation_perturbed)
        explanation_perturbed = explanation_perturbed / (explanation_perturbed.max() + 1e-8)
        explanation_perturbed = explanation_perturbed.squeeze()

        if explanation_perturbed.shape != explanation_base.shape:
            explanation_perturbed = torch.nn.functional.interpolate(
                explanation_perturbed.unsqueeze(0).unsqueeze(0),
                size=explanation_base.shape[-2:],
                mode='nearest'
            ).squeeze()
            explanation_perturbed = explanation_perturbed.to(device)

        diff = explanation_perturbed - explanation_base
        diff_norm = torch.norm(diff.view(-1), p=norm_type)
        delta_norm = torch.norm(noise.view(-1), p=norm_type)

        if delta_norm.item() > 0:
            sensitivities.append((diff_norm / delta_norm).item())

    return np.mean(sensitivities)
