# Libraries

In [None]:
!git clone https://github.com/LeonardoArrighi/clxai

fatal: destination path 'clxai' already exists and is not an empty directory.


In [None]:
from tqdm import tqdm

import numpy as np
import torch
import matplotlib.pyplot as plt

import torch.nn as nn

from clxai.src.models.resnet import get_model
from clxai.src.utils.data import get_data_loaders, get_num_classes

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
model_version = 'scl' # THIS
train_flag = False
test_flag = False

In [None]:
if model_version == 'ce':
  model_path = '/content/drive/MyDrive/clxai/weights/ce_seed0_best_model.pt'
elif model_version == 'scl' and train_flag:
  model_path = '/content/drive/MyDrive/clxai/weights/scl_seed1_best_model.pt'
  weight_path = '/content/drive/MyDrive/clxai/scl_mlp/scl_seed1_model_mlp.pt'
elif model_version == 'scl' and not train_flag:
  model_path = '/content/drive/MyDrive/clxai/scl_mlp/scl_seed1_model_mlp.pt'

In [None]:
# res_path = '/content/drive/MyDrive/clxai/results_faith/test'

In [None]:
dataset = 'cifar10'
architecture = 'resnet18'
num_classes = get_num_classes(dataset)

In [None]:
train_loader, test_loader = get_data_loaders(
        dataset=dataset,
        # data_dir = ,
        batch_size = 128,
        num_workers = 2,
        augment = False,
        download = True # riga 340 in data
)

# Model

In [None]:
model = get_model(architecture = architecture,
                  num_classes = num_classes)

if model_version == 'ce':

  checkpoint = torch.load(model_path, map_location=device, weights_only=False)

  loaded_state_dict = checkpoint['model_state_dict']

  model.load_state_dict(loaded_state_dict, strict=False)


if model_version == 'scl' and train_flag:

  checkpoint = torch.load(model_path, map_location=device, weights_only=False)

  loaded_state_dict = checkpoint['model_state_dict']

  model.load_state_dict(loaded_state_dict, strict=False)


  new_state_dict = {}
  for k, v in loaded_state_dict.items():
      if k.startswith('fc.'):
          continue
      else:
          new_state_dict['encoder.' + k] = v

  # Dopo model.load_state_dict(new_state_dict, strict=False)
  loaded_keys = set(new_state_dict.keys())
  model_keys = set(model.state_dict().keys())
  matched_keys = loaded_keys.intersection(model_keys)

  print(f"Parametri caricati con successo: {len(matched_keys)}")
  print(f"Parametri mancanti: {len(model_keys - matched_keys)}")

  if len(matched_keys) == 0:
      print("ATTENZIONE: Nessuna chiave corrisponde!")

  model.load_state_dict(new_state_dict, strict=False)

model.to(device)
model.eval()

ResNet18(
  (encoder): ResNet18Encoder(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 

## Per CL training

In [None]:
# modifichiamo il classificatore e lo congeliamo

if model_version == 'scl' and train_flag:

  import torch.nn as nn

  in_features = model.fc.in_features
  hidden_dim = 256

  model.fc = nn.Sequential(
      nn.Linear(in_features, hidden_dim),
      nn.ReLU(),
      nn.Dropout(0.2),
      nn.Linear(hidden_dim, num_classes)
  )

  # freeze
  for name, param in model.named_parameters():
      if "fc" not in name:
          param.requires_grad = False

  model.to(device)

In [None]:
# addestriamo il MLP finale
if model_version == 'scl' and train_flag:

  import torch.optim as optim

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)

  num_epochs = 20

  for epoch in range(num_epochs):
      model.train()
      running_loss = 0.0
      correct = 0
      total = 0

      for inputs, labels in train_loader:
          inputs, labels = inputs.to(device), labels.to(device)

          optimizer.zero_grad()

          outputs = model(inputs)
          loss = criterion(outputs, labels)

          loss.backward()
          optimizer.step()

          running_loss += loss.item()
          _, predicted = outputs.max(1)
          total += labels.size(0)
          correct += predicted.eq(labels).sum().item()

      print(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss/len(train_loader):.4f} - Acc: {100.*correct/total:.2f}%")

      torch.save(model.state_dict(), weight_path)

In [None]:
if model_version == 'scl' and not train_flag:
  model = get_model(architecture=architecture, num_classes=num_classes)
  in_features = model.fc.in_features
  hidden_dim = 256

  model.fc = nn.Sequential(
      nn.Linear(in_features, hidden_dim),
      nn.ReLU(),
      nn.Dropout(0.2),
      nn.Linear(hidden_dim, num_classes)
  )

  loaded_weights = torch.load(model_path, map_location=device, weights_only=False)
  model.load_state_dict(loaded_weights, strict=False)
  model.to(device)
  model.eval()

# XAI

In [None]:
!pip install grad-cam



In [None]:
from pytorch_grad_cam import GradCAM, EigenCAM, HiResCAM, AblationCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# Quantus

In [None]:
!pip install quantus



In [None]:
import quantus
import pandas as pd
import os
import matplotlib.pyplot as plt
from pytorch_grad_cam.utils.image import show_cam_on_image
import matplotlib
matplotlib.use('Agg') # Force non-interactive backend (prevents hanging)

In [None]:
import copy

unfrozen_model = copy.deepcopy(model)

for param in unfrozen_model.parameters():
    param.requires_grad = True

unfrozen_model.to(device)

ResNet18(
  (encoder): ResNet18Encoder(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 

## pixel flipping

In [None]:
def xai_scores_pf(model, test_loader, device, res_path="results", plot_flag=False):
    model.eval()
    os.makedirs(res_path, exist_ok=True)
    target_layers = [model.encoder.layer4[-1]]

    pixel_flipping_metric = quantus.PixelFlipping(
        perturb_baseline="black",
        features_in_step=32,
        disable_warnings=True,
    )

    all_scores = []
    global_idx = 0

    for batch_idx, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            preds = model(images).argmax(dim=1)

        for i in range(images.size(0)):
            p_idx = preds[i].item()
            g_idx = labels[i].item()
            img_tensor = images[i:i+1]

            record = {"idx": global_idx, "true": g_idx, "pred": p_idx}

            # Re-initialize methods inside or use a very tight loop to ensure cleanup
            # We define them here so they don't carry state between samples
            curr_methods = {
                "GradCAM": GradCAM(model=model, target_layers=target_layers),
                "EigenCAM": EigenCAM(model=model, target_layers=target_layers),
                "AblationCAM": AblationCAM(model=model, target_layers=target_layers)
            }

            if plot_flag:
                fig, axes = plt.subplots(1, 4, figsize=(18, 5))
                img_np = images[i].permute(1, 2, 0).cpu().numpy()
                img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
                axes[0].imshow(img_np)
                axes[0].axis('off')

            for m_idx, (name, cam_obj) in enumerate(curr_methods.items()):
                # Tight gradient control
                with torch.enable_grad():
                    grayscale_cam = cam_obj(input_tensor=img_tensor,
                                           targets=[ClassifierOutputTarget(p_idx)])[0, :]

                # Metric calculation
                # We do this on CPU to keep the GPU purely for the model
                pf_score = pixel_flipping_metric(
                    model=model,
                    x_batch=img_tensor.detach().cpu().numpy(),
                    y_batch=np.array([g_idx]),
                    a_batch=grayscale_cam[np.newaxis, ...],
                    device=device
                )[0]

                auc_val = np.trapz(pf_score) / (len(pf_score) - 1)
                record[f"{name}_PF_AUC"] = auc_val

                if plot_flag:
                    viz = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
                    axes[m_idx+1].imshow(viz)
                    axes[m_idx+1].set_title(f"{name}\nAUC: {auc_val:.3f}")
                    axes[m_idx+1].axis('off')

                # CRITICAL: Manually trigger CAM cleanup if the library supports it
                if hasattr(cam_obj, 'activations_and_grads'):
                    cam_obj.activations_and_grads.release()

            if plot_flag:
                plt.savefig(f"/content/drive/MyDrive/clxai/results_faith/test/pf/{model_version}_sample_{global_idx}.png")
                plt.close(fig)
                plt.clf() # Clear the entire current figure

            all_scores.append(record)
            global_idx += 1

            # Clear CAM objects from memory immediately
            del curr_methods

        # End of batch cleanup
        del images, labels, preds
        torch.cuda.empty_cache()

        return pd.DataFrame(all_scores)

In [None]:
# df_scores_pf = xai_scores_pf(model=unfrozen_model,
#            test_loader=test_loader,
#            device=device,
#            plot_flag = True) #4 min

In [None]:
# df_scores_pf.to_csv(f"/content/drive/MyDrive/clxai/results_faith/test_2/{model_version}_pf_xai_scores.csv", index=False)

## IROF

### class - non aprire


In [None]:
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import torch

from quantus.functions.perturb_func import baseline_replacement_by_indices
from quantus.helpers import asserts, utils, warn
from quantus.helpers.enums import (
    DataType,
    EvaluationCategory,
    ModelType,
    ScoreDirection,
)
from quantus.helpers.model.model_interface import ModelInterface
from quantus.helpers.perturbation_utils import make_perturb_func
from quantus.metrics.base import Metric

if sys.version_info >= (3, 8):
    from typing import final
else:
    from typing_extensions import final


@final
class IROF(Metric[List[float]]):
    """
    Implementation of IROF (Iterative Removal of Features) by Rieger at el., 2020.

    The metric computes the area over the curve per class for sorted mean importances
    of feature segments (superpixels) as they are iteratively removed (and prediction scores are collected),
    averaged over several test samples.

    Assumptions:
        - The original metric definition relies on image-segmentation functionality. Therefore, only apply the
        metric to 3-dimensional (image) data. To extend the applicablity to other data domains,
        adjustments to the current implementation might be necessary.

    References:
        1) Laura Rieger and Lars Kai Hansen. "Irof: a low resource evaluation metric for
        explanation methods." arXiv preprint arXiv:2003.08747 (2020).

    Attributes:
        -  _name: The name of the metric.
        - _data_applicability: The data types that the metric implementation currently supports.
        - _models: The model types that this metric can work with.
        - score_direction: How to interpret the scores, whether higher/ lower values are considered better.
        - evaluation_category: What property/ explanation quality that this metric measures.
    """

    name = "IROF"
    data_applicability = {DataType.IMAGE}
    model_applicability = {ModelType.TORCH, ModelType.TF}
    score_direction = ScoreDirection.HIGHER
    evaluation_category = EvaluationCategory.FAITHFULNESS

    def __init__(
        self,
        segmentation_method: str = "slic",
        abs: bool = False,
        normalise: bool = True,
        normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None,
        normalise_func_kwargs: Optional[Dict[str, Any]] = None,
        perturb_func: Optional[Callable] = None,
        perturb_baseline: str = "mean",
        perturb_func_kwargs: Optional[Dict[str, Any]] = None,
        return_aggregate: bool = True,
        aggregate_func: Optional[Callable] = None,
        default_plot_func: Optional[Callable] = None,
        disable_warnings: bool = False,
        display_progressbar: bool = False,
        return_scores: bool = False, # new argument to visualized the curves
        distance_based: bool = False, # new argument to determine if evalauation is distance based
        **kwargs,
    ):
        """
        Parameters
        ----------
        segmentation_method: string
            Image segmentation method:'slic' or 'felzenszwalb', default="slic".
        abs: boolean
            Indicates whether absolute operation is applied on the attribution, default=False.
        normalise: boolean
            Indicates whether normalise operation is applied on the attribution, default=True.
        normalise_func: callable
            Attribution normalisation function applied in case normalise=True.
            If normalise_func=None, the default value is used, default=normalise_by_max.
        normalise_func_kwargs: dict
            Keyword arguments to be passed to normalise_func on call, default={}.
        perturb_func: callable
            Input perturbation function. If None, the default value is used,
            default=baseline_replacement_by_indices.
        perturb_baseline: string
            Indicates the type of baseline: "mean", "random", "uniform", "black" or "white",
            default="mean".
        perturb_func_kwargs: dict
            Keyword arguments to be passed to perturb_func, default={}.
        return_aggregate: boolean
            Indicates if an aggregated score should be computed over all instances.
        aggregate_func: callable
            Callable that aggregates the scores given an evaluation call.
        default_plot_func: callable
            Callable that plots the metrics result.
        disable_warnings: boolean
            Indicates whether the warnings are printed, default=False.
        display_progressbar: boolean
            Indicates whether a tqdm-progress-bar is printed, default=False.
        kwargs: optional
            Keyword arguments.
        """
        super().__init__(
            abs=abs,
            normalise=normalise,
            normalise_func=normalise_func,
            normalise_func_kwargs=normalise_func_kwargs,
            return_aggregate=return_aggregate,
            aggregate_func=aggregate_func,
            default_plot_func=default_plot_func,
            display_progressbar=display_progressbar,
            disable_warnings=disable_warnings,
            **kwargs,
        )

        if perturb_func is None:
            perturb_func = baseline_replacement_by_indices

        # Save metric-specific attributes.
        self.return_scores = return_scores
        self.distance_based = distance_based
        self.segmentation_method = segmentation_method
        self.nr_channels = None
        self.perturb_func = make_perturb_func(
            perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline
        )

        # Asserts and warnings.
        if not self.disable_warnings:
            warn.warn_parameterisation(
                metric_name=self.__class__.__name__,
                sensitive_params=(
                    "baseline value 'perturb_baseline' and the method to segment "
                    "the image 'segmentation_method' (including all its associated"
                    " hyperparameters), also, IROF only works with image data"
                ),
                data_domain_applicability=(
                    f"Also, the current implementation only works for 3-dimensional (image) data."
                ),
                citation=(
                    "Rieger, Laura, and Lars Kai Hansen. 'Irof: a low resource evaluation metric "
                    "for explanation methods.' arXiv preprint arXiv:2003.08747 (2020)"
                ),
            )

    def __call__(
        self,
        model,
        x_batch: np.array,
        y_batch: np.array,
        a_batch: Optional[np.ndarray] = None,
        s_batch: Optional[np.ndarray] = None,
        channel_first: Optional[bool] = None,
        explain_func: Optional[Callable] = None,
        explain_func_kwargs: Optional[Dict] = None,
        model_predict_kwargs: Optional[Dict] = None,
        softmax: Optional[bool] = True,
        device: Optional[str] = None,
        batch_size: int = 64,
        **kwargs,
    ) -> List[float]:
        """
        This implementation represents the main logic of the metric and makes the class object callable.
        It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch),
        output labels (y_batch) and a torch or tensorflow model (model).

        Calls general_preprocess() with all relevant arguments, calls
        () on each instance, and saves results to evaluation_scores.
        Calls custom_postprocess() afterwards. Finally returns evaluation_scores.

        Parameters
        ----------
        model: torch.nn.Module, tf.keras.Model
            A torch or tensorflow model that is subject to explanation.
        x_batch: np.ndarray
            A np.ndarray which contains the input data that are explained.
        y_batch: np.ndarray
            A np.ndarray which contains the output labels that are explained.
        a_batch: np.ndarray, optional
            A np.ndarray which contains pre-computed attributions i.e., explanations.
        s_batch: np.ndarray, optional
            A np.ndarray which contains segmentation masks that matches the input.
        channel_first: boolean, optional
            Indicates of the image dimensions are channel first, or channel last.
            Inferred from the input shape if None.
        explain_func: callable
            Callable generating attributions.
        explain_func_kwargs: dict, optional
            Keyword arguments to be passed to explain_func on call.
        model_predict_kwargs: dict, optional
            Keyword arguments to be passed to the model's predict method.
        softmax: boolean
            Indicates whether to use softmax probabilities or logits in model prediction.
            This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used.
        device: string
            Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu".
        kwargs: optional
            Keyword arguments.

        Returns
        -------
        evaluation_scores: list
            a list of Any with the evaluation scores of the concerned batch.

        Examples:
        --------
            # Minimal imports.
            >> import quantus
            >> from quantus import LeNet
            >> import torch

            # Enable GPU.
            >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

            # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).
            >> model = LeNet()
            >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model"))

            # Load MNIST datasets and make loaders.
            >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True)
            >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24)

            # Load a batch of inputs and outputs to use for XAI evaluation.
            >> x_batch, y_batch = iter(test_loader).next()
            >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()

            # Generate Saliency attributions of the test set batch of the test set.
            >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1)
            >> a_batch_saliency = a_batch_saliency.cpu().numpy()

            # Initialise the metric and evaluate explanations by calling the metric instance.
            >> metric = Metric(abs=True, normalise=False)
            >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency)
        """
        return super().__call__(
            model=model,
            x_batch=x_batch,
            y_batch=y_batch,
            a_batch=a_batch,
            s_batch=s_batch,
            custom_batch=None,
            channel_first=channel_first,
            explain_func=explain_func,
            explain_func_kwargs=explain_func_kwargs,
            softmax=softmax,
            device=device,
            model_predict_kwargs=model_predict_kwargs,
            batch_size=batch_size,
            **kwargs,
        )
    def distance_to_centroid(self,
            model: ModelInterface,
            all_centroids: np.ndarray,
            y_pred: np.ndarray,
            x_input: np.ndarray) -> np.ndarray:
        """
        Calculate the Euclidean distance between the predicted embedding of an input and the centroid of the predicted class.

        Parameters
        ----------
        model : ModelInterface
            The model containing the embedding network and classification head.
        all_centroids : np.ndarray
            An array containing the centroids of all classes.
        y_pred : np.ndarray or torch.Tensor
            The predicted output from the model, typically a one-hot encoded vector.
        x_input : np.ndarray
            The input data for which the distance to the class centroid is calculated.

        Returns
        -------
        np.ndarray
            The calculated distance between the input's embedding and the predicted class centroid.
        """

        x_input = torch.tensor(x_input, dtype=torch.float32).to(model.device)
        y_pred_embedding = model.model.embedding
        y_pred_embedding = y_pred_embedding.view(y_pred_embedding.size(0), -1)

        if isinstance(y_pred_embedding, torch.Tensor):
            y_pred_embedding = y_pred_embedding.detach().cpu().numpy()

        if isinstance(y_pred, torch.Tensor):
            y_pred = y_pred.detach().cpu().numpy()

        y_pred = np.argmax(y_pred)
        pred_centroid = all_centroids[y_pred]  # [B, D]
        distance = np.linalg.norm(y_pred_embedding - pred_centroid, axis=1)
        return distance

    def compute_centroids(self,
            model: ModelInterface) -> np.ndarray:
        """
        Compute the centroids of the clusters in the KNN model.

        Parameters
        ----------
        model : ModelInterface
            The model containing the KNN model.

        Returns
        -------
        np.ndarray
            An array containing the centroids of all classes.
        """
        knn = model.model.classification_head.knn
        centroid = []
        num_classes = len(knn.classes_)
        for i in range (num_classes):
            class_i = knn._fit_X[knn._y == i]

            # Compute the centroid
            centroid_i = np.mean(class_i, axis = 0)
            centroid.append(centroid_i)
        return centroid
    def evaluate_instance(
        self,
        model: ModelInterface,
        x: np.ndarray,
        y: np.ndarray,
        a: np.ndarray,
    ) -> float:
        """
        Evaluate instance gets model and data for a single instance as input and returns the evaluation result.

        Parameters
        ----------
        model: ModelInterface
            A ModelInteface that is subject to explanation.
        x: np.ndarray
            The input to be evaluated on an instance-basis.
        y: np.ndarray
            The output to be evaluated on an instance-basis.
        a: np.ndarray
            The explanation to be evaluated on an instance-basis.
        Returns
        -------
        float
            The evaluation results.
        """
        # Predict on x.
        x_input = model.shape_input(x, x.shape, channel_first=True)

        ############ outputs a class ################
        # need to find the distanceto class centroid and store
        # same "similar to probability" -> moght be more interesting if the classification was k+1 because a fully erased thing makes no sense to be classified as a cat or a dog --> but how do i train the model?
        if self.distance_based:
            y_pred = model.predict(x_input)
            centroid_array = self.compute_centroids(model = model)
            distance = self.distance_to_centroid(model = model, all_centroids = centroid_array, y_pred = y_pred, x_input = x_input)
            y_pred = distance

        # Calculate the area over the curve (AOC) score.
        # higher is worse so the sign of preds needs to be changed
        # higher is better so the sign of preds does not need to be changed
        #############################################
        else:
            y_pred = float(model.predict(x_input)[:, y])

        # Segment image.
        segments = utils.get_superpixel_segments(
            img=np.moveaxis(x, 0, -1).astype("double"),
            segmentation_method=self.segmentation_method,
        )
        nr_segments = len(np.unique(segments))
        asserts.assert_nr_segments(nr_segments=nr_segments)

        # Calculate average attribution of each segment.
        att_segs = np.zeros(nr_segments)
        for i, s in enumerate(range(nr_segments)):
            att_segs[i] = np.mean(a[:, segments == s])

        # Sort segments based on the mean attribution (descending order).
        s_indices = np.argsort(-att_segs)

        preds = []
        x_prev_perturbed = x

        for i_ix, s_ix in enumerate(s_indices):
            # Perturb input by indices of attributions.
            a_ix = np.nonzero((segments == s_ix).flatten())[0]

            x_perturbed = self.perturb_func(
                arr=x_prev_perturbed,
                indices=a_ix,
                indexed_axes=self.a_axes,
            )
            warn.warn_perturbation_caused_no_change(
                x=x_prev_perturbed, x_perturbed=x_perturbed
            )

            # Predict on perturbed input x.
            x_input = model.shape_input(x_perturbed, x.shape, channel_first=True)


            ### changes #####
            if self.distance_based:
                y_pred_perturb = model.predict(x_input)
                distance_perturb = self.distance_to_centroid(model = model, all_centroids = centroid_array, y_pred = y_pred_perturb, x_input =x_input)
                y_pred_perturb = distance_perturb
            else:
                y_pred_perturb = float(model.predict(x_input)[:, y])


            # Normalise the scores to be within range [0, 1].
            preds.append(float(y_pred_perturb / y_pred))
            x_prev_perturbed = x_perturbed

        if self.distance_based:
            # inverting so the greater the distance the smaller the number and auc works correctly
            preds = 1.0 / (np.array(preds) + 1e-8)

        # Calculate the area over the curve (AOC) score.
        aoc = len(preds) - utils.calculate_auc(np.array(preds))
        if self.return_scores:
            return aoc, preds
        return aoc

    def custom_preprocess(
        self,
        x_batch: np.ndarray,
        **kwargs,
    ) -> None:
        """
        Implementation of custom_preprocess_batch.

        Parameters
        ----------
        model: torch.nn.Module, tf.keras.Model
            A torch or tensorflow model e.g., torchvision.models that is subject to explanation.
        x_batch: np.ndarray
            A np.ndarray which contains the input data that are explained.
        y_batch: np.ndarray
            A np.ndarray which contains the output labels that are explained.
        a_batch: np.ndarray, optional
            A np.ndarray which contains pre-computed attributions i.e., explanations.
        s_batch: np.ndarray, optional
            A np.ndarray which contains segmentation masks that matches the input.
        custom_batch: any
            Gives flexibility ot the user to use for evaluation, can hold any variable.

        Returns
        -------
        None
        """
        # Infer number of input channels.
        self.nr_channels = x_batch.shape[1]

    @property
    def get_aoc_score(self):
        """Calculate the area over the curve (AOC) score for several test samples."""
        return np.mean(self.evaluation_scores)

    def evaluate_batch(
        self,
        model: ModelInterface,
        x_batch: np.ndarray,
        y_batch: np.ndarray,
        a_batch: np.ndarray,
        **kwargs,
    ) -> List[float]:
        """
        This method performs XAI evaluation on a single batch of explanations.
        For more information on the specific logic, we refer the metric’s initialisation docstring.

        Parameters
        ----------
        model: ModelInterface
            A ModelInterface that is subject to explanation.
        x_batch: np.ndarray
            The input to be evaluated on a batch-basis.
        y_batch: np.ndarray
            The output to be evaluated on a batch-basis.
        a_batch: np.ndarray
            The explanation to be evaluated on a batch-basis.
        kwargs:
            Unused.

        Returns
        -------
        scores_batch:
            The evaluation results.
        """
        return [
            self.evaluate_instance(model=model, x=x, y=y, a=a)
            for x, y, a in zip(x_batch, y_batch, a_batch)
        ]

    ############# new ##########
    def custom_postprocess(self, **kwargs):
        if self.return_scores:
            if self.default_plot_func is not None:
                return self.default_plot_func(self.evaluation_scores)
            return self.evaluation_scores

### app

In [None]:
def xai_scores_irof(model, test_loader, device, plot_flag=False,
                    segmentation_method="slic", perturb_baseline="black", distance_based=False):

    model.eval()
    target_layers = [model.encoder.layer4[-1]]

    irof_metric = IROF(
        segmentation_method=segmentation_method,
        perturb_baseline=perturb_baseline,
        abs=False,
        normalise=True,
        disable_warnings=True,
        return_scores=True,  # Per ottenere sia AOC che la curva
        distance_based=distance_based,
        return_aggregate=False,  # IMPORTANTE: disabilita l'aggregazione automatica
    )

    all_scores = []
    global_idx = 0

    for batch_idx, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            preds = model(images).argmax(dim=1)

        for i in range(images.size(0)):
            p_idx = preds[i].item()
            g_idx = labels[i].item()
            img_tensor = images[i:i+1]

            record = {"idx": global_idx, "true": g_idx, "pred": p_idx}

            # Inizializza i metodi CAM
            curr_methods = {
                "GradCAM": GradCAM(model=model, target_layers=target_layers),
                "EigenCAM": EigenCAM(model=model, target_layers=target_layers),
                "AblationCAM": AblationCAM(model=model, target_layers=target_layers)
            }

            if plot_flag:
                fig, axes = plt.subplots(1, 4, figsize=(18, 5))
                img_np = images[i].permute(1, 2, 0).cpu().numpy()
                img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
                axes[0].imshow(img_np)
                axes[0].axis('off')
                axes[0].set_title("Original Image")

            for m_idx, (name, cam_obj) in enumerate(curr_methods.items()):
                # Genera la mappa di saliency
                with torch.enable_grad():
                    grayscale_cam = cam_obj(
                        input_tensor=img_tensor,
                        targets=[ClassifierOutputTarget(p_idx)]
                    )[0, :]  # Shape: [H, W]

                # Prepara l'input per IROF
                x_np = img_tensor.detach().cpu().numpy()  # [1, C, H, W]

                # Espandi grayscale_cam per avere la dimensione dei canali
                a_np = np.repeat(grayscale_cam[np.newaxis, :, :], x_np.shape[1], axis=0)  # [C, H, W]
                a_np = a_np[np.newaxis, ...]  # [1, C, H, W]

                # Calcola IROF
                try:
                    irof_result = irof_metric(
                        model=model,
                        x_batch=x_np,
                        y_batch=np.array([g_idx]),
                        a_batch=a_np,
                        device=device
                    )

                    # Con return_scores=True e return_aggregate=False,
                    # irof_result è una lista: [(aoc, preds)]
                    if isinstance(irof_result, list) and len(irof_result) > 0:
                        result_item = irof_result[0]

                        if isinstance(result_item, tuple) and len(result_item) == 2:
                            # Estrai AOC e curva
                            aoc_score, irof_curve = result_item
                            record[f"{name}_IROF_AOC"] = float(aoc_score)

                            # Salva la curva come lista
                            if isinstance(irof_curve, np.ndarray):
                                record[f"{name}_IROF_curve"] = irof_curve.tolist()
                            else:
                                record[f"{name}_IROF_curve"] = list(irof_curve)
                        else:
                            # Fallback: solo AOC
                            aoc_score = float(result_item)
                            record[f"{name}_IROF_AOC"] = aoc_score
                    else:
                        print(f"Formato risultato inaspettato per {name}, sample {global_idx}")
                        record[f"{name}_IROF_AOC"] = np.nan

                except Exception as e:
                    print(f"Errore IROF per {name}, sample {global_idx}: {str(e)}")
                    import traceback
                    traceback.print_exc()
                    record[f"{name}_IROF_AOC"] = np.nan
                    continue

                if plot_flag:
                    viz = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
                    axes[m_idx+1].imshow(viz)

                    # Mostra AOC nel titolo
                    aoc_val = record.get(f"{name}_IROF_AOC", np.nan)
                    if not np.isnan(aoc_val):
                        axes[m_idx+1].set_title(f"{name}\nIROF AOC: {aoc_val:.3f}")
                    else:
                        axes[m_idx+1].set_title(f"{name}\nIROF: Error")
                    axes[m_idx+1].axis('off')

                # Cleanup CAM
                if hasattr(cam_obj, 'activations_and_grads'):
                    cam_obj.activations_and_grads.release()

            if plot_flag:
                plt.tight_layout()
                plt.savefig(f"/content/drive/MyDrive/clxai/results_faith/test/irof/{model_version}_sample_{global_idx}.png", bbox_inches='tight', dpi=150)
                plt.close(fig)
                plt.clf()

            all_scores.append(record)
            global_idx += 1

            # Cleanup
            del curr_methods
            torch.cuda.empty_cache()

        # Cleanup batch
        del images, labels, preds
        torch.cuda.empty_cache()

        return pd.DataFrame(all_scores)

In [None]:
# irof_scores = xai_scores_irof(
#     model = unfrozen_model,
#     test_loader = test_loader,
#     device = device,
#     plot_flag = True,
# )

In [None]:
# irof_scores.to_csv(f"/content/drive/MyDrive/clxai/results_faith/test_2/{model_version}_irof_xai_scores.csv", index=False)

## sparseness

In [None]:
def xai_scores_sparsity(model, test_loader, device, plot_flag=False):
    model.eval()
    target_layers = [model.encoder.layer4[-1]]

    sparsity_metric = quantus.Sparseness(
        disable_warnings=True,
    )

    all_scores = []
    global_idx = 0

    for batch_idx, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            preds = model(images).argmax(dim=1)

        for i in range(images.size(0)):
            p_idx = preds[i].item()
            g_idx = labels[i].item()
            img_tensor = images[i:i+1]

            record = {"idx": global_idx, "true": g_idx, "pred": p_idx}

            curr_methods = {
                "GradCAM": GradCAM(model=model, target_layers=target_layers),
                "EigenCAM": EigenCAM(model=model, target_layers=target_layers),
                "AblationCAM": AblationCAM(model=model, target_layers=target_layers)
            }

            if plot_flag:
                fig, axes = plt.subplots(1, 4, figsize=(18, 5))
                img_np = images[i].permute(1, 2, 0).cpu().numpy()
                img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
                axes[0].imshow(img_np)
                axes[0].axis('off')

            for m_idx, (name, cam_obj) in enumerate(curr_methods.items()):
                with torch.enable_grad():
                    grayscale_cam = cam_obj(input_tensor=img_tensor,
                                           targets=[ClassifierOutputTarget(p_idx)])[0, :]

                # Calcolo della Sparsity
                # Restituisce un valore singolo (float) tra 0 e 1 per ogni spiegazione
                sparsity_score = sparsity_metric(
                    model=model,
                    x_batch=img_tensor.detach().cpu().numpy(),
                    y_batch=np.array([g_idx]),
                    a_batch=grayscale_cam[np.newaxis, ...],
                    device=device
                )[0]

                record[f"{name}_Sparsity"] = sparsity_score

                if plot_flag:
                    viz = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
                    axes[m_idx+1].imshow(viz)
                    axes[m_idx+1].set_title(f"{name}\nSparse: {sparsity_score:.3f}")
                    axes[m_idx+1].axis('off')

                if hasattr(cam_obj, 'activations_and_grads'):
                    cam_obj.activations_and_grads.release()

            if plot_flag:
                plt.savefig(f"/content/drive/MyDrive/clxai/results_complex/test/sparseness/{model_version}_sample_{global_idx}_sparsity.png")
                plt.close(fig)

            all_scores.append(record)
            global_idx += 1
            del curr_methods

        del images, labels, preds
        torch.cuda.empty_cache()

        return pd.DataFrame(all_scores)

In [None]:
# sparse_scores = xai_scores_sparsity(
#     model = unfrozen_model,
#     test_loader = test_loader,
#     device = device,
#     plot_flag = True,
# )

In [None]:
# sparse_scores.to_csv(f"/content/drive/MyDrive/clxai/results_complex/test2/{model_version}_sparse_xai_scores.csv", index=False)

## complexity

In [None]:
def xai_scores_complexity(model, test_loader, device, plot_flag=False):
    model.eval()
    target_layers = [model.encoder.layer4[-1]]

    complexity_metric = quantus.Complexity(
        disable_warnings=True,
    )

    all_scores = []
    global_idx = 0

    for batch_idx, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            preds = model(images).argmax(dim=1)

        for i in range(images.size(0)):
            p_idx = preds[i].item()
            g_idx = labels[i].item()
            img_tensor = images[i:i+1]

            record = {"idx": global_idx, "true": g_idx, "pred": p_idx}

            curr_methods = {
                "GradCAM": GradCAM(model=model, target_layers=target_layers),
                "EigenCAM": EigenCAM(model=model, target_layers=target_layers),
                "AblationCAM": AblationCAM(model=model, target_layers=target_layers)
            }

            if plot_flag:
                fig, axes = plt.subplots(1, 4, figsize=(18, 5))
                img_np = images[i].permute(1, 2, 0).cpu().numpy()
                img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
                axes[0].imshow(img_np)
                axes[0].axis('off')

            for m_idx, (name, cam_obj) in enumerate(curr_methods.items()):
                with torch.enable_grad():
                    grayscale_cam = cam_obj(input_tensor=img_tensor,
                                           targets=[ClassifierOutputTarget(p_idx)])[0, :]

                complexity_score = complexity_metric(
                    model=model,
                    x_batch=img_tensor.detach().cpu().numpy(),
                    y_batch=np.array([g_idx]),
                    a_batch=grayscale_cam[np.newaxis, ...],
                    device=device
                )[0]

                record[f"{name}_Complexity"] = complexity_score

                if plot_flag:
                    viz = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
                    axes[m_idx+1].imshow(viz)
                    axes[m_idx+1].set_title(f"{name}\nSparse: {complexity_score:.3f}")
                    axes[m_idx+1].axis('off')

                if hasattr(cam_obj, 'activations_and_grads'):
                    cam_obj.activations_and_grads.release()

            if plot_flag:
                plt.savefig(f"/content/drive/MyDrive/clxai/results_complex/test/complexity/{model_version}_sample_{global_idx}_complexity.png")
                plt.close(fig)

            all_scores.append(record)
            global_idx += 1
            del curr_methods

        del images, labels, preds
        torch.cuda.empty_cache()

        return pd.DataFrame(all_scores)

In [None]:
# complex_scores = xai_scores_complexity(
#     model = unfrozen_model,
#     test_loader = test_loader,
#     device = device,
#     plot_flag = True,
# )

In [None]:
# complex_scores.to_csv(f"/content/drive/MyDrive/clxai/results_complex/test2/{model_version}_complex_xai_scores.csv", index=False)

## contrastivity

In [None]:
from skimage.metrics import structural_similarity as ssim

def xai_scores_ssim_robustness(model, test_loader, device, plot_flag=False, noise_level=0.02):
    model.eval()
    target_layers = [model.encoder.layer4[-1]]
    all_scores = []
    global_idx = 0

    for batch_idx, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            preds = model(images).argmax(dim=1)

        for i in range(images.size(0)):

            torch.manual_seed(global_idx)
            np.random.seed(global_idx)

            p_idx = preds[i].item()
            g_idx = labels[i].item()
            img_tensor = images[i:i+1]

            perturbed_tensor = img_tensor + torch.randn_like(img_tensor) * noise_level
            perturbed_tensor = torch.clamp(perturbed_tensor, 0, 1)

            record = {"idx": global_idx, "true": g_idx, "pred": p_idx}

            methods_list = [
                ("GradCAM", GradCAM(model=model, target_layers=target_layers)),
                ("EigenCAM", EigenCAM(model=model, target_layers=target_layers)),
                ("AblationCAM", AblationCAM(model=model, target_layers=target_layers))
            ]

            if plot_flag:
                fig, axes = plt.subplots(2, 4, figsize=(20, 10))

                img_np = images[i].permute(1, 2, 0).cpu().numpy()
                img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)

                pert_np = perturbed_tensor[0].permute(1, 2, 0).cpu().numpy()
                pert_np = (pert_np - pert_np.min()) / (pert_np.max() - pert_np.min() + 1e-8)

                axes[0, 0].imshow(img_np)
                axes[0, 0].set_title("Original Input")
                axes[1, 0].imshow(pert_np)
                axes[1, 0].set_title(f"Perturbed (Noise: {noise_level})")
                for r in range(2): axes[r, 0].axis('off')

            for m_idx, (name, cam_obj) in enumerate(methods_list):
                with torch.enable_grad():

                    cam_orig = cam_obj(input_tensor=img_tensor,
                                       targets=[ClassifierOutputTarget(p_idx)])[0, :]

                    cam_pert = cam_obj(input_tensor=perturbed_tensor,
                                       targets=[ClassifierOutputTarget(p_idx)])[0, :]

                robustness_score = ssim(cam_orig, cam_pert, data_range=1.0)
                record[f"{name}_SSIM_Robustness"] = robustness_score

                if plot_flag:
                    viz_orig = show_cam_on_image(img_np, cam_orig, use_rgb=True)
                    axes[0, m_idx+1].imshow(viz_orig)
                    axes[0, m_idx+1].set_title(f"{name} (Orig)")

                    viz_pert = show_cam_on_image(pert_np, cam_pert, use_rgb=True)
                    axes[1, m_idx+1].imshow(viz_pert)
                    axes[1, m_idx+1].set_title(f"{name} (Perturbed)\nSSIM Robustness: {robustness_score:.3f}")

                    for r in range(2): axes[r, m_idx+1].axis('off')

                if hasattr(cam_obj, 'activations_and_grads'):
                    cam_obj.activations_and_grads.release()

            if plot_flag:
                plt.tight_layout()
                plt.savefig(f"/content/drive/MyDrive/clxai/results_contrast/test/robustness/{model_version}_ssim_robustness_sample_{global_idx}.png")
                plt.close(fig)

            all_scores.append(record)
            global_idx += 1

        del images, labels, preds
        torch.cuda.empty_cache()

        return pd.DataFrame(all_scores)

In [None]:
# robust_scores = xai_scores_ssim_robustness(
#     model = unfrozen_model,
#     test_loader = test_loader,
#     device = device,
#     plot_flag = True,
# )

In [None]:
# robust_scores.to_csv(f"/content/drive/MyDrive/clxai/results_contrast/test2/{model_version}_robust_xai_scores.csv", index=False)

## constrastivity 2

In [None]:
from skimage.metrics import structural_similarity as ssim

def xai_scores_ssim_contrastivity(model, test_loader, device, plot_flag=False):
    model.eval()
    target_layers = [model.encoder.layer4[-1]]
    all_scores = []
    global_idx = 0

    for batch_idx, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            logits = model(images)
            preds = logits.argmax(dim=1)

            N, C = logits.shape
            r = torch.randint(0, C - 1, (N,), device=device)
            contrast_preds = r + (r >= preds).long()

        for i in range(images.size(0)):
            p_idx = preds[i].item()
            c_idx = contrast_preds[i].item() # Classe di contrasto
            g_idx = labels[i].item()
            img_tensor = images[i:i+1]

            record = {
                "idx": global_idx,
                "true": g_idx,
                "pred": p_idx,
                "contrast_class": c_idx
            }

            methods_list = [
                ("GradCAM", GradCAM(model=model, target_layers=target_layers)),
                ("EigenCAM", EigenCAM(model=model, target_layers=target_layers)),
                ("AblationCAM", AblationCAM(model=model, target_layers=target_layers))
            ]

            if plot_flag:
                fig, axes = plt.subplots(2, 4, figsize=(20, 10))
                img_np = images[i].permute(1, 2, 0).cpu().numpy()
                img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)

                axes[0, 0].imshow(img_np)
                axes[0, 0].set_title(f"Input (Label: {g_idx})")
                axes[1, 0].axis('off')
                axes[0, 0].axis('off')

            for m_idx, (name, cam_obj) in enumerate(methods_list):
                with torch.enable_grad():
                    cam_factual = cam_obj(input_tensor=img_tensor,
                                          targets=[ClassifierOutputTarget(p_idx)])[0, :]

                    cam_contrastive = cam_obj(input_tensor=img_tensor,
                                              targets=[ClassifierOutputTarget(c_idx)])[0, :]

                # SSIM Contrastivity: 1 uguali - 0 indipendenti - <0 opposte
                contrast_ssim = ssim(cam_factual, cam_contrastive, data_range=1.0)

                # L2 Contrastivity
                contrast_l2 = np.linalg.norm(cam_factual - cam_contrastive)

                record[f"{name}_SSIM_Contrastivity"] = contrast_ssim
                record[f"{name}_L2_Contrastivity"] = contrast_l2

                if plot_flag:
                    viz_f = show_cam_on_image(img_np, cam_factual, use_rgb=True)
                    axes[0, m_idx+1].imshow(viz_f)
                    axes[0, m_idx+1].set_title(f"{name}\nPred: {p_idx}")

                    viz_c = show_cam_on_image(img_np, cam_contrastive, use_rgb=True)
                    axes[1, m_idx+1].imshow(viz_c)
                    axes[1, m_idx+1].set_title(f"\nContrast: {c_idx}\nSSIM Diff: {contrast_ssim:.3f}\nL2 Norm: {contrast_l2:.3f}")

                    for r in range(2): axes[r, m_idx+1].axis('off')

                if hasattr(cam_obj, 'activations_and_grads'):
                    cam_obj.activations_and_grads.release()

            if plot_flag:
                plt.tight_layout()
                plt.savefig(f"/content/drive/MyDrive/clxai/results_contrast/test/contrastivity/{model_version}_contrastivity_{global_idx}.png")
                plt.close(fig)

            all_scores.append(record)
            global_idx += 1

        del images, labels, preds, contrast_preds
        torch.cuda.empty_cache()

        return pd.DataFrame(all_scores)

In [None]:
contrastivity_scores = xai_scores_ssim_contrastivity(
    model = unfrozen_model,
    test_loader = test_loader,
    device = device,
    plot_flag = True,
)

100%|██████████| 16/16 [00:00<00:00, 79.91it/s]
100%|██████████| 16/16 [00:00<00:00, 95.25it/s]
100%|██████████| 16/16 [00:00<00:00, 98.29it/s]
100%|██████████| 16/16 [00:00<00:00, 102.75it/s]
100%|██████████| 16/16 [00:00<00:00, 84.49it/s]
100%|██████████| 16/16 [00:00<00:00, 83.05it/s]
100%|██████████| 16/16 [00:00<00:00, 80.69it/s]
100%|██████████| 16/16 [00:00<00:00, 85.57it/s]
100%|██████████| 16/16 [00:00<00:00, 61.03it/s]
100%|██████████| 16/16 [00:00<00:00, 88.67it/s]
100%|██████████| 16/16 [00:00<00:00, 50.26it/s]
100%|██████████| 16/16 [00:00<00:00, 39.76it/s]
100%|██████████| 16/16 [00:00<00:00, 49.71it/s]
100%|██████████| 16/16 [00:00<00:00, 52.00it/s]
100%|██████████| 16/16 [00:00<00:00, 60.40it/s]
100%|██████████| 16/16 [00:00<00:00, 70.76it/s]
100%|██████████| 16/16 [00:00<00:00, 54.92it/s]
100%|██████████| 16/16 [00:00<00:00, 70.50it/s]
100%|██████████| 16/16 [00:00<00:00, 60.00it/s]
100%|██████████| 16/16 [00:00<00:00, 80.73it/s]
100%|██████████| 16/16 [00:00<00:00, 59

In [None]:
contrastivity_scores.to_csv(f"/content/drive/MyDrive/clxai/results_contrast/test2/{model_version}_contrastivity_xai_scores.csv", index=False)

In [None]:
1 / 0

ZeroDivisionError: division by zero

# Analysis

## faith

In [None]:
mode = 'irof'

In [None]:
ce_path = f'/content/drive/MyDrive/clxai/results_faith/test_2/ce_{mode}_xai_scores.csv'
scl_path = f'/content/drive/MyDrive/clxai/results_faith/test_2/scl_{mode}_xai_scores.csv'

In [None]:
ce_df = pd.read_csv(ce_path)
scl_df = pd.read_csv(scl_path)

In [None]:
print(f'CE: {round((ce_df['true'] == ce_df['pred']).mean() * 100, 2)}')
print(f'SCL: {round((scl_df['true'] == scl_df['pred']).mean() * 100, 2)}')

In [None]:
mask1 = ce_df['true'] == ce_df['pred']
mask2 = scl_df['true'] == scl_df['pred']
final_mask = mask1 & mask2
ce_filtered = ce_df[final_mask].copy()
scl_filtered = scl_df[final_mask].copy()
ce_filtered = ce_filtered.reset_index(drop=True)
scl_filtered = scl_filtered.reset_index(drop=True)

In [None]:
scl_filtered['GradCAM_IROF_AOC'].mean()

In [None]:
ce_filtered['GradCAM_IROF_AOC'].mean()

In [None]:
scl_filtered['AblationCAM_IROF_AOC'].mean()

In [None]:
ce_filtered['AblationCAM_IROF_AOC'].mean()

In [None]:
scl_filtered['EigenCAM_IROF_AOC'].mean()

In [None]:
ce_filtered['EigenCAM_IROF_AOC'].mean()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

sort_metric = "GradCAM_PF_AUC"
diff_sort = ce_filtered[sort_metric] - scl_filtered[sort_metric]
sorted_indices = diff_sort.sort_values().index

metrics = ["GradCAM_PF_AUC", "EigenCAM_PF_AUC", "AblationCAM_PF_AUC"]

for metric in metrics:
    mean_ce = ce_filtered[metric].mean()
    mean_scl = scl_filtered[metric].mean()

    # reorder data
    ce_values = ce_filtered.loc[sorted_indices, metric].values
    scl_values = scl_filtered.loc[sorted_indices, metric].values

    current_diff = ce_values - scl_values

    plt.figure(figsize=(25, 10))
    x_positions = np.arange(len(ce_values))

    colors = ['#2ca02c' if d < 0 else '#d62728' for d in current_diff]

    plt.vlines(x_positions, ymin=ce_values, ymax=scl_values,
               colors=colors, linewidth=2, alpha=0.6, zorder=2)

    # Plot the points
    plt.scatter(x_positions, ce_values, color='royalblue', alpha=1.0,
                label=f'ce_{mode}_df', s=50, zorder=4, edgecolors='white')
    plt.scatter(x_positions, scl_values, color='darkorange', alpha=1.0,
                label=f'scl_{mode}_df', s=50, zorder=4, edgecolors='white')

    # Vertical grid boundaries between indices
    boundary_positions = x_positions - 0.5
    for x in boundary_positions[1:]:
        plt.axvline(x=x, color='gray', linestyle='-', linewidth=0.5, alpha=0.2, zorder=1)

    # Labels and Ticks
    plt.xticks(ticks=x_positions, labels=sorted_indices, rotation=90, fontsize=12)
    plt.xlim(-0.7, len(ce_values) - 0.3)
    plt.grid(axis='y', linestyle='--', alpha=0.3, zorder=1)

    plt.title(f"{metric}, mean CE: {mean_ce:0.3f} vs mean SCL: {mean_scl:0.3f}",
              fontsize=25, fontweight='bold', pad=20)
    plt.xlabel("Original Sample ID", fontsize=18)
    plt.ylabel(f"{metric} AUC", fontsize=18)

    plt.legend(loc='upper left', frameon=True, facecolor='white', framealpha=1, fontsize=15)

    plt.tight_layout()
    plt.savefig(f"/content/drive/MyDrive/clxai/results_faith/test_2/{mode}_diff_line_plot_{metric}_via_{sort_metric}.png", bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()

In [None]:
# import matplotlib.pyplot as plt
# import numpy as np

# # 1. Calculate the mean difference for each metric
# # (ce_df - scl_df) -> Negative means ce_df is lower/better
# metrics = ["GradCAM_PF_AUC", "EigenCAM_PF_AUC", "HiResCAM_PF_AUC"]
# mean_diffs = [(ce_df[m] - scl_df[m]).mean() for m in metrics]

# # 2. Setup colors based on improvement or regression
# # Green for negative (improvement), Red for positive (regression)
# bar_colors = ['#2ca02c' if diff < 0 else '#d62728' for diff in mean_diffs]

# plt.figure(figsize=(10, 7))

# # Create the bars
# bars = plt.bar(metrics, mean_diffs, color=bar_colors, edgecolor='black', alpha=0.8)

# # Add a horizontal line at 0 for reference
# plt.axhline(0, color='black', linewidth=1.5, linestyle='-')

# # 3. Add text labels on top/bottom of bars for the exact values
# for bar in bars:
#     height = bar.get_height()
#     plt.text(bar.get_x() + bar.get_width()/2., height,
#              f'{height:.4f}',
#              ha='center', va='bottom' if height > 0 else 'top',
#              fontsize=12, fontweight='bold')

# # Formatting
# plt.title("Average Difference in Faithfulness (ce_df - scl_df)\nLower is Better (Negative = Improvement)",
#           fontsize=16, fontweight='bold', pad=20)
# plt.ylabel("Mean AUC Difference", fontsize=14)
# plt.grid(axis='y', linestyle='--', alpha=0.3)

# # Add a slight buffer to the y-axis so labels don't get cut off
# y_max = max(abs(min(mean_diffs)), abs(max(mean_diffs))) * 1.3
# plt.ylim(-y_max, y_max)

# plt.tight_layout()
# plt.show()

## sparse

In [None]:
ce_complex_df = pd.read_csv('/content/drive/MyDrive/clxai/results_complex/test2/ce_sparse_xai_scores.csv')
scl_complex_df = pd.read_csv('/content/drive/MyDrive/clxai/results_complex/test2/scl_sparse_xai_scores.csv')

In [None]:
mask1 = ce_complex_df['true'] == ce_complex_df['pred']
mask2 = scl_complex_df['true'] == scl_complex_df['pred']
final_mask = mask1 & mask2
ce_filtered = ce_complex_df[final_mask].copy()
scl_filtered = scl_complex_df[final_mask].copy()
ce_filtered = ce_filtered.reset_index(drop=True)
scl_filtered = scl_filtered.reset_index(drop=True)

In [None]:
ce_filtered.mean()

In [None]:
scl_filtered.mean()

## complex

In [None]:
ce_complex_df = pd.read_csv('/content/drive/MyDrive/clxai/results_complex/test2/ce_complex_xai_scores.csv')
scl_complex_df = pd.read_csv('/content/drive/MyDrive/clxai/results_complex/test2/scl_complex_xai_scores.csv')

In [None]:
mask1 = ce_complex_df['true'] == ce_complex_df['pred']
mask2 = scl_complex_df['true'] == scl_complex_df['pred']
final_mask = mask1 & mask2
ce_filtered = ce_complex_df[final_mask].copy()
scl_filtered = scl_complex_df[final_mask].copy()
ce_filtered = ce_filtered.reset_index(drop=True)
scl_filtered = scl_filtered.reset_index(drop=True)

In [None]:
ce_filtered.mean()

In [None]:
scl_filtered.mean()

## contrastivity

In [None]:
ce_contrast_df = pd.read_csv('/content/drive/MyDrive/clxai/results_contrast/test2/ce_contrastivity_xai_scores.csv')
scl_contrast_df = pd.read_csv('/content/drive/MyDrive/clxai/results_contrast/test2/scl_contrastivity_xai_scores.csv')

In [None]:
mask1 = ce_contrast_df['true'] == ce_contrast_df['pred']
mask2 = scl_contrast_df['true'] == scl_contrast_df['pred']
final_mask = mask1 & mask2
ce_filtered = ce_contrast_df[final_mask].copy()
scl_filtered = scl_contrast_df[final_mask].copy()
ce_filtered = ce_filtered.reset_index(drop=True)
scl_filtered = scl_filtered.reset_index(drop=True)

In [None]:
ce_filtered[['AblationCAM_L2_Contrastivity','AblationCAM_SSIM_Contrastivity','GradCAM_L2_Contrastivity','GradCAM_SSIM_Contrastivity']].mean()

Unnamed: 0,0
AblationCAM_L2_Contrastivity,6.803687
AblationCAM_SSIM_Contrastivity,0.594501
GradCAM_L2_Contrastivity,18.148861
GradCAM_SSIM_Contrastivity,0.008563


In [None]:
scl_filtered[['AblationCAM_L2_Contrastivity','AblationCAM_SSIM_Contrastivity','GradCAM_L2_Contrastivity','GradCAM_SSIM_Contrastivity']].mean()

Unnamed: 0,0
AblationCAM_L2_Contrastivity,2.901425
AblationCAM_SSIM_Contrastivity,0.858255
GradCAM_L2_Contrastivity,12.921122
GradCAM_SSIM_Contrastivity,-0.033921


## robustness

In [None]:
ce_robust_df = pd.read_csv('/content/drive/MyDrive/clxai/results_contrast/test2/ce_robust_xai_scores.csv')
scl_robust_df = pd.read_csv('/content/drive/MyDrive/clxai/results_contrast/test2/scl_robust_xai_scores.csv')

In [None]:
mask1 = ce_robust_df['true'] == ce_robust_df['pred']
mask2 = scl_robust_df['true'] == scl_robust_df['pred']
final_mask = mask1 & mask2
ce_filtered = ce_robust_df[final_mask].copy()
scl_filtered = scl_robust_df[final_mask].copy()
ce_filtered = ce_filtered.reset_index(drop=True)
scl_filtered = scl_filtered.reset_index(drop=True)

In [None]:
ce_filtered.mean()

Unnamed: 0,0
idx,63.408696
true,4.895652
pred,4.895652
GradCAM_SSIM_Robustness,0.513066
EigenCAM_SSIM_Robustness,0.566275
AblationCAM_SSIM_Robustness,0.647728


In [None]:
scl_filtered.mean()

Unnamed: 0,0
idx,63.408696
true,4.895652
pred,4.895652
GradCAM_SSIM_Robustness,0.591083
EigenCAM_SSIM_Robustness,0.613452
AblationCAM_SSIM_Robustness,0.599957
