# 1. Import Packages for the Environment

In [1]:
# Import basic packages for later use
import os
import shutil
from collections import OrderedDict

import json
import matplotlib.pyplot as plt
import nibabel as nib

import numpy as np
import torch

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
IN_KAGGLE = os.path.exists('/kaggle/input')
IN_COLAB = not IN_KAGGLE and os.path.exists('/content')
IN_DEIB = not IN_KAGGLE and not IN_COLAB

In [4]:
if not IN_DEIB:
    !pip install nnunetv2
    !pip install captum

# 2. Mount the dataset

In [5]:
from batchgenerators.utilities.file_and_folder_operations import join

if IN_COLAB:
    # Google Colab
    # for colab users only - mounting the drive

    from google.colab import drive
    drive.mount('/content/drive',force_remount = True)

    drive_dir = "/content/drive/My Drive"
    mount_dir = join(drive_dir, "tesi", "automi")
    base_dir = os.getcwd()
elif IN_KAGGLE:
    # Kaggle
    mount_dir = "/kaggle/input/"
    base_dir = os.getcwd()
    print(base_dir)
    !ls '/kaggle/input'
    !cd "/kaggle/input/automi-seg" ; ls
else:
    mount_dir = "/workspace/data"
    base_dir = "/workspace/output"
    os.chdir(base_dir)
    print("base_dir:", base_dir)

base_dir: /workspace/output


# 3. Setting up nnU-Nets folder structure and environment variables
nnUnet expects a certain folder structure and environment variables.

Roughly they tell nnUnet:
1. Where to look for stuff
2. Where to put stuff

For more information about this please check: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/setting_up_paths.md

## 3.1 Set environment Variables and creating folders

In [6]:
# ===========================
# 📦 SETUP nnUNet ENVIRONMENT
# ===========================

# Definisci i path da settare
path_dict = {
    "nnUNet_raw": join(mount_dir, "nnunet_raw"),
    "nnUNet_preprocessed": join(mount_dir, "preprocessed_files"),#"nnUNet_preprocessed"),
    "nnUNet_results": join(mount_dir, "results"),#"nnUNet_results"),
    # "RAW_DATA_PATH": join(mount_dir, "RawData"),  # Facoltativo, se ti serve salvare zips
}

# Scrivi i path nelle variabili di ambiente, che vengono lette dal modulo paths di nnunetv2
for env_var, path in path_dict.items():
    os.environ[env_var] = path

from nnunetv2.paths import nnUNet_results, nnUNet_raw

if IN_KAGGLE:
    if nnUNet_raw == None:
        nnUNet_raw = "/kaggle/input/nnunet_raw"
    if nnUNet_results == None:
        nnUNet_results = "/kaggle/input/results"
    # Kaggle has some very unconsistent behaviors in dataset mounting...
    #nnUNet_raw = "/kaggle/input/automi-seg/nnunet_raw"
    #nnUNet_results = "/kaggle/input/automi-seg/results"
    
print("nnUNet_raw:", nnUNet_raw)
print("nnUNet_results:", nnUNet_results)

nnUNet_raw: /workspace/data/nnunet_raw
nnUNet_results: /workspace/data/results


### Some tests

In [7]:
if IN_KAGGLE:
    ct_img_path = join(nnUNet_raw, "imagesTr", "AUTOMI_00039_0000.nii")
    organ_mask_path = join(nnUNet_raw, "total_segmentator_structures", "AUTOMI_00039_0000", "mask_mask_add_input_20_total_segmentator.nii")
else:
    ct_img_path = join(nnUNet_raw, "imagesTr", "AUTOMI_00039_0000.nii.gz")
    organ_mask_path = join(nnUNet_raw, "total_segmentator_structures", "AUTOMI_00039_0000", "mask_mask_add_input_20_total_segmentator.nii.gz")
ct_img = nib.load(ct_img_path)
organ_mask = nib.load(organ_mask_path)

In [8]:
print("CT shape:", ct_img.shape)
print("Organ shape:", organ_mask.shape)
print("Spacing:", ct_img.header.get_zooms())
print("Organ spacing:", organ_mask.header.get_zooms())

CT shape: (512, 512, 283)
Organ shape: (512, 512, 283)
Spacing: (1.171875, 1.171875, 5.0)
Organ spacing: (1.171875, 1.171875, 5.0)


## Re-align CT scan with its own organ segmentation mask

In [9]:
import SimpleITK as sitk

# Load CT and misaligned organ mask
ct = sitk.ReadImage(ct_img_path, sitk.sitkFloat32)
organ_mask = sitk.ReadImage(organ_mask_path, sitk.sitkUInt8)

# Resample organ mask to match CT space
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(ct)
resampler.SetInterpolator(sitk.sitkNearestNeighbor)
organ_resampled = resampler.Execute(organ_mask)

# Save aligned output
sitk.WriteImage(organ_resampled, "organ_mask_resampled_to_ct.nii.gz")

In [10]:
#organ_mask_path = join(nnUNet_raw, "organ_mask_resampled_to_ct.nii.gz")
organ_mask_path = "organ_mask_resampled_to_ct.nii.gz"
ct_img = nib.load(ct_img_path)
organ_mask = nib.load(organ_mask_path)
print("CT shape:", ct_img.shape)
print("Organ shape:", organ_mask.shape)
print("Spacing:", ct_img.header.get_zooms())
print("Organ spacing:", organ_mask.header.get_zooms())

CT shape: (512, 512, 283)
Organ shape: (512, 512, 283)
Spacing: (1.171875, 1.171875, 5.0)
Organ spacing: (1.171875, 1.171875, 5.0)


In [11]:
# model directory; note that this is readonly in Kaggle environment
if IN_KAGGLE:
    model_dir = join(nnUNet_results, 'Dataset003_AUTOMI_CTVLNF_NEWGL_results/nnUNetTrainer__nnUNetPlans__3d_fullres')
else:
    model_dir = join(nnUNet_results, 'nnUNetTrainer__nnUNetPlans__3d_fullres')

## Utility to export logits to a visualizable segmentation

In [12]:
import numpy as np
import torch
import os
from typing import Union
from pathlib import Path
from nnunetv2.configuration import default_num_processes
from nnunetv2.inference.export_prediction import export_prediction_from_logits

def export_logits_to_nifty_segmentation(
    predictor,
    volume_file: Path,
    model_dir: str,
    logits: Union[str, np.ndarray, torch.Tensor],
    npz_dir: str | None,
    output_dir: str = "",
    fold: int = 0,
    save_probs: bool = False,
    from_file: bool = True
):
    """
    Converts a saved .npz logits file into a native-space NIfTI segmentation using nnU-Net's helper.

    Args:
        predictor: An instantiated nnU-Net predictor object with loaded plans/configs.
        volume_file: An Path object pointing to the raw image file.
        model_dir (str): Path to the nnU-Net mode1Introductionl directory containing dataset.json.
        logits (str or tensor): Base name of the .npz logits file (no extension), when from_file=True
        npz_dir (str): Directory where the .npz file is stored.
        output_dir (str): Directory where the .nii.gz segmentation will be saved.
        fold (int): The fold number used for prediction (default is 0).
        save_probs (bool): Whether to save softmax probabilities as a .npz file.
        from_file (bool). Whether to convert from a file instead of from the logits (default true)
    """
    if from_file:
        npz_logits = Path(npz_dir) / f"{logits}.npz"
        output_nii = Path(output_dir) / f"{logits}_seg.nii.gz"
        logits = np.load(npz_logits)["logits"]
    else:
        output_nii = Path(output_dir) / "exported_seg.nii.gz"

    plans_manager = predictor.plans_manager
    configuration_manager = predictor.configuration_manager
    dataset_json = Path(model_dir) / "dataset.json"

    preprocessor = configuration_manager.preprocessor_class(verbose=False)
    rw = plans_manager.image_reader_writer_class()
    if callable(rw) and not hasattr(rw, "read_images"):
        rw = rw()
    img_np, img_props = rw.read_images([str(volume_file)])

    _, _, data_props = preprocessor.run_case_npy(
        img_np, seg=None, properties=img_props,
        plans_manager=plans_manager,
        configuration_manager=configuration_manager,
        dataset_json=dataset_json
    )


    export_prediction_from_logits(
        predicted_array_or_file=logits,
        properties_dict=data_props,
        configuration_manager=configuration_manager,
        plans_manager=plans_manager,
        dataset_json_dict_or_file=str(dataset_json),
        output_file_truncated=os.path.splitext(str(output_nii))[0],
        save_probabilities=save_probs,
        num_threads_torch=default_num_processes
    )

    print(f"✅  NIfTI segmentation written → {output_nii}")

In [13]:
"""export_logits_to_nifty_segmentation(
    predictor=predictor,
    plan=plan,
    model_dir=Path(model_dir),
    logits_filename="pred_00007",
    npz_dir="SHAP/shap_run",
    output_dir="SHAP/shap_run",
    fold=0,
    save_probs=False
)"""

'export_logits_to_nifty_segmentation(\n    predictor=predictor,\n    plan=plan,\n    model_dir=Path(model_dir),\n    logits_filename="pred_00007",\n    npz_dir="SHAP/shap_run",\n    output_dir="SHAP/shap_run",\n    fold=0,\n    save_probs=False\n)'

## We define a sliding window caching for faster multi-inference scenario, like SHAP

### Try to override the sliding_window_function

In [14]:
from typing import Union
import numpy as np
import torch
from tqdm import tqdm
from queue import Queue
from threading import Thread
from acvl_utils.cropping_and_padding.padding import pad_nd_image
from nnunetv2.utilities.helpers import empty_cache, dummy_context
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.inference.sliding_window_prediction import compute_gaussian, compute_steps_for_sliding_window

class CustomNNUNetPredictor(nnUNetPredictor):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    @torch.inference_mode()
    def predict_sliding_window_return_logits_with_caching(self, input_image: torch.Tensor,
                                                          perturbation_mask: torch.BoolTensor | None,
                                                          baseline_prediction_dict: dict) \
            -> Union[np.ndarray, torch.Tensor]:
        """
        Method predict_sliding_window_return_logits taken from official nnunetv2 documentation:
        https://github.com/MIC-DKFZ/nnUNet/blob/58a3b121a6d1846a978306f6c79a7c005b7d669b/nnunetv2/inference/predict_from_raw_data.py
        We add a perturbation_mask parameter to check each patch for the actual presence of a perturbation
        """
        # fallback to original method if perturbation_mask is None
        if perturbation_mask is None:
            return self.predict_sliding_window_return_logits(input_image)
                
        assert isinstance(input_image, torch.Tensor)
        self.network = self.network.to(self.device)
        self.network.eval()

        empty_cache(self.device)

        # DEBUG --------------
        """voxels  = np.prod(input_image.shape[1:])          # (X*Y*Z)
        bytes_per_voxel = 2                        # fp16
        needed  = voxels * self.label_manager.num_segmentation_heads * bytes_per_voxel
        print(f"≈{needed/1e9:.1f} GB per predicted_logits")"""

        # Autocast can be annoying
        # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection)
        # and needs to be disabled.
        # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False
        # is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
        # So autocast will only be active if we have a cuda device.
        with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
            assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'

            if self.verbose:
                print(f'Input shape: {input_image.shape}')
                print("step_size:", self.tile_step_size)
                print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)
                print(f'Perturbation mask shape: {perturbation_mask.shape}')


            # if input_image is smaller than tile_size we need to pad it to tile_size.
            data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
                                                       'constant', {'value': 0}, True,
                                                       None)

            # slicers can be applied to both perturbed volume and 
            slicers = self._internal_get_sliding_window_slicers(data.shape[1:])

            if self.perform_everything_on_device and self.device != 'cpu':
                # behavior changed
                try:
                    predicted_logits = self._internal_predict_sliding_window_return_logits(
                        data, slicers, True, perturbation_mask, baseline_prediction_dict, caching=True
                    )
                except RuntimeError as e:
                    if "CUDA out of memory" in str(e):
                        print("⚠️  CUDA OOM, cambiare batch size o patch size!")
                        raise
                    else:
                        # Mostra l'errore reale e aborta: niente CPU fallback
                        raise
            else:
                predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
                                                                                       self.perform_everything_on_device)

            empty_cache(self.device)
            # revert padding
            predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
        return predicted_logits
                

    def _slice_key(self, slicer_tuple):
        # make slicer object hashable to use it for cache lookup
        return tuple((s.start, s.stop, s.step) for s in slicer_tuple)


    @torch.inference_mode()
    def _internal_predict_sliding_window_return_logits(self,
                                                       data: torch.Tensor,
                                                       slicers,
                                                       do_on_device: bool = True,
                                                       perturbation_mask: torch.BoolTensor | None = None,
                                                       baseline_prediction_dict: dict | None = None,
                                                       caching: bool = False,
                                                       ):
        """
        Modified to manage the caching of patches
        """
        predicted_logits = n_predictions = prediction = gaussian = workon = None
        results_device = self.device if do_on_device else torch.device('cpu')
        if next(self.network.parameters()).device != results_device:
            self.network = self.network.to(results_device)

        def producer(d, slh, q):
            for s in slh:
                q.put((torch.clone(d[s][None], memory_format=torch.contiguous_format).to(results_device), s))
            q.put('end')

        try:
            empty_cache(self.device)

            # move data to device
            if self.verbose:
                print(f'move image to device {results_device}')
            data = data.to(results_device)
            queue = Queue(maxsize=2)
            t = Thread(target=producer, args=(data, slicers, queue))
            t.start()

            # preallocate arrays
            if self.verbose:
                print(f'preallocating results arrays on device {results_device}')
            predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
                                           dtype=torch.half,
                                           device=results_device)
            n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)

            if self.use_gaussian:
                gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
                                            value_scaling_factor=10,
                                            device=results_device)
            else:
                gaussian = 1

        

            if not self.allow_tqdm and self.verbose:
                print(f'running prediction: {len(slicers)} steps')

            with tqdm(desc=None, total=len(slicers), disable=not self.allow_tqdm) as pbar:
                cache_hits = 0
                while True:
                    item = queue.get()
                    if item == 'end':
                        queue.task_done()
                        break
                    workon, sl = item
                    try:
                        if caching and not self.check_overlapping(sl, perturbation_mask):
                            prediction = baseline_prediction_dict[self._slice_key(sl)].to(results_device)
                            cache_hits += 1
                        else:
                            prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)
                    except Exception as e:
                        raise RuntimeError("Errore nella predizione del patch") from e

                    # 2) sanity-check device
                    assert prediction.device == predicted_logits.device

                    if self.use_gaussian:
                        prediction *= gaussian
                    predicted_logits[sl] += prediction
                    n_predictions[sl[1:]] += gaussian

                    # free up gpu memory
                    del prediction, workon
                    
                    queue.task_done()
                    pbar.set_postfix(
                        cache=f"{cache_hits}",
                        mem=f"{torch.cuda.memory_allocated()/1e9:.2f} GB"
                    )
                    pbar.update(1)
            queue.join()
            if self.verbose and not self.allow_tqdm:
                print(f"Cache hits: {cache_hits}\\{len(slicers)}")
            

            # predicted_logits /= n_predictions
            torch.div(predicted_logits, n_predictions, out=predicted_logits)
            # check for infs
            if torch.any(torch.isinf(predicted_logits)):
                raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '
                                   'reduce value_scaling_factor in compute_gaussian or increase the dtype of '
                                   'predicted_logits to fp32')
        except Exception as e:
            del predicted_logits, n_predictions, prediction, gaussian, workon
            empty_cache(self.device)
            empty_cache(results_device)
            raise e
        return predicted_logits
  


    def get_output_dictionary_sliding_window(self, data: torch.Tensor, slicers,
                                            do_on_device: bool = True,
                                            ) -> torch.Tensor:
        """
        # create a dictionary that associates the output of the inference, to each slicer of the sliding window module
        # this way we can set ready for cache the output for the untouched patches.
        """
        
        dictionary = dict()
        prediction = workon = None
        results_device = self.device if do_on_device else torch.device('cpu')
        if next(self.network.parameters()).device != results_device:
            self.network = self.network.to(results_device)

        def producer(d, slh, q):
            for s in slh:
                #tqdm.write(f"put patch {s} on queue")    # dentro producer
                q.put((torch.clone(d[s][None], memory_format=torch.contiguous_format).to(self.device), s))
            q.put('end')

        try:
            empty_cache(self.device)

            # move data and network to device
            if self.verbose:
                print(f'move image and model to device {results_device}')

            self.network = self.network.to(results_device)
            data = data.to(results_device)
            queue = Queue(maxsize=2)
            t = Thread(target=producer, args=(data, slicers, queue))
            t.start()

            if not self.allow_tqdm and self.verbose:
                print(f'running prediction: {len(slicers)} steps')

            with tqdm(desc=None, total=len(slicers), disable=not self.allow_tqdm) as pbar:
                while True:
                    item = queue.get()
                    if item == 'end':
                        queue.task_done()
                        break
                    workon, sl = item
                    pred_gpu = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)

                    pred_cpu = pred_gpu.cpu()
                    # save prediction in the dictionary
                    dictionary[self._slice_key(sl)] = pred_cpu
                    # immediately free gpu memory
                    del pred_gpu
                    
                    queue.task_done()
                    pbar.update()
            queue.join()

        except Exception as e:
            del workon#, prediction
            empty_cache(self.device)
            empty_cache(results_device)
            raise e
        return dictionary

        try:
            empty_cache(self.device)

            # move data and network to device
            if self.verbose:
                print(f'move image and model to device {results_device}')

            self.network = self.network.to(results_device)
            data = data.to(results_device)
            queue = Queue(maxsize=2)
            t = Thread(target=producer, args=(data, slicers, queue))
            t.start()

            if not self.allow_tqdm and self.verbose:
                print(f'running prediction: {len(slicers)} steps')

            with tqdm(desc=None, total=len(slicers), disable=not self.allow_tqdm) as pbar:
                while True:
                    item = queue.get()
                    if item == 'end':
                        queue.task_done()
                        break
                    workon, sl = item
                    pred_gpu = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)

                    pred_cpu = pred_gpu.cpu()
                    # save prediction in the dictionary
                    dictionary[self._slice_key(sl)] = pred_cpu
                    # immediately free gpu memory
                    del pred_gpu
                    
                    queue.task_done()
                    pbar.update()
            queue.join()

        except Exception as e:
            del workon#, prediction
            empty_cache(self.device)
            empty_cache(results_device)
            raise e
        return dictionary

    def check_overlapping(self, slicer, perturbation_mask: torch.BoolTensor) -> bool:
        """
        Restituisce True se la patch definita da `slicer`
        contiene almeno un voxel perturbato.
    
        Parameters
        ----------
        slicer : tuple
            Quello prodotto da `_internal_get_sliding_window_slicers`,
            cioè (slice(None), slice(x0,x1), slice(y0,y1), slice(z0,z1)).
        perturbation_mask : torch.BoolTensor
            Maschera (C, X, Y, Z) con True nei voxel da perturbare
            (di solito C==1 o replicata sui canali).
    
        Returns
        -------
        bool
            True ↔ almeno un voxel True nella patch.out
        """
        # NB: il primo elemento del tuple è sempre slice(None) (canali).
        #     Lo manteniamo: non ha overhead e semplifica.
        return perturbation_mask[slicer].any().item()

## Try Captum's kernel SHAP on the organ mask

### first derive a customized class from Captum library, to use sliding window caching

## Try to customize KernelShap as a "sibling", so let's inherit the parent, LimeBase
that's because we need to override (to-and-from)/interpret_rep_transform methods used to map the (1,M) binary mask vector with the perturbed volume AND the perturbation mask we need for caching

In [15]:
#!/usr/bin/env python3

# pyre-strict
import inspect
import math
import typing
import warnings
from collections.abc import Iterator
from typing import Any, Callable, cast, List, Literal, Optional, Tuple, Union

import torch
from captum._utils.common import (
    _expand_additional_forward_args,
    _expand_target,
    _flatten_tensor_or_tuple,
    _format_output,
    _format_tensor_into_tuples,
    _get_max_feature_index,
    _is_tuple,
    _reduce_list,
    _run_forward,
)
from captum._utils.models.linear_model import SkLearnLasso
from captum._utils.models.model import Model
from captum._utils.progress import progress
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.batching import _batch_example_iterator
from captum.attr._utils.common import (
    _construct_default_feature_mask,
    _format_input_baseline,
)
from captum.log import log_usage
from torch import Tensor, BoolTensor
from torch.nn import CosineSimilarity
from torch.utils.data import DataLoader, TensorDataset


class LimeBaseWithCustomArgumentToForwardFunc(PerturbationAttribution):
    r"""
    Here we create a modification of Lime class from Captum Library (https://captum.ai/api/_modules/captum/attr/_core/lime.html)
    """

    def __init__(
        self,
        forward_func: Callable[..., Tensor],
        interpretable_model: Model,
        similarity_func: Callable[
            ...,
            Union[float, Tensor],
        ],
        perturb_func: Callable[..., object],
        perturb_interpretable_space: bool,
        from_interp_rep_transform: Optional[
            Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
        ],
        to_interp_rep_transform: Optional[Callable[..., Tensor]],
    ) -> None:
        r"""

        Args:


            forward_func (Callable): The forward function of the model or any
                    modification of it. If a batch is provided as input for
                    attribution, it is expected that forward_func returns a scalar
                    representing the entire batch.
            interpretable_model (Model): Model object to train interpretable model.
                    A Model object provides a `fit` method to train the model,
                    given a dataloader, with batches containing three tensors:

                    - interpretable_inputs: Tensor
                      [2D num_samples x num_interp_features],
                    - expected_outputs: Tensor [1D num_samples],
                    - weights: Tensor [1D num_samples]

                    The model object must also provide a `representation` method to
                    access the appropriate coefficients or representation of the
                    interpretable model after fitting.
                    Some predefined interpretable linear models are provided in
                    captum._utils.models.linear_model including wrappers around
                    SkLearn linear models as well as SGD-based PyTorch linear
                    models.

                    Note that calling fit multiple times should retrain the
                    interpretable model, each attribution call reuses
                    the same given interpretable model object.
            similarity_func (Callable): Function which takes a single sample
                    along with its corresponding interpretable representation
                    and returns the weight of the interpretable sample for
                    training interpretable model. Weight is generally
                    determined based on similarity to the original input.
                    The original paper refers to this as a similarity kernel.

                    The expected signature of this callable is:

                    >>> similarity_func(
                    >>>    original_input: Tensor or tuple[Tensor, ...],
                    >>>    perturbed_input: Tensor or tuple[Tensor, ...],
                    >>>    perturbed_interpretable_input:
                    >>>        Tensor [2D 1 x num_interp_features],
                    >>>    **kwargs: Any
                    >>> ) -> float or Tensor containing float scalar

                    perturbed_input and original_input will be the same type and
                    contain tensors of the same shape (regardless of whether or not
                    the sampling function returns inputs in the interpretable
                    space). original_input is the same as the input provided
                    when calling attribute.

                    All kwargs passed to the attribute method are
                    provided as keyword arguments (kwargs) to this callable.
            perturb_func (Callable): Function which returns a single
                    sampled input, generally a perturbation of the original
                    input, which is used to train the interpretable surrogate
                    model. Function can return samples in either
                    the original input space (matching type and tensor shapes
                    of original input) or in the interpretable input space,
                    which is a vector containing the intepretable features.
                    Alternatively, this function can return a generator
                    yielding samples to train the interpretable surrogate
                    model, and n_samples perturbations will be sampled
                    from this generator.

                    The expected signature of this callable is:

                    >>> perturb_func(
                    >>>    original_input: Tensor or tuple[Tensor, ...],
                    >>>    **kwargs: Any
                    >>> ) -> Tensor, tuple[Tensor, ...], or
                    >>>    generator yielding tensor or tuple[Tensor, ...]

                    All kwargs passed to the attribute method are
                    provided as keyword arguments (kwargs) to this callable.

                    Returned sampled input should match the input type (Tensor
                    or Tuple of Tensor and corresponding shapes) if
                    perturb_interpretable_space = False. If
                    perturb_interpretable_space = True, the return type should
                    be a single tensor of shape 1 x num_interp_features,
                    corresponding to the representation of the
                    sample to train the interpretable model.

                    All kwargs passed to the attribute method are
                    provided as keyword arguments (kwargs) to this callable.
            perturb_interpretable_space (bool): Indicates whether
                    perturb_func returns a sample in the interpretable space
                    (tensor of shape 1 x num_interp_features) or a sample
                    in the original space, matching the format of the original
                    input. Once sampled, inputs can be converted to / from
                    the interpretable representation with either
                    to_interp_rep_transform or from_interp_rep_transform.
            from_interp_rep_transform (Callable): Function which takes a
                    single sampled interpretable representation (tensor
                    of shape 1 x num_interp_features) and returns
                    the corresponding representation in the input space
                    (matching shapes of original input to attribute).

                    This argument is necessary if perturb_interpretable_space
                    is True, otherwise None can be provided for this argument.

                    The expected signature of this callable is:

                    >>> from_interp_rep_transform(
                    >>>    curr_sample: Tensor [2D 1 x num_interp_features]
                    >>>    original_input: Tensor or Tuple of Tensors,
                    >>>    **kwargs: Any
                    >>> ) -> Tensor or tuple[Tensor, ...]

                    Returned sampled input should match the type of original_input
                    and corresponding tensor shapes.

                    All kwargs passed to the attribute method are
                    provided as keyword arguments (kwargs) to this callable.

            to_interp_rep_transform (Callable): Function which takes a
                    sample in the original input space and converts to
                    its interpretable representation (tensor
                    of shape 1 x num_interp_features).

                    This argument is necessary if perturb_interpretable_space
                    is False, otherwise None can be provided for this argument.

                    The expected signature of this callable is:

                    >>> to_interp_rep_transform(
                    >>>    curr_sample: Tensor or Tuple of Tensors,
                    >>>    original_input: Tensor or Tuple of Tensors,
                    >>>    **kwargs: Any
                    >>> ) -> Tensor [2D 1 x num_interp_features]

                    curr_sample will match the type of original_input
                    and corresponding tensor shapes.

                    All kwargs passed to the attribute method are
                    provided as keyword arguments (kwargs) to this callable.
        """
        PerturbationAttribution.__init__(self, forward_func)
        self.interpretable_model = interpretable_model
        self.similarity_func = similarity_func
        self.perturb_func = perturb_func
        self.perturb_interpretable_space = perturb_interpretable_space
        self.from_interp_rep_transform = from_interp_rep_transform
        self.to_interp_rep_transform = to_interp_rep_transform

        if self.perturb_interpretable_space:
            assert (
                self.from_interp_rep_transform is not None
            ), "Must provide transform from interpretable space to original input space"
            " when sampling from interpretable space."
        else:
            assert (
                self.to_interp_rep_transform is not None
            ), "Must provide transform from original input space to interpretable space"

    @log_usage(part_of_slo=True)
    @torch.no_grad()
    def attribute(
        self,
        inputs: TensorOrTupleOfTensorsGeneric,
        target: TargetType = None,
        additional_forward_args: Optional[Tuple[object, ...]] = None,
        n_samples: int = 50,
        perturbations_per_eval: int = 1,
        show_progress: bool = False,
        # --- MONITOR CONVERGENCE QUALITY
        monitor_log_path: str | None = None,
        monitor_convergence_step: int | None = 20,
        monitor_local_accuracy_step: int | None = 50,
        **kwargs: object,
    ) -> Tensor:
        r"""
        This method attributes the output of the model with given target index
        (in case it is provided, otherwise it assumes that output is a
        scalar) to the inputs of the model using the approach described above.
        It trains an interpretable model and returns a representation of the
        interpretable model.

        It is recommended to only provide a single example as input (tensors
        with first dimension or batch size = 1). This is because LIME is generally
        used for sample-based interpretability, training a separate interpretable
        model to explain a model's prediction on each individual example.

        A batch of inputs can be provided as inputs only if forward_func
        returns a single value per batch (e.g. loss).
        The interpretable feature representation should still have shape
        1 x num_interp_features, corresponding to the interpretable
        representation for the full batch, and perturbations_per_eval
        must be set to 1.

        Args:

            inputs (Tensor or tuple[Tensor, ...]): Input for which LIME
                        is computed. If forward_func takes a single
                        tensor as input, a single input tensor should be provided.
                        If forward_func takes multiple tensors as input, a tuple
                        of the input tensors should be provided. It is assumed
                        that for all given input tensors, dimension 0 corresponds
                        to the number of examples, and if multiple input tensors
                        are provided, the examples must be aligned appropriately.
            target (int, tuple, Tensor, or list, optional): Output indices for
                        which surrogate model is trained
                        (for classification cases,
                        this is usually the target class).
                        If the network returns a scalar value per example,
                        no target index is necessary.
                        For general 2D outputs, targets can be either:

                        - a single integer or a tensor containing a single
                          integer, which is applied to all input examples

                        - a list of integers or a 1D tensor, with length matching
                          the number of examples in inputs (dim 0). Each integer
                          is applied as the target for the corresponding example.

                        For outputs w            except --------ith > 2 dimensions, targets can be either:

                        - A single tuple, which contains #output_dims - 1
                          elements. This target index is applied to all examples.

                        - A list of tuples with length equal to the number of
                          examples in inputs (dim 0), and each tuple containing
                          #output_dims - 1 elements. Each tuple is applied as the
                          target for the corresponding example.

                        Default: None
            additional_forward_args (Any, optional): If the forward function
                        requires additional arguments other than the inputs for
                        which attributions should not be computed, this argument
                        can be provided. It must be either a single additional
                        argument of a Tensor or arbitrary (non-tuple) type or a
                        tuple containing multiple additional arguments including
                        tensors or any arbitrary python types. These arguments
                        are provided to forward_func in order following the
                        arguments in inputs.
                        For a tensor, the first dimension of the tensor must
                        correspond to the number of examples. For all other types,
                        the given argument is used for all forward evaluations.
                        Note that attributions are not computed with respect
                        to these arguments.
                        Default: None
            n_samples (int, optional): The number of samples of the original
                        model used to train the surrogate interpretable model.
                        Default: `50` if `n_samples` is not provided.
            perturbations_per_eval (int, optional): Allows multiple samples
                        to be processed simultaneously in one call to forward_fn.
                        Each forward pass will contain a maximum of
                        perturbations_per_eval * #examples samples.
                        For DataParallel models, each batch is split among the
                        available devices, so evaluations on each available
                        device contain at most
                        (perturbations_per_eval * #examples) / num_devices
                        samples.
                        If the forward function returns a single scalar per batch,
                        perturbations_per_eval must be set to 1.
                        Default: 1
            show_progress (bool, optional): Displays the progress of computation.
                        It will try to use tqdm if available for advanced features
                        (e.g. time estimation). Otherwise, it will fallback to
                        a simple output of progress.
                        Default: False
            monitor_log_path (str, optional): Path to the log file for monitoring convergence.
                        if None, no monitoring is performed.
                        Default: None
            monitor_convergence_step (int, optional): Number of iterations over which
                        the difference among two attribution is computerd.
                        Default: 20
            monitor_local_accuracy_step (int, optional): Number of iterations over which
                        the local accuracy of an attribution is computerd.
                        Default: 50
            **kwargs (Any, optional): Any additional arguments necessary for
                        sampling and transformation functions (provided to
                        constructor).
                        Default: None

        Returns:
            **interpretable model representation**:
            - **interpretable model representation** (*Any*):
                    A representation of the interpretable model trained. The return
                    type matches the return type of train_interpretable_model_func.
                    For example, this could contain coefficients of a
                    linear surrogate model.

        Examples::

            >>> # SimpleClassifier takes a single input tensor of
            >>> # float features with size N x 5,
            >>> # and returns an Nx3 tensor of class probabilities.
            >>> net = SimpleClassifier()
            >>>
            >>> # We will train an interpretable model with the same
            >>> # features by simply sampling with added Gaussian noise
            >>> # to the inputs and training a model to predict the
            >>> # score of the target class.
            >>>
            >>> # For interpretable model training, we will use sklearn
            >>> # linear model in this example. We have provided wrappers
            >>> # around sklearn linear models to fit the Model interface.
            >>> # Any arguments provided to the sklearn constructor can also
            >>> # be provided to the wrapper, e.g.:
            >>> # SkLearnLinearModel("linear_model.Ridge", alpha=2.0)
            >>> from captum._utils.models.linear_model import SkLearnLinearModel
            >>>
            >>>
            >>> # Define similarity kernel (exponential kernel based on L2 norm)
            >>> def similarity_kernel(
            >>>     original_input: Tensor,
            >>>     perturbed_input: Tensor,
            >>>     perturbed_interpretable_input: Tensor,
            >>>     **kwargs)->Tensor:
            >>>         # kernel_width will be provided to attribute as a kwarg
            >>>         kernel_width = kwargs["kernel_width"]
            >>>         l2_dist = torch.norm(original_input - perturbed_input)
            >>>         return torch.exp(- (l2_dist**2) / (kernel_width**2))
            >>>
            >>>
            >>> # Define sampling function
            >>> # This function samples in original input space
            >>> def perturb_func(
            >>>     original_input: Tensor,
            >>>     **kwargs)->Tensor:
            >>>         return original_input + torch.randn_like(original_input)
            >>>
            >>> # For this example, we are setting the interpretable input to
            >>> # match the model input, so the to_interp_rep_transform
            >>> # function simply returns the input. In most cases, the interpretable
            >>> # input will be different and may have a smaller feature set, so
            >>> # an appropriate transformation function should be provided.
            >>>
            >>> def to_interp_transform(curr_sample, original_inp,
            >>>                                      **kwargs):
            >>>     return curr_sample\
            >>>
            >>> # Generating random input with size 1 x 5
            >>> input = torch.randn(1, 5)
            >>> # Defining LimeBase interpreter
            >>> lime_attr = LimeBase(net,
                                     SkLearnLinearModel("linear_model.Ridge"),
                                     similarity_func=similarity_kernel,
                                     perturb_func=perturb_func,
                                     perturb_interpretable_space=False,
                                     from_interp_rep_transform=None,
                                     to_interp_rep_transform=to_interp_transform)
            >>> # Computes interpretable model, returning coefficients of linear
            >>> # model.
            >>> attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1)
        """
        inp_tensor = cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0]
        device = inp_tensor.device
        
        # --------------------------------------------------------------------- #
        # 1.  Lists that grow while we sample                                   #
        # --------------------------------------------------------------------- #
        interpretable_inps, similarities, outputs = [], [], []
        
        curr_model_inputs = []
        expanded_additional_args = None
        expanded_target = None
        gen_perturb_func = self._get_perturb_generator_func(inputs, **kwargs)

        # --------------------------------------------------------------------- #
        # 2.  Monitoring initialisation                                         #
        # --------------------------------------------------------------------- #
        MONITOR = monitor_log_path is not None
        if MONITOR:
            beta_prev = None
            k_monitor_conv   = monitor_convergence_step
            k_monitor_delta  = monitor_local_accuracy_step
            logf = open(monitor_log_path, "a")
        
        if show_progress:
            attr_progress = progress(
                total=math.ceil(n_samples / perturbations_per_eval),
                desc=f"{self.get_name()} attribution",
            )
            attr_progress.update(0)
        
        feature_mask = kwargs["feature_mask"]
        batch_count = 0

        # --------------------------------------------------------------------- #
        # 3.  Main sampling loop                                                #
        # --------------------------------------------------------------------- #
        for _ in range(n_samples):
            try:
                interpretable_inp, curr_model_input = gen_perturb_func()
                perturbation_mask = self._get_perturbation_mask(
                    interpretable_inp, curr_model_input, feature_mask
                )
            except StopIteration:
                warnings.warn(
                    "Generator completed prior to given n_samples iterations!",
                    stacklevel=1,
                )
                break
            except Exception:
                print("error in the perturbation mask generation")
                raise
        
            # ------------ Build forward args with mask ------------------------ #
            if additional_forward_args is None:
                additional_forward_args_with_mask = (perturbation_mask,)
            elif isinstance(additional_forward_args, tuple):
                additional_forward_args_with_mask = additional_forward_args + (perturbation_mask,)
            else:
                additional_forward_args_with_mask = (additional_forward_args, perturbation_mask)

            # ------------ Book-keeping per sample ----------------------------- #
            batch_count += 1
            interpretable_inps.append(interpretable_inp)
            curr_model_inputs.append(curr_model_input)
        
            curr_sim = self.similarity_func(inputs, curr_model_input, interpretable_inp, **kwargs)
            similarities.append(
                curr_sim.flatten()
                if isinstance(curr_sim, Tensor)
                else torch.tensor([curr_sim], device=device)
            )

            # ------------ When we have one evaluation batch ready ------------- #
            if len(curr_model_inputs) == perturbations_per_eval:
                expanded_additional_args = _expand_additional_forward_args(
                    additional_forward_args_with_mask, len(curr_model_inputs)
                )
                if expanded_target is None:
                    expanded_target = _expand_target(target, len(curr_model_inputs))
        
                model_out = self._evaluate_batch(
                    curr_model_inputs,
                    expanded_target,
                    expanded_additional_args,
                    device,
                )
                if show_progress:
                    attr_progress.update()
                outputs.append(model_out)
        
                curr_model_inputs = []

                # =============================================================== #
                # === MONITORING: re-fit & log when a checkpoint is reached ===== #
                # =============================================================== #
                if MONITOR and (
                    len(interpretable_inps) % k_monitor_conv == 0
                    or len(interpretable_inps) % k_monitor_delta == 0
                ):
                    ### MONITOR BEGIN
                    # build DataLoader with *all* samples so far
                    X = torch.cat(interpretable_inps).float()
                    y = (
                        torch.cat(outputs)
                        if len(outputs[0].shape) > 0
                        else torch.stack(outputs)
                    ).float()
                    w = torch.cat(similarities).float()
        
                    dl_mon = DataLoader(
                        TensorDataset(X, y, w), batch_size=len(X)
                    )
        
                    # one API call → fast enough for ≤ few 100 samples
                    self.interpretable_model.fit(dl_mon)
        
                    # ---------- obtain coefficients as a clean 1-D tensor -----------------
                    rep = self.interpretable_model.representation()        # may be Tensor / np / list
                    
                    beta_cur = (
                        rep.flatten().to("cpu")                         # if already a Tensor
                        if isinstance(rep, torch.Tensor)
                        else torch.as_tensor(rep, dtype=torch.float32, device="cpu").flatten() # keep these small vectors in CPU
                    )
                    phi0, phis = beta_cur[0].item(), beta_cur[1:]          # scalar + 1-D tensor
                    # ----------------------------------------------------------------------
                    
                    # ---- Convergence distance -------------------------------------------
                    if len(interpretable_inps) % k_monitor_conv == 0 and beta_prev is not None:
                        dist = torch.norm(beta_cur - beta_prev, p=1).item()
                        logf.write(json.dumps({
                            "iter": len(interpretable_inps),
                            "conv_dist_L1": dist
                        }) + "\n")

        
                    # ---- Local-accuracy residual ------------------------------ #
                    if len(interpretable_inps) % k_monitor_delta == 0:
                        # take the output of the unperturbed input (ASSUMES first sample is unperturbed, only true for KernelSHAP)
                        model_fwd_original = outputs[0]
                        fx = model_fwd_original.item() if torch.is_tensor(model_fwd_original) else model_fwd_original
                        delta = abs(fx - (phi0 + sum(phis))).item()
                        logf.write(json.dumps({
                            "iter": len(interpretable_inps),
                            "delta_shap": delta
                        }) + "\n")

                    # flush the log file
                    logf.flush()

                    # DEBUG store whole dataset each 10 iterations
                    import pickle
                    if len(interpretable_inps) % 10 == 0:
                        with open('dataset-autosave.pkl', 'wb') as file:
                            pickle.dump(dl_mon, file)

                    beta_prev = beta_cur.detach().cpu()
                    ### MONITOR END


        # --------------------------------------------------------------------- #
        # 4.  Flush any leftover mini-batch                                     #
        # --------------------------------------------------------------------- #
        if len(curr_model_inputs) > 0:
            expanded_additional_args = _expand_additional_forward_args(
                additional_forward_args_with_mask, len(curr_model_inputs)
            )
            expanded_target = _expand_target(target, len(curr_model_inputs))
        
            model_out = self._evaluate_batch(
                curr_model_inputs,
                expanded_target,
                expanded_additional_args,
                device,
            )
            if show_progress:
                attr_progress.update()
            outputs.append(model_out)
        
        if show_progress:
            attr_progress.close()

        # --------------------------------------------------------------------- #
        # 5.  Final fit on *all* samples                                        #
        # --------------------------------------------------------------------- #
        combined_interp_inps = torch.cat(interpretable_inps).float()
        combined_outputs = (
            torch.cat(outputs) if len(outputs[0].shape) > 0 else torch.stack(outputs)
        ).float()
        combined_sim = (
            torch.cat(similarities)
            if len(similarities[0].shape) > 0
            else torch.stack(similarities)
        ).float()
        
        self.dataset = TensorDataset(combined_interp_inps, combined_outputs, combined_sim)
        self.interpretable_model.fit(DataLoader(self.dataset, batch_size=batch_count))
        
        if MONITOR:
            logf.close()
        
        return self.interpretable_model.representation()



    def _get_perturbation_mask(
        self,
        interpretable_input: torch.Tensor,       # shape = (B, M)
        original_inputs: TensorOrTupleOfTensorsGeneric, # shape = (B, C, D, W, H) or tuple thereof
        feature_mask,
    ) -> Union[torch.BoolTensor, Tuple[torch.BoolTensor, ...]]:
        """
        Build a Boolean mask of shape (B, *input_dims) indicating which
        elements should be perturbed (True) vs. left untouched (False).
        """
    
        # Case 1: single‐Tensor input
        if isinstance(feature_mask, torch.Tensor):
            # advanced indexing over the batch dimension
            # result has shape (B, *feature_mask.shape)
            mask = interpretable_input[:, feature_mask]
            mask = ~mask.bool()

            return mask
    
        # Case 2: multi‐input (tuple) model
        else:
            masks = []
            for fm_i in feature_mask:
                mask_i = interpretable_input[:, fm_i]  # → (B, *fm_i.shape)
                masks.append(~mask_i.bool())
            return tuple(masks)

    

    def _get_perturb_generator_func(
        self, inputs: TensorOrTupleOfTensorsGeneric, **kwargs: Any
    ) -> Callable[
        [], Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
    ]:
        perturb_generator: Optional[Iterator[TensorOrTupleOfTensorsGeneric]]
        perturb_generator = None
        if inspect.isgeneratorfunction(self.perturb_func):
            perturb_generator = self.perturb_func(inputs, **kwargs)

        def generate_perturbation() -> (
            Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
        ):
            if perturb_generator:
                curr_sample = next(perturb_generator)
            else:
                curr_sample = self.perturb_func(inputs, **kwargs)

            if self.perturb_interpretable_space:
                interpretable_inp = curr_sample
                curr_model_input = self.from_interp_rep_transform(  # type: ignore
                    curr_sample, inputs, **kwargs
                )
            else:
                curr_model_input = curr_sample
                interpretable_inp = self.to_interp_rep_transform(  # type: ignore
                    curr_sample, inputs, **kwargs
                )

            return interpretable_inp, curr_model_input  # type: ignore

        return generate_perturbation

    # pyre-fixme[24] Generic type `Callable` expects 2 type parameters.)
    def attribute_future(self) -> Callable:
        r"""
        This method is not implemented for LimeBase.
        """
        raise NotImplementedError(
            "LimeBase does not support attribution of future samples."
        )

    def _evaluate_batch(
        self,
        curr_model_inputs: List[TensorOrTupleOfTensorsGeneric],
        expanded_target: TargetType,
        expanded_additional_args: object,
        device: torch.device,
    ) -> Tensor:
        model_out = _run_forward(
            self.forward_func,
            #MOMENTANEAOUS---> sliding_window forward function only works with single items (no batch) --> take first
            #_reduce_list(curr_model_inputs),
            _reduce_list(curr_model_inputs)[0],
            expanded_target,
            expanded_additional_args,
        )
        if isinstance(model_out, Tensor):
            assert model_out.numel() == len(curr_model_inputs), (
                "Number of outputs is not appropriate, must return "
                "one output per perturbed input"
            )
        if isinstance(model_out, Tensor):
            return model_out.flatten()
        return torch.tensor([model_out], device=device)

    def has_convergence_delta(self) -> bool:
        return False

    @property
    def multiplies_by_inputs(self) -> bool:
        return False


# Default transformations and methods
# for Lime child implementation.


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
    assert (
        "feature_mask" in kwargs
    ), "Must provide feature_mask to use default interpretable representation transform"
    assert (
        "baselines" in kwargs
    ), "Must provide baselines to use default interpretable representation transform"
    feature_mask = kwargs["feature_mask"]
    if isinstance(feature_mask, Tensor):
        binary_mask = curr_sample[0][feature_mask].bool()
        input_space_transformed = (
            binary_mask.to(original_inputs.dtype) * original_inputs
            + (~binary_mask).to(original_inputs.dtype) * kwargs["baselines"]
        )
        
        return input_space_transformed
    else:
        binary_mask = tuple(
            curr_sample[0][feature_mask[j]].bool() for j in range(len(feature_mask))
        )
        input_space_transformed = tuple(
            binary_mask[j].to(original_inputs[j].dtype) * original_inputs[j]
            + (~binary_mask[j]).to(original_inputs[j].dtype) * kwargs["baselines"][j]
            for j in range(len(feature_mask))
        )
        return input_space_transformed


def get_exp_kernel_similarity_function(
    distance_mode: str = "cosine",
    kernel_width: float = 1.0,
) -> Callable[..., float]:
    r"""
    This method constructs an appropriate similarity function to compute
    weights for perturbed sample in LIME. Distance between the original
    and perturbed inputs is computed based on the provided distance mode,
    and the distance is passed through an exponential kernel with given
    kernel width to convert to a range between 0 and 1.

    The callable returned can be provided as the similarity_fn for
    Lime or LimeBase.

    Args:

        distance_mode (str, optional): Distance mode can be either "cosine" or
                    "euclidean" corresponding to either cosine distance
                    or Euclidean distance respectively. Distance is computed
                    by flattening the original inputs and perturbed inputs
                    (concatenating tuples of inputs if necessary) and computing
                    distances between the resulting vectors.
                    Default: "cosine"
        kernel_width (float, optional):
                    Kernel width for exponential kernel applied to distance.
                    Default: 1.0

    Returns:

        *Callable*:
        - **similarity_fn** (*Callable*):
            Similarity function. This callable can be provided as the
            similarity_fn for Lime or LimeBase.
    """

    # pyre-fixme[3]: Return type must be annotated.
    # pyre-fixme[2]: Parameter must be annotated.
    def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs):
        flattened_original_inp = _flatten_tensor_or_tuple(original_inp).float()
        flattened_perturbed_inp = _flatten_tensor_or_tuple(perturbed_inp).float()
        if distance_mode == "cosine":
            cos_sim = CosineSimilarity(dim=0)
            distance = 1 - cos_sim(flattened_original_inp, flattened_perturbed_inp)
        elif distance_mode == "euclidean":
            distance = torch.norm(flattened_original_inp - flattened_perturbed_inp)
        else:
            raise ValueError("distance_mode must be either cosine or euclidean.")
        return math.exp(-1 * (distance**2) / (2 * (kernel_width**2)))

    return default_exp_kernel


def default_perturb_func(
    original_inp: TensorOrTupleOfTensorsGeneric, **kwargs: object
) -> Tensor:
    assert (
        "num_interp_features" in kwargs
    ), "Must provide num_interp_features to use default interpretable sampling function"
    if isinstance(original_inp, Tensor):
        device = original_inp.device
    else:
        device = original_inp[0].device

    probs = torch.ones(1, cast(int, kwargs["num_interp_features"])) * 0.5
    return torch.bernoulli(probs).to(device=device).long()


def construct_feature_mask(
    feature_mask: Union[None, Tensor, Tuple[Tensor, ...]],
    formatted_inputs: Tuple[Tensor, ...],
) -> Tuple[Tuple[Tensor, ...], int]:
    feature_mask_tuple: Tuple[Tensor, ...]
    if feature_mask is None:
        feature_mask_tuple, num_interp_features = _construct_default_feature_mask(
            formatted_inputs
        )
    else:
        feature_mask_tuple = _format_tensor_into_tuples(feature_mask)
        min_interp_features = int(
            min(
                torch.min(single_mask).item()
                for single_mask in feature_mask_tuple
                if single_mask.numel()
            )
        )
        if min_interp_features != 0:
            warnings.warn(
                "Minimum element in feature mask is not 0, shifting indices to"
                " start at 0.",
                stacklevel=2,
            )
            feature_mask_tuple = tuple(
                single_mask - min_interp_features for single_mask in feature_mask_tuple
            )

        num_interp_features = _get_max_feature_index(feature_mask_tuple) + 1
    return feature_mask_tuple, num_interp_features




  from .autonotebook import tqdm as notebook_tqdm


In [16]:
class LimeWithCustomArgumentToForwardFunc(LimeBaseWithCustomArgumentToForwardFunc):
    r"""
    Here we create a modification of Lime class from Captum Library (https://captum.ai/api/_modules/captum/attr/_core/lime.html)
    This will just inherit our modified LimeBase class
    """

    def __init__(
        self,
        forward_func: Callable[..., Tensor],
        interpretable_model: Optional[Model] = None,
        # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
        similarity_func: Optional[Callable] = None,
        # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
        perturb_func: Optional[Callable] = None,
    ) -> None:
        r"""

        Args:


            forward_func (Callable): The forward function of the model or any
                    modification of it
            interpretable_model (Model, optional): Model object to train
                    interpretable model.

                    This argument is optional and defaults to SkLearnLasso(alpha=0.01),
                    which is a wrapper around the Lasso linear model in SkLearn.
                    This requires having sklearn version >= 0.23 available.

                    Other predefined interpretable linear models are provided in
                    captum._utils.models.linear_model.

                    Alternatively, a custom model object must provide a `fit` method to
                    train the model, given a dataloader, with batches containing
                    three tensors:

                    - interpretable_inputs: Tensor
                      [2D num_samples x num_interp_features],
                    - expected_outputs: Tensor [1D num_samples],
                    - weights: Tensor [1D num_samples]

                    The model object must also provide a `representation` method to
                    access the appropriate coefficients or representation of the
                    interpretable model after fitting.

                    Note that calling fit multiple times should retrain the
                    interpretable model, each attribution call reuses
                    the same given interpretable model object.
            similarity_func (Callable, optional): Function which takes a single sample
                    along with its corresponding interpretable representation
                    and returns the weight of the interpretable sample for
                    training the interpretable model.
                    This is often referred to as a similarity kernel.

                    This argument is optional and defaults to a function which
                    applies an exponential kernel to the cosine distance between
                    the original input and perturbed input, with a kernel width
                    of 1.0.

                    A similarity function applying an exponential
                    kernel to cosine / euclidean distances can be constructed
                    using the provided get_exp_kernel_similarity_function in
                    captum.attr._core.lime.

                    Alternately, a custom callable can also be provided.
                    The expected signature of this callable is:

                    >>> def similarity_func(
                    >>>    original_input: Tensor or tuple[Tensor, ...],
                    >>>    perturbed_input: Tensor or tuple[Tensor, ...],
                    >>>    perturbed_interpretable_input:
                    >>>        Tensor [2D 1 x num_interp_features],
                    >>>    **kwargs: Any
                    >>> ) -> float or Tensor containing float scalar

                    perturbed_input and original_input will be the same type and
                    contain tensors of the same shape, with original_input
                    being the same as the input provided when calling attribute.

                    kwargs includes baselines, feature_mask, num_interp_features
                    (integer, determined from feature mask).
            perturb_func (Callable, optional): Function which returns a single
                    sampled input, which is a binary vector of length
                    num_interp_features, or a generator of such tensors.

                    This function is optional, the default function returns
                    a binary vector where each element is selected
                    independently and uniformly at LimeWithCustomArgumentToForwardFuncrandom. Custom
                    logic for selecting sampled binary vectors can
                    be implemented by providing a function with the
                    following expected signature:

                    >>> perturb_func(
                    >>>    original_input: Tensor or tuple[Tensor, ...],
                    >>>    **kwargs: Any
                    >>> ) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
                    >>>  or generator yielding such tensors

                    kwargs includes baselines, feature_mask, num_interp_features
                    (integer, determined from feature mask).

        """
        if interpretable_model is None:
            interpretable_model = SkLearnLasso(alpha=0.01)

        if similarity_func is None:
            similarity_func = get_exp_kernel_similarity_function()

        if perturb_func is None:
            perturb_func = default_perturb_func

        LimeBaseWithCustomArgumentToForwardFunc.__init__(
            self,
            forward_func,
            interpretable_model,
            similarity_func,
            perturb_func,
            True,
            default_from_interp_rep_transform,
            None,
        )

    @log_usage(part_of_slo=True)
    def attribute(  # type: ignore
        self,
        inputs: TensorOrTupleOfTensorsGeneric,
        baselines: BaselineType = None,
        target: TargetType = None,
        additional_forward_args: Optional[object] = None,
        feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
        n_samples: int = 25,
        perturbations_per_eval: int = 1,
        return_input_shape: bool = True,
        show_progress: bool = False,
    ) -> TensorOrTupleOfTensorsGeneric:
        r"""
        This method attributes the output of the model with given target index
        (in case it is provided, otherwise it assumes that output is a
        scalar) to the inputs of the model using the approach described above,
        training an interpretable model and returning a representation of the
        interpretable model.

        It is recommended to only provide a single example as input (tensors
        with first dimension or batch size = 1). This is because LIME is generally
        used for sample-based interpretability, training a separate interpretable
        model to explain a model's prediction on each individual example.

        A batch of inputs can also be provided as inputs, similar to
        other perturbation-based attribution methods. In this case, if forward_fn
        returns a scalar per example, attributions will be computed for each
        example independently, with a separate interpretable model trained for each
        example. Note that provided similarity and pertforward_funcurbation functions will be
        provided each example separately (first dimension = 1) in this case.
        If forward_fn returns a scalar per batch (e.g. loss), attributions will
        still be computed using a single interpretable model for the full batch.
        In this case, similarity and perturbation functions will be provided the
        same original input containing the full batch.

        The number of interpretable features is determined from the provided
        feature mask, or if none is provided, from the default feature mask,
        which considers each scalar input as a separate feature. It is
        generally recommended to provide a feature mask which groups features
        into a small number of interpretable features / components (e.g.
        superpixels in images).

        Args:

            inputs (Tensor or tuple[Tensor, ...]): Input for which LIME
                        is computed. If forward_func takes a single
                        tensor as input, a single input tensor should be provided.
                        If forward_func takes multiple tensors as input, a tuple
                        of the input tensors should be provided. It is assumed
                        that for all given input tensors, dimension 0 corresponds
                        to the number of examples, and if multiple input tensors
                        are provided, the examples must be aligned appropriately.
            baselines (scalar, Tensor, tuple of scalar, or Tensor, optional):
                        Baselines define reference value which replaces each
                        feature when the corresponding interpretable feature
                        is set to 0.
                        Baselines can be provided as:

                        - a single tensor, if inputs is a single tensor, with
                          exactly the same dimensions as inputs or the first
                          dimension is one and the remaining dimensions match
                          with inputs.

                        - a single scalar, if inputs is a single tensor, which will
                          be broadcasted for each input value in input tensor.

                        - a tuple of tensors or scalars, the baseline corresponding
                          to each tensor in the inputs' tuple can be:

                          - either a tensor with matching dimensions to
                            corresponding tensor in the inputs' tuple
                            or the first dimension is one and the remaining
                            dimensions match with the corresponding
                            input tensor.

                          - or a scalar, corresponding to a tensor in the
                            inputs' tuple. This scalar value is broadcasted
                            for corresponding input tensor.

                        In the cases when `baselines` iforward_funcs not provided, we internally
                        use zero scalar corresponding to each input tensor.
                        Default: None
            target (int, tuple, Tensor, or list, optional): Output indices for
                        which surrogate model is trained
                        (for classification cases,
                        this is usually the target class).
                        If the network returns a scalar value per example,
                        no target index is necessary.
                        For general 2D outputs, targets can be either:

                        - a single integer or a tensor containing a single
                          integer, which is applied to all input examples

                        - a list of integers or a 1D tensor, with length matching
                          the number of examples in inputs (dim 0). Each integer
                          is applied as the target for the corresponding example.

                        For outputs with > 2 dimensions, targets can be either:

                        - A single tuple, which contains #output_dims - 1
                          elements. This target index is applied to all examples.

                        - A list of tuples with length equal to the number of
                          examples in inputs (dim 0), and each tuple containing
                          #output_dims - 1 elements. Each tuple is applied as the
                          target for the corresponding example.

                        Default: None
            additional_forward_args (Any, optional): If the forward function
                        requires additional arguments other than the inputs for
                        which attributions should not be computed, this argument
                        can be provided. It must be either a single additional
                        argument of a Tensor or arbitrary (non-tuple) type or a
                        tuple containing multiple additional arguments including
                        tensors or any arbitrary python types. These arguments
                        are provided to forward_func in order following the
                        arguments in inputs.
                        For a tensor, the first dimension of the tensor must
                        correspond to the number of examples. It will be
                        repeated for each of `n_steps` along the integrated
                        path. For all other types, the given argument is used
                        for all forward evaluations.
                        Note that attributions are not computed with respect
                        to these arguments.
                        Default: None
            feature_mask (Tensor or tuple[Tensor, ...], optional):
                        feature_mask defines a mask for the input, grouping
                        features which correspond to the same
                        interpretable feature. feature_mask
                        should contain the same number of tensors as inputs.
                        Each tensor should
                        be the same size as the corresponding input or
                        broadcastable to match the inpuforward_funct tensor. Values across
                        all tensors should be integers in the range 0 to
                        num_interp_features - 1, and indices corresponding to the
                        same feature should have the same value.
                        Note that features are grouped across tensors
                        (unlike feature ablation and occlusion), so
                        if the same index is used in different tensors, those
                        features are still grouped and added simultaneously.
                        If None, then a feature mask is constructed which assigns
                        each scalar within a tensor as a separate feature.
                        Default: None
            n_samples (int, optional): The number of samples of the original
                        model used to train the surrogate interpretable model.
                        Default: `50` if `n_samples` is not provided.
            perturbations_per_eval (int, optional): Allows multiple samples
                        to be processed simultaneously in one call to forward_fn.
                        Each forward pass will contain a maximum of
                        perturbations_per_eval * #examples samples.
                        For DataParallel models, each batch is split among the
                        available devices, so evaluations on each available
                        device contain at most
                        (perturbations_per_eval * #examples) / num_devices
                        samples.
                        If the forward function returns a single scalar per batch,
                        perturbations_per_eval must be set to 1.
                        Default: 1
            return_input_shape (bool, optional): Determines whether the returned
                        tensor(s) only contain the coefficients for each interp-
                        retable feature from the trained surrogate model, or
                        whether the returned attributions match the input shape.
                        When return_input_shape is True, the return type of attribute
                        matches the input shape, with each element containing the
                        coefficient of the corresponding interpretale feature.
                        All elements with the same value in the feature mask
                        will contain the same coefficient in the returned
                        attributions.
                        If forward_func returns a single element per batch, then the
                        first dimension of each tensor will be 1, and the remaining
                        dimensions will have the same shape as the original input
                        tensor.
                        If return_input_shape is False, a 1D
                        tensor is returned, containing only the coefficients
                        of the trained interpreatable models, with length
                        num_interp_features.
            show_progress (bool, optional): Displays the progress of computation.
                        It will try to use tqdm if available for advanced features
                        (e.g. time estimation). Otherwise, it will fallback to
                        a simple output of progress.
                        Default: False

        Returns:
            *Tensor* or *tuple[Tensor, ...]* of **attributions**:
            - **attributions** (*Tensor* or *tuple[Tensor, ...]*):
                        The attributions with respect to each input feature.
                        If return_input_shape = True, attributions will be
                        the same size as the provided inputs, with each value
                        providing the coefficient of the corresponding
                        interpretale feature.
                        If return_input_shape is False, a 1D
                        tensor is returned, containing only the coefficients
                        of the trained interpreatable models, with length
                        num_interp_features.
        Examples::

            >>> # SimpleClassifier takes a single input tensor of size Nx4x4,
            >>> # and returns an Nx3 tensor of class probabilities.
            >>> net = SimpleClassifier()

            >>> # Generating random input with size 1 x 4 x 4
            >>> input = torch.randn(1, 4, 4)

            >>> # Defining Lime interpreter
            >>> lime = Lime(net)
            >>> # Computes attribution, with each of the 4 x 4 = 16
            >>> # features as a separate interpretable feature
            >>> attr = lime.attribute(input, target=1, n_samples=200)

            >>> # Alternatively, we can group each 2x2 square of the inputs
            >>> # as one 'interpretable' feature and perturb them together.
            >>> # This can be done by creating a feature mask as follows, which
            >>> # defines the feature groups, e.g.:
            >>> # +---+---+---+---+
            >>> # | 0 | 0 | 1 | 1 |
            >>> # +---+---+---+---+
            >>> # | 0 | 0 | 1 | 1 |
            >>> # +---+---+---+---+
            >>> # | 2 | 2 | 3 | 3 |
            >>> # +---+---+---+---+
            >>> # | 2 | 2 | 3 | 3 |
            >>> # +---+---+---+---+
            >>> # With this mask, all inputs with the same value are set to their
            >>> # baseline value, when the corresponding binary interpretable
            >>> # feature is set to 0.
            >>> # The attributions can be calculated as follows:
            >>> # feature mask has dimensions 1 x 4 x 4
            >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
            >>>                             [2,2,3,3],[2,2,3,3]]])

            >>> # Computes interpretable model and returning attributions
            >>> # matching input shape.
            >>> attr = lime.attribute(input, target=1, feature_mask=feature_mask)
        """
        return self._attribute_kwargs(
            inputs=inputs,
            baselines=baselines,
            target=target,
            additional_forward_args=additional_forward_args,
            feature_mask=feature_mask,
            n_samples=n_samples,
            perturbations_per_eval=perturbations_per_eval,
            return_input_shape=return_input_shape,
            show_progress=show_progress,
        )

    # pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
    def attribute_future(self) -> Callable:
        return super().attribute_future()

    def _attribute_kwargs(  # type: ignore
        self,
        inputs: TensorOrTupleOfTensorsGeneric,
        baselines: BaselineType = None,
        target: TargetType = None,
        additional_forward_args: Optional[object] = None,
        feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
        n_samples: int = 25,
        perturbations_per_eval: int = 1,
        return_input_shape: bool = True,
        monitor_log_path: str | None = None,
        monitor_convergence_step: int | None = 20,
        monitor_local_accuracy_step: int | None = 50,
        show_progress: bool = False,
        **kwargs: object,
    ) -> TensorOrTupleOfTensorsGeneric:
        is_inputs_tuple = _is_tuple(inputs)
        formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
        bsz = formatted_inputs[0].shape[0]

        feature_mask, num_interp_features = construct_feature_mask(
            feature_mask, formatted_inputs
        )

        if num_interp_features > 10000:
            warnings.warn(
                "Attempting to construct interpretable model with > 10000 features."
                "This can be very slow or lead to OOM issues. Please provide a feature"
                "mask which groups input features to reduce the number of interpretable"
                "features. ",
                stacklevel=1,
            )

        coefs: Tensor
        if bsz > 1:
            test_output = _run_forward(
                self.forward_func, inputs, target, additional_forward_args
            )
            if isinstance(test_output, Tensor) and torch.numel(test_output) > 1:
                if torch.numel(test_output) == bsz:
                    warnings.warn(
                        "You are providing multiple inputs for Lime / Kernel SHAP "
                        "attributions. This trains a separate interpretable model "
                        "for each example, which can be time consuming. It is "
                        "recommended to compute attributions for one example at a "
                        "time.",
                        stacklevel=1,
                    )
                    output_list = []
                    for (
                        curr_inps,
                        curr_target,
                        curr_additional_args,
                        curr_baselines,
                        curr_feature_mask,
                    ) in _batch_example_iterator(
                        bsz,
                        formatted_inputs,
                        target,
                        additional_forward_args,# -----> CAN BE ALSO BATCHED AUTOMATICALLY BY THE LIBRARY ITERATOR
                        baselines,
                        feature_mask,
                    ):
                        coefs = super().attribute.__wrapped__(
                            self,
                            inputs=curr_inps if is_inputs_tuple else curr_inps[0],
                            target=curr_target,
                            additional_forward_args=curr_additional_args,
                            n_samples=n_samples,
                            perturbations_per_eval=perturbations_per_eval,
                            baselines=(
                                curr_baselines if is_inputs_tuple else curr_baselines[0]
                            ),
                            feature_mask=(
                                curr_feature_mask
                                if is_inputs_tuple
                                else curr_feature_mask[0]
                            ),
                            num_interp_features=num_interp_features,
                            show_progress=show_progress,
                            **kwargs,
                        )
                        if return_input_shape:
                            output_list.append(
                                self._convert_output_shape(
                                    curr_inps,
                                    curr_feature_mask,
                                    coefs,
                                    num_interp_features,
                                    is_inputs_tuple,
                                )
                            )
                        else:
                            output_list.append(coefs.reshape(1, -1))  # type: ignore

                    return _reduce_list(output_list)
                else:
                    raise AssertionError(
                        "Invalid number of outputs, forward function should return a"
                        "scalar per example or a scalar per input batch."
                    )
            else:
                assert perturbations_per_eval == 1, (
                    "Perturbations per eval must be 1 when forward function"
                    "returns single value per batch!"
                )

        coefs = super().attribute.__wrapped__(
            self,
            inputs=inputs,
            target=target,
            additional_forward_args=additional_forward_args,
            n_samples=n_samples,
            perturbations_per_eval=perturbations_per_eval,
            baselines=baselines if is_inputs_tuple else baselines[0],
            feature_mask=feature_mask if is_inputs_tuple else feature_mask[0],
            num_interp_features=num_interp_features,
            monitor_log_path = monitor_log_path,
            monitor_convergence_step = monitor_convergence_step,
            monitor_local_accuracy_step = monitor_local_accuracy_step,
            show_progress=show_progress,
            **kwargs,
        )
        if return_input_shape:
            # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
            #  `Tuple[Tensor, ...]`.
            return self._convert_output_shape(
                formatted_inputs,
                feature_mask,
                coefs,
                num_interp_features,
                is_inputs_tuple,
    
            leading_dim_one=(bsz > 1),
            )
        else:
            return coefs

    @typing.overload
    def _convert_output_shape(
        self,
        formatted_inp: Tuple[Tensor, ...],
        feature_mask: Tuple[Tensor, ...],
        coefs: Tensor,
        num_interp_features: int,
        is_inputs_tuple: Literal[True],
        leading_dim_one: bool = False,
    ) -> Tuple[Tensor, ...]: ...

    @typing.overload
    def _convert_output_shape(  # type: ignore
        self,
        formatted_inp: Tuple[Tensor, ...],
        feature_mask: Tuple[Tensor, ...],
        coefs: Tensor,
        num_interp_features: int,
        is_inputs_tuple: Literal[False],
        leading_dim_one: bool = False,
    ) -> Tensor: ...

    @typing.overload
    def _convert_output_shape(
        self,
        formatted_inp: Tuple[Tensor, ...],
        feature_mask: Tuple[Tensor, ...],
        coefs: Tensor,
        num_interp_features: int,
        is_inputs_tuple: bool,
        leading_dim_one: bool = False,
    ) -> Union[Tensor, Tuple[Tensor, ...]]: ...

    def _convert_output_shape(
        self,
        formatted_inp: Tuple[Tensor, ...],
        feature_mask: Tuple[Tensor, ...],
        coefs: Tensor,
        num_interp_features: int,
        is_inputs_tuple: bool,
        leading_dim_one: bool = False,
    ) -> Union[Tensor, Tuple[Tensor, ...]]:
        coefs = coefs.flatten()
        attr = [
            torch.zeros_like(single_inp, dtype=torch.float)
            for single_inp in formatted_inp
        ]
        for tensor_ind in range(len(formatted_inp)):
            for single_feature in range(num_interp_features):
                attr[tensor_ind] += (
                    coefs[single_feature].item()
                    * (feature_mask[tensor_ind] == single_feature).float()
                )
        if leading_dim_one:
            for i in range(len(attr)):
                attr[i] = attr[i][0:1]
        return _format_output(is_inputs_tuple, tuple(attr))


In [17]:
#!/usr/bin/env python3

# pyre-strict

from typing import Callable, cast, Generator, Optional, Tuple, Union

import torch
from captum._utils.models.linear_model import SkLearnLinearRegression
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.lime import construct_feature_mask, Lime
from captum.attr._utils.common import _format_input_baseline
from captum.log import log_usage
from torch import Tensor
from torch.distributions.categorical import Categorical


class KernelShapWithMask(LimeWithCustomArgumentToForwardFunc):
    r"""
    Kernel SHAP is a method that uses the LIME framework to compute
    Shapley Values. Setting the loss function, weighting kernel and
    regularization terms appropriately in the LIME framework allows
    theoretically obtaining Shapley Values more efficiently than
    directly computing Shapley Values.

    More information regarding this method and proof of equivalence
    can be found in the original paper here:
    https://arxiv.org/abs/1705.07874
    """

    def __init__(self, 
                 forward_func: Callable[..., Tensor],
                 surrogate_model: str = "linear regression",
                 alpha_surrogate: float = 0.01,
                 max_iter_surrogate: int = 1000
                ) -> None:
        r"""
        Args:

            forward_func (Callable): The forward function of the model or
                        any modification of it.
        """
        if surrogate_model == "linear regression":
            interpretable_model = SkLearnLinearRegression()
        elif surrogate_model == "lasso": 
            interpretable_model = SkLearnLasso(alpha=alpha_surrogate, max_iter=max_iter_surrogate)
            
        LimeWithCustomArgumentToForwardFunc.__init__(
            self,
            forward_func,
            interpretable_model=interpretable_model,
            similarity_func=self.kernel_shap_similarity_kernel,
            perturb_func=self.kernel_shap_perturb_generator,
        )
        self.inf_weight = 1000000.0

    @log_usage(part_of_slo=True)
    def attribute(  # type: ignore
        self,
        inputs: TensorOrTupleOfTensorsGeneric,
        baselines: BaselineType = None,
        target: TargetType = None,
        additional_forward_args: Optional[object] = None,
        feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
        n_samples: int = 25,
        perturbations_per_eval: int = 1,
        return_input_shape: bool = True,
        monitor_log_path: str | None = None,
        monitor_convergence_step: int | None = 20,
        monitor_local_accuracy_step: int | None = 50,
        show_progress: bool = False,
    ) -> TensorOrTupleOfTensorsGeneric:
        r"""
        This method attributes the output of the model with given target index
        (in case it is provided, otherwise it assumes that output is a
        scalar) to the inputs of the model using the approach described above,
        training an interpretable model based on KernelSHAP and returning a
        representation of the interpretable model.

        It is recommended to only provide a single example as input (tensors
        with first dimension or batch size = 1). This is because LIME / KernelShap
        is generally used for sample-based interpretability, training a separate
        interpretable model to explain a model's prediction on each individual example.

        A batch of inputs can also be provided as inputs, similar to
        other perturbation-based attribution methods. In this case, if forward_fn
        returns a scalar per example, attributions will be computed for each
        example independently, with a separate interpretable model trained for each
        example. Note that provided similarity and perturbation functions will be
        provided each example separately (first dimension = 1) in this case.
        If forward_fn returns a scalar per batch (e.g. loss), attributions will
        still be computed using a single interpretable model for the full batch.
        In this case, similarity and perturbation functions will be provided the
        same original input containing the full batch.

        The number of interpretable features is determined from the provided
        feature mask, or if none is provided, from the default feature mask,
        which considers each scalar input as a separate feature. It is
        generally recommended to provide a feature mask which groups features
        into a small number of interpretable features / components (e.g.
        superpixels in images).

        Args:

            inputs (Tensor or tuple[Tensor, ...]): Input for which KernelShap
                        is computed. If forward_func takes a single
                        tensor as input, a single input tensor should be provided.
                        If forward_func takes multiple tensors as input, a tuple
                        of the input tensors should be provided. It is assumed
                        that for all given input tensors, dimension 0 corresponds
                        to the number of examples, and if multiple input tensors
                        are provided, the examples must be aligned appropriately.
            baselines (scalar, Tensor, tuple of scalar, or Tensor, optional):
                        Baselines define the reference value which replaces eachconv_dist_L1 and delta_shap
                        feature when the corresponding interpretable feature
                        is set to 0.
                        Baselines can be provided as:

                        - a single tensor, if inputs is a single tensor, with
                          exactly the same dimensions as inputs or the first
                          dimension is one and the remaining dimensions match
                          with inputs.

                        - a single scalar, if inputs is a single tensor, which will
                          be broadcasted for each input value in input tensor.

                        - a tuple of tensors or scalars, the baseline corresponding
                          to each tensor in the inputs' tuple can be:

                          - either a tensor with matching dimensions to
                            corresponding tensor in the inputs' tuple
                            or the first dimension is one and the remaining
                            dimensions match with the corresponding
                            input tensor.

                          - or a scalar, corresponding to a tensor in the
                            inputs' tuple. This scalar value is broadcasted
                            for corresponding input tensor.

                        In the cases when `baselines` is not provided, we internally
                        use zero scalar corresponding to each input tensor.
                        Default: None
            target (int, tuple, Tensor, or list, optional): Output indices for
                        which surrogate model is trained
                        (for classification cases,
                        this is usually the target class).
                        If the network returns a scalar value per example,
                        no target index is necessary.
                        For general 2D outputs, targets can be either:

                        - a single integer or a tensor containing a single
                          integer, which is applied to all input examples

                        - a list of integers or a 1D tensor, with length matching
                          the number of examples in inputs (dim 0). Each integer
                          is applied as the target for the corresponding example.

                        For outputs with > 2 dimensions, targets can be either:

                        - A single tuple, which contains #output_dims - 1
                          elements. This target index is applied to all examples.

                        - A list of tuples with length equal to the number of
                          examples in inputs (dim 0), and each tuple containing
                          #output_dims - 1 elements. Each tuple is applied as the
                          target for the corresponding example.

                        Default: None
            additional_forward_args (Any, optional): If the forward function
                        requires additional arguments other than the inputs for
                        which attributions should not be computed, this argument
                        can be provided. It must be either a single additional
                        argument of a Tensor or arbitrary (non-tuple) type or a
                        tuple containing multiple additional arguments including
                        tensors or any arbitrary python types. These arguments
                        are provided to forward_func in order following the
                        arguments in inputs.
                        For a tensor, the first dimension of the tensor must
                        correspond to the number of examples. It will be
                        repeated for each of `n_steps` along the integrated
                        path. For all other types, the given argument is used
                        for all forward evaluations.
                        Note that attributions are not computed with respect
                        to these arguments.
                        Default: None
            feature_mask (Tensor or tuple[Tensor, ...], optional):
                        feature_mask defines a mask for the input, grouping
                        features which correspond to the same
                        interpretable feature. feature_mask
                        should contain the same number of tensors as inputs.
                        Each tensor should
                        be the same size as the corresponding input or
                        broadcastable to match the input tensor. Values across
                        all tensors should be integers in the range 0 to
                        num_interp_features - 1, and indices corresponding to the
                        same feature should have the same value.
                        Note that features are grouped across tensors
                        (unlike feature ablation and occlusion), so
                        if the same index is used in different tensors, those
                        features are still grouped and added simultaneously.
                        If None, then a feature mask is constructed which assigns
                        each scalar within a tensor as a separate feature.
                        Default: None
            n_samples (int, optional): The number of samples of the original
                        model used to train the surrogate interpretable model.
                        Default: `50` if `n_samples` is not provided.
            perturbations_per_eval (int, optional): Allows multiple samples
                        to be processed simultaneously in one call to forward_fn.
                        Each forward pass will contain a maximum of
                        perturbations_per_eval * #examples samples.
                        For DataParallel models, each batch is split among the
                        available devices, so evaluations on each available
                        device contain at most
                        (perturbations_per_eval * #examples) / num_devices
                        samples.
                        If the forward function returns a single scalar per batch,
                        perturbations_per_eval must be set to 1.
                        Default: 1
            return_input_shape (bool, optional): Determines whether the returned
                        tensor(s) only contain the coefficients for each interp-
                        retable feature from the trained surrogate model, or
                        whether the returned attributions match the input shape.
                        When return_input_shape is True, the return type of attribute
                        matches the input shape, with each element containing the
                        coefficient of the corresponding interpretable feature.
                        All elements with the same value in the feature mask
                        will contain the same coefficient in the returned
                        attributions. If return_input_shape is False, a 1D
                        tensor is returned, containing only the coefficients
                        of the trained interpretable model, with length
                        num_interp_features.
            monitor_log_path (str, optional): Path to the log file for monitoring convergence.
                        if None, no monitoring is performed.
                        Default: None
            monitor_convergence_step (int, optional): Number of iterations over which
                        the difference among two attribution is computerd.
                        Default: 20
            monitor_local_accuracy_step (int, optional): Number of iterations over which
                        the local accuracy of an attribution is computerd.
                        Default: 50
            show_progress (bool, optional): Displays the progress of computation.
                        It will try to use tqdm if available for advanced features
                        (e.g. time estimation). Otherwise, it will fallback to
                        a simple output of progress.
                        Default: False

        Returns:
            *Tensor* or *tuple[Tensor, ...]* of **attributions**:
            - **attributions** (*Tensor* or *tuple[Tensor, ...]*):
                        The attributions with respect to each input feature.
                        If return_input_shape = True, attributions will be
                        the same size as the provided inputs, with each value
                        providing the coefficient of the corresponding
                        interpretale feature.
                        If return_input_shape is False, a 1D
                        tensor is returned, containing only the coefficients
                        of the trained interpreatable models, with length
                        num_interp_features.
        Examples::
            >>> # SimpleClassifier takes a single input tensor of size Nx4x4,
            >>> # and returns an Nx3 tensor of class probabilities.
            >>> net = SimpleClassifier()

            >>> # Generating random input with size 1 x 4 x 4
            >>> input = torch.randn(1, 4, 4)

            >>> # Defining KernelShap interpreter
            >>> ks = KernelShap(net)
            >>> # Computes attribution, with each of the 4 x 4 = 16
            >>> # features as a separate interpretable feature
            >>> attr = ks.attribute(input, target=1, n_samples=200)

            >>> # Alternatively, we can group each 2x2 square of the inputs
            >>> # as one 'interpretable' feature and perturb them together.
            >>> # This can be done by creating a feature mask as follows, which
            >>> # defines the feature groups, e.g.:
            >>> # +---+---+---+---+
            >>> # | 0 | 0 | 1 | 1 |
            >>> # +---+---+---+---+
            >>> # | 0 | 0 | 1 | 1 |
            >>> # +---+---+---+---+
            >>> # | 2 | 2 | 3 | 3 |
            >>> # +---+---+---+---+
            >>> # | 2 | 2 | 3 | 3 |
            >>> # +---+---+---+---+
            >>> # With this mask, all inputs with the same value are set to their
            >>> # baseline value, when the corresponding binary interpretable
            >>> # feature is set to 0.
            >>> # The attributions can be calculated as follows:
            >>> # feature mask has dimensions 1 x 4 x 4
            >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
            >>>                             [2,2,3,3],[2,2,3,3]]])

            >>> # Computes KernelSHAP attributions with feature mask.
            >>> attr = ks.attribute(input, target=1, feature_mask=feature_mask)
        """
        formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
        feature_mask, num_interp_features = construct_feature_mask(
            feature_mask, formatted_inputs
        )
        num_features_list = torch.arange(num_interp_features, dtype=torch.float)
        denom = num_features_list * (num_interp_features - num_features_list)
        probs = torch.tensor((num_interp_features - 1)) / denom
        probs[0] = 0.0
        return self._attribute_kwargs(
            inputs=inputs,
            baselines=baselines,
            target=target,
            additional_forward_args=additional_forward_args,
            feature_mask=feature_mask,
            n_samples=n_samples,
            perturbations_per_eval=perturbations_per_eval,
            return_input_shape=return_input_shape,
            num_select_distribution=Categorical(probs),
            monitor_log_path=monitor_log_path,
            monitor_convergence_step=monitor_convergence_step,
            monitor_local_accuracy_step=monitor_local_accuracy_step,
            show_progress=show_progress,
        )

    # pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
    def attribute_future(self) -> Callable:
        r"""
        This method is not implemented for KernelShap.
        """
        raise NotImplementedError("attribute_future is not implemented for KernelShap")

    def kernel_shap_similarity_kernel(
        self,
        _,
        __,
        interpretable_sample: Tensor,
        **kwargs: object,
    ) -> Tensor:
        assert (
            "num_interp_features" in kwargs
        ), "Must provide num_interp_features to use default similarity kernel"
        num_selected_features = int(interpretable_sample.sum(dim=1).item())
        num_features = kwargs["num_interp_features"]
        if num_selected_features == 0 or num_selected_features == num_features:
            # weight should be theoretically infinite when
            # num_selected_features = 0 or num_features
            # enforcing that trained linear model must satisfy
            # end-point criteria. In practice, it is sufficient to
            # make this weight substantially larger so setting this
            # weight to 1000000 (all other weights are 1).
            similarities = self.inf_weight
        else:
            similarities = 1.0
        return torch.tensor([similarities])

    def kernel_shap_perturb_generator(
        self,
        original_inp: Union[Tensor, Tuple[Tensor, ...]],
        **kwargs: object,
    ) -> Generator[Tensor, None, None]:
        r"""
        Perturbations are sampled by the following process:
         - Choose k (number of selected features), based on the distribution
                p(k) = (M - 1) / (k * (M - k))

            where M is the total number of features in the interpretable space

         - Randomly select a binary vector with k ones, each sample is equally
            likely. This is done by generating a random vector of normal
            values and thresholding based on the top k elements.

         Since there are M choose k vectors with k ones, this weighted sampling
         is equivalent to applying the Shapley kernel for the sample weight,
         defined as:
         k(M, k) = (M - 1) / (k * (M - k) * (M choose k))
        """
        assert (
            "num_select_distribution" in kwargs and "num_interp_features" in kwargs
        ), (
            "num_select_distribution and num_interp_features are necessary"
            " to use kernel_shap_perturb_func"
        )
        if isinstance(original_inp, Tensor):
            device = original_inp.device
        else:
            device = original_inp[0].device
        num_features = cast(int, kwargs["num_interp_features"])
        yield torch.ones(1, num_features, device=device, dtype=torch.long)
        yield torch.zeros(1, num_features, device=device, dtype=torch.long)
        while True:
            num_selected_features = cast(
                Categorical, kwargs["num_select_distribution"]
            ).sample()
            rand_vals = torch.randn(1, num_features)
            threshold = torch.kthvalue(
                rand_vals, num_features - num_selected_features
            ).values.item()
            yield (rand_vals > threshold).to(device=device).long()

   

In [18]:
# define an utility for annoying nnunetv2 preprocessing
def nnunetv2_default_preprocessing(ct_img_path, predictor, dataset_json_path) -> np.ndarray:
    plans_manager = predictor.plans_manager
    configuration_manager = predictor.configuration_manager
    
    preprocessor = configuration_manager.preprocessor_class(verbose=False)
    rw = plans_manager.image_reader_writer_class()
    if callable(rw) and not hasattr(rw, "read_images"):
        rw = rw()
    img_np, img_props = rw.read_images([str(ct_img_path)])
    
    preprocessed, _, _ = preprocessor.run_case_npy(
        img_np, seg=None, properties=img_props,
        plans_manager=plans_manager,
        configuration_manager=configuration_manager,
        dataset_json=dataset_json_path
    )
    return preprocessed

# Next step: define regular, fixed size superpixels and try to compute the attributions of each of them
So we also need to define a metric to compare, since segmentation explanations, differently from classification, is intrinsically ambiguous. For example, let's select a priori a single region of the segmentation output, and use the average of these pixels to compute the impact of perturbations.

### 4. Face-centered cubic (FCC) lattice induced supervoxel assignment
-> more *isotropic* than simple cubes

### 4.1 affine transformation to translate isotropy from voxel space into the original geometrical space (measured in mm)

### 4.2 Try to apply the original algorithm FCC it to an affine transformed volume that has the same proportion as the .nii in the physical space. Transform->apply the algorithm to derive the map-> back transform the map onto the voxel space



In [19]:
import numpy as np
import nibabel as nib
from scipy.spatial import cKDTree

def generate_supervoxel_map(img, S=200.0):
    """
    Generate a supervoxel map using FCC tessellation in physical space (original version).
    
    Args:
        img: Nibabel NIfTI image object
        S (float): Desired supervoxel size in millimeters (default: 200.0)
    
    Returns:
        supervoxel_map: 3D NumPy array with integer labels for supervoxels
    """
    # Load volume and affine from image
    volume = img.get_fdata()
    affine = img.affine
    W, H, D = volume.shape

    # Compute the physical bounding box of the volume
    corners_voxel = np.array([
        [0, 0, 0],
        [W-1, 0, 0],
        [0, H-1, 0],
        [0, 0, D-1],
        [W-1, H-1, 0],
        [W-1, 0, D-1],
        [0, H-1, D-1],
        [W-1, H-1, D-1]
    ])
    corners_hom = np.hstack((corners_voxel, np.ones((8, 1))))
    corners_physical = (affine @ corners_hom.T).T[:, :3]
    min_xyz = corners_physical.min(axis=0)
    max_xyz = corners_physical.max(axis=0)

    # Generate FCC lattice centers in physical space
    a = S * np.sqrt(2)
    factor = 2 / a
    p_min = int(np.floor(factor * min_xyz[0])) - 1
    p_max = int(np.ceil(factor * max_xyz[0])) + 1
    q_min = int(np.floor(factor * min_xyz[1])) - 1
    q_max = int(np.ceil(factor * max_xyz[1])) + 1
    r_min = int(np.floor(factor * min_xyz[2])) - 1
    r_max = int(np.ceil(factor * max_xyz[2])) + 1

    # Create grid of possible indices
    p_vals = np.arange(p_min, p_max + 1)
    q_vals = np.arange(q_min, q_max + 1)
    r_vals = np.arange(r_min, r_max + 1)
    P, Q, R = np.meshgrid(p_vals, q_vals, r_vals, indexing='ij')
    P = P.flatten()
    Q = Q.flatten()
    R = R.flatten()

    # Filter for FCC lattice points (sum of indices is even)
    mask = (P + Q + R) % 2 == 0
    P = P[mask]
    Q = Q[mask]
    R = R[mask]

    # Compute physical coordinates of centers
    centers = np.column_stack((P * a / 2, Q * a / 2, R * a / 2))

    # Keep only centers within the bounding box
    inside = ((centers[:, 0] >= min_xyz[0]) & (centers[:, 0] <= max_xyz[0]) &
              (centers[:, 1] >= min_xyz[1]) & (centers[:, 1] <= max_xyz[1]) &
              (centers[:, 2] >= min_xyz[2]) & (centers[:, 2] <= max_xyz[2]))
    centers = centers[inside]

    # Check if any centers were generated
    print(f"Number of supervoxel centers: {len(centers)}")
    if len(centers) == 0:
        raise ValueError("No supervoxel centers generated. Try reducing S.")

    # Generate voxel indices and transform to physical coordinates
    voxel_indices = np.indices((W, H, D)).reshape(3, -1).T  # shape (W*H*D, 3)
    voxel_indices_hom = np.hstack((voxel_indices, np.ones((voxel_indices.shape[0], 1))))  # shape (W*H*D, 4)
    physical_coords = (affine @ voxel_indices_hom.T).T[:, :3]  # shape (W*H*D, 3)

    # Assign each voxel to the nearest supervoxel center
    tree = cKDTree(centers)
    _, labels = tree.query(physical_coords)

    # Create the supervoxel map
    supervoxel_map = labels.reshape((W, H, D)).astype(np.int32)

    return supervoxel_map

## Try SLIC for visual context aware supervoxels

### define a preprocessing routine to enhance SLIC results

In [20]:
from skimage import exposure
from skimage.restoration import denoise_nl_means, estimate_sigma
from scipy.ndimage import gaussian_filter
from skimage.exposure import equalize_adapthist

def preprocessing_for_SLIC(data: np.ndarray):
    # 2️⃣ Clip extreme intensities (e.g. 0.5%–99.5% quantiles) for contrast enhancement
    vmin, vmax = np.quantile(data, (0.005, 0.995))
    data = np.clip(data, vmin, vmax)
    data = exposure.rescale_intensity(data, in_range=(vmin, vmax), out_range=(0, 1))  #  [oai_citation:0‡scikit-image.org](https://scikit-image.org/docs/0.25.x/api/skimage.segmentation.html?utm_source=chatgpt.com) [oai_citation:1‡researchgate.net](https://www.researchgate.net/publication/330691413_A_novel_technique_for_analysing_histogram_equalized_medical_images_using_superpixels?utm_source=chatgpt.com) [oai_citation:2‡arxiv.org](https://arxiv.org/abs/2204.05278?utm_source=chatgpt.com) [oai_citation:3‡scikit-image.org](https://scikit-image.org/skimage-tutorials/lectures/three_dimensional_image_processing.html?utm_source=chatgpt.com)

    sigma_vox = np.array([1.0 / s for s in spacing])  # blur by 1 mm across axes
    data = gaussian_filter(data, sigma=sigma_vox)

    data = exposure.equalize_hist(data)

    # Apply slice-wise CLAHE for 3D volume
    data = np.stack([equalize_adapthist(slice_, clip_limit=0.03)
                       for slice_ in data], axis=0)

    return data

In [21]:
from skimage.segmentation import slic

In [22]:
# Parametri SLIC: n_segments definisce quanti supervoxels circa si voglion
apply_SLIC = lambda data, spacing, n_supervoxels: slic(
                preprocessing_for_SLIC(data), 
                n_segments=n_supervoxels, 
                compactness=0.2,
                spacing=spacing,
                start_label=0,
                max_num_iter=10, 
                channel_axis=None)

### Initialize predictor

In [23]:
# 2) Initialise predictor ------------------------
predictor = CustomNNUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=False, # == test time augmentation
    perform_everything_on_device=True,
    device=torch.device('cuda', 0),
    verbose=False,
    verbose_preprocessing=False,
    allow_tqdm=False #it interfere with SHAP loading bar
)
# initializes the network architecture, loads the checkpoint
predictor.initialize_from_trained_model_folder(
    model_dir,
    use_folds=(0,),
    checkpoint_name='checkpoint_final.pth',
)



# Define a ROI to explain segmentation in. 
Maybe this will provide a more useful attribution map, highlighting nearby organs

In [24]:
# get the manually derived ROI mask from the dataset, where we manually added it
if IN_KAGGLE:
    ROI_segmentation_mask_path = "/kaggle/input/segmentation-masked-ROI.nii"
else:
    ROI_segmentation_mask_path = nnUNet_raw + "/segmentation-masked-ROI.nii"

ROI_segmentation_mask = nib.load(ROI_segmentation_mask_path)

print(ROI_segmentation_mask.get_fdata().shape)
print(ROI_segmentation_mask.affine)

(512, 512, 283)
[[-1.17187500e+00  0.00000000e+00  0.00000000e+00  3.00000000e+02]
 [ 0.00000000e+00 -1.17187500e+00  0.00000000e+00  1.86100006e+02]
 [ 0.00000000e+00  0.00000000e+00  5.00000000e+00 -1.73950000e+03]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]]


### ROI mask is a binary mask highlighting the lymphnodes of interest. We need a bounding box to crop the volume accordingly

In [25]:
import nibabel as nib
import numpy as np

def get_mask_bbox_slices(mask_nii_path):
    """
    Load a binary ROI mask NIfTI and compute the minimal 3D bounding
    box slices containing all positive voxels.

    Parameters
    ----------
    mask_nii_path : str or Path
        Path to the input binary ROI mask NIfTI (.nii or .nii.gz).

    Returns
    -------
    bbox_slices : tuple of slice
        A 3-tuple of Python slice objects (x_slice, y_slice, z_slice)
        defining the minimal bounding box.
    """
    # 1) Load mask
    nii = nib.load(str(mask_nii_path))
    data = nii.get_fdata()
    if data.ndim != 3:
        raise ValueError("Input NIfTI must be a 3D volume")
    
    # 2) Find indices of positive voxels
    pos_voxels = np.argwhere(data > 0)
    if pos_voxels.size == 0:
        raise ValueError("No positive voxels found in mask")
    
    # 3) Compute min/max per axis
    x_min, y_min, z_min = pos_voxels.min(axis=0)
    x_max, y_max, z_max = pos_voxels.max(axis=0)
    
    # 4) Build slice objects (end is exclusive, hence +1)
    bbox_slices = (
        slice(int(x_min), int(x_max) + 1),
        slice(int(y_min), int(y_max) + 1),
        slice(int(z_min), int(z_max) + 1),
    )
    
    return bbox_slices

In [26]:
x_slice, y_slice, z_slice = get_mask_bbox_slices(ROI_segmentation_mask_path)
print("Bounding box slices:")
print("  x:", x_slice)
print("  y:", y_slice)
print("  z:", z_slice)

Bounding box slices:
  x: slice(219, 351, None)
  y: slice(189, 316, None)
  z: slice(156, 182, None)


### the identified region is our ROI bounding box

In [27]:
import numpy as np

def slices_to_binary_mask(volume_shape, bbox_slices, dtype=np.uint8):
    """
    Create a binary mask of given shape where voxels inside the provided
    3D bounding‐box slices are set to 1, and all others to 0.

    Parameters
    ----------
    volume_shape : tuple of int
        The full 3D volume dimensions, e.g. (X, Y, Z).
    bbox_slices : tuple of slice
        A 3‐tuple of slice objects (x_slice, y_slice, z_slice) defining
        the region to mask.
    dtype : data‐type, optional
        The desired data‐type of the output mask (default: np.uint8).

    Returns
    -------
    mask : np.ndarray
        A binary mask array of shape `volume_shape`, with ones in the
        region defined by `bbox_slices` and zeros elsewhere.
    """
    if len(volume_shape) != len(bbox_slices):
        raise ValueError(f"volume_shape has {len(volume_shape)} dimensions, "
                         f"but bbox_slices has {len(bbox_slices)} slices")

    # Initialize mask to zeros
    mask = np.zeros(volume_shape, dtype=dtype)
    # Set the bounding‐box region to 1
    mask[bbox_slices] = 1

    return mask

In [28]:
ROI_mask_path = "ROI_binary_mask.nii.gz"
shape = ROI_segmentation_mask.get_fdata().shape
print("shape", shape)
ROI_mask = slices_to_binary_mask(
    volume_shape=shape,
    bbox_slices=(x_slice,y_slice,z_slice),
)
nib.save(
        nib.Nifti1Image(ROI_mask, affine=ROI_segmentation_mask.affine),
        ROI_mask_path
        )

shape (512, 512, 283)


### to **crop** correctly the volume around the *ROI*, we need to derive the **receptive field** of the sliding window inference, that depends on the *patch size*.

In [29]:
patch_size = np.array(predictor.configuration_manager.patch_size)
print("Patch size: ", patch_size)

# Receptive field is twice the patch size-1
RF = 2*(patch_size-1)
print("Receptive field: ", RF)

Patch size:  [ 72 160 160]
Receptive field:  [142 318 318]


### Consider the receptive field to compute the final slices for cropping
Remember that model metadata are related to transposed volume (nnunetv2 takes (D, H, W) shape)

In [30]:
volume_path = ct_img_path
volume_shape = nib.load(volume_path).shape # (W, H, D) -> x, y, z
# RF shape -> (D, H, W) -> z, y, x (model input shape)
print("Original volume shape:", volume_shape)

W, H, D = volume_shape

# backward sorted receptive field axes
RF_x, RF_y, RF_z = RF[2],RF[1],RF[0]

x_slice_RF = slice(int(max(x_slice.start - RF_x/2, 0)), int(min(x_slice.stop + RF_x/2, W)))
y_slice_RF = slice(int(max(y_slice.start - RF_y/2, 0)), int(min(y_slice.stop + RF_y/2, H)))
z_slice_RF = slice(int(max(z_slice.start - RF_z/2, 0)), int(min(z_slice.stop + RF_z/2, D)))

print("new  x:", x_slice_RF)
print("new  y:", y_slice_RF)
print("new  z:", z_slice_RF)

Original volume shape: (512, 512, 283)
new  x: slice(60, 510, None)
new  y: slice(30, 475, None)
new  z: slice(85, 253, None)


In [31]:
import nibabel as nib
import numpy as np

def crop_volume_and_affine(nii_path, bbox_slices, save_cropped_nii_path=None):
    """
    Crop a 3D NIfTI volume using the given bounding-box slices and
    recompute the affine so the cropped volume retains correct world coordinates.

    Parameters
    ----------
    nii_path : str or Path
        Path to the input NIfTI volume (.nii or .nii.gz).
    bbox_slices : tuple of slice
        A 3-tuple (x_slice, y_slice, z_slice) as returned by get_mask_bbox_slices().
    save_cropped_nii_path : str or Path, optional
        If provided, the cropped volume will be saved here as a new NIfTI.

    Returns
    -------
    cropped_data : np.ndarray
        The volume data cropped to the bounding box.
    new_affine : np.ndarray
        The updated 4×4 affine transform for the cropped volume.
    """
    # 1) Load the original image
    img = nib.load(str(nii_path))
    data = img.get_fdata()
    affine = img.affine

    # 2) Crop the data array
    cropped_data = data[bbox_slices]

    # 3) Extract the voxel‐offsets for x, y, z from the slice starts
    x_slice, y_slice, z_slice = bbox_slices
    z0, y0, x0 = z_slice.start, y_slice.start, x_slice.start

    # 4) Compute the new affine translation: shift the origin by the voxel offsets
    # Note voxel coordinates are (i, j, k) = (x, y, z)
    offset_vox = np.array([x0, y0, z0])
    new_affine = affine.copy()
    new_affine[:3, 3] += affine[:3, :3].dot(offset_vox)

    # 5) Optionally save the cropped volume
    if save_cropped_nii_path is not None:
        cropped_img = nib.Nifti1Image(cropped_data, new_affine)
        nib.save(cropped_img, str(save_cropped_nii_path))

    return cropped_data, new_affine


In [32]:
from pathlib import Path

slices = (x_slice_RF, y_slice_RF, z_slice_RF)

nii_path = ct_img_path

cropped_volume, affine_cropped_volume = crop_volume_and_affine(
    nii_path=nii_path,
    bbox_slices=slices,
    save_cropped_nii_path=Path("cropped_volume_with_RF.nii.gz")
)

print("Cropped data shape:", cropped_volume.shape)
print("New affine:\n", affine_cropped_volume)

Cropped data shape: (450, 445, 168)
New affine:
 [[-1.17187500e+00  0.00000000e+00  0.00000000e+00  2.29687500e+02]
 [ 0.00000000e+00 -1.17187500e+00  0.00000000e+00  1.50943756e+02]
 [ 0.00000000e+00  0.00000000e+00  5.00000000e+00 -1.31450000e+03]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]]


### We need a way to check original mask overlapping in the new cropped volume

In [33]:
cropped_ROI_segmentation_mask, affine_ROI_segmentation_cropped_mask = crop_volume_and_affine(
    nii_path=ROI_segmentation_mask_path,
    bbox_slices=slices,
    save_cropped_nii_path=Path("cropped_mask_with_RF.nii.gz")
)

print("Cropped mask shape:", cropped_ROI_segmentation_mask.shape)
print("New mask affine:\n", affine_ROI_segmentation_cropped_mask)

Cropped mask shape: (450, 445, 168)
New mask affine:
 [[-1.17187500e+00  0.00000000e+00  0.00000000e+00  2.29687500e+02]
 [ 0.00000000e+00 -1.17187500e+00  0.00000000e+00  1.50943756e+02]
 [ 0.00000000e+00  0.00000000e+00  5.00000000e+00 -1.31450000e+03]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]]


### We also need a mask to correctly ignoring out-of ROI context in our aggregation metrics

In [34]:
cropped_ROI_mask, affine_ROI_cropped_mask = crop_volume_and_affine(
    nii_path=ROI_mask_path,
    bbox_slices=slices,
    save_cropped_nii_path=Path("cropped_mask_with_RF.nii.gz")
)

print("Cropped mask shape:", cropped_ROI_mask.shape)
print("New mask affine:\n", affine_ROI_segmentation_cropped_mask)

Cropped mask shape: (450, 445, 168)
New mask affine:
 [[-1.17187500e+00  0.00000000e+00  0.00000000e+00  2.29687500e+02]
 [ 0.00000000e+00 -1.17187500e+00  0.00000000e+00  1.50943756e+02]
 [ 0.00000000e+00  0.00000000e+00  5.00000000e+00 -1.31450000e+03]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]]


## We execute SHAP on this cropped image, and we only consider our ROI

### set device

In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [36]:
# (a) load + cropped volume  (1, C, D, H, W) – nnU-Net order
nii_path_cropped = "cropped_volume_with_RF.nii.gz"
dataset_json_path = Path(model_dir) / "dataset.json"

volume_np = nnunetv2_default_preprocessing(nii_path_cropped, predictor, dataset_json_path)

volume = torch.from_numpy(volume_np).unsqueeze(0).to(device)        # torch (1,C,D,H,W)

print("Volume shape:", volume.shape)                # (1, C, D, H, W)

Volume shape: torch.Size([1, 1, 168, 445, 450])


In [37]:
SUPERVOXEL_TYPE = "FCC"
USE_SAVED_MAP = True

In [38]:
# (b) super-voxel / organ-id map  (W, H, D)
# Load the image
img = nib.load(nii_path_cropped)

if SUPERVOXEL_TYPE == "FCC":
    supervoxel_map_path = 'FCC-supervoxel_map.nii.gz'
    if USE_SAVED_MAP and os.path.exists(supervoxel_map_path):
        supervoxel_map = nib.load(supervoxel_map_path).get_fdata()
    else:
        # Generate and save supervoxel map using the FCC
        cube_side = 100.00 # [mm]
        supervoxel_map = generate_supervoxel_map(img, S=cube_side) # (D, H, W)
        

elif SUPERVOXEL_TYPE == "SLIC":
    supervoxel_map_path = 'SLIC_supervoxel_map.nii.gz'
    if USE_SAVED_MAP and os.path.exists(supervoxel_map_path):
        supervoxel_map = nib.load(supervoxel_map_path).get_fdata()
    else:
        # compute spacing for SLIC
        data = img.get_fdata()
        affine = img.affine
        data = np.array(data, dtype=np.float32)
        
        # derive slic spacing from affine
        spacing = np.sqrt(np.sum(affine[:3, :3] ** 2, axis=0))
        spacing = tuple(spacing)  # convert to tuple for slic
        
        # Generate and save supervoxel map using SLIC
        n_supervoxels = 380
        supervoxel_map = apply_SLIC(data, spacing, n_supervoxels)
        
        # Save the result
        supervoxel_map_img = nib.Nifti1Image(supervoxel_map.astype(np.int32), affine=affine)
        nib.save(supervoxel_map_img, 'SLIC_supervoxel_map.nii.gz')
        
else:
    raise ValueError()

print("Mappa supervoxel shape:", supervoxel_map.shape)
n_supervoxels = len(np.unique(supervoxel_map))
print("Numero di supervoxels:", n_supervoxels)

Number of supervoxel centers: 384
Mappa supervoxel shape: (450, 445, 168)
Numero di supervoxels: 384


In [39]:
supervoxel_map = np.transpose(supervoxel_map, (2, 1, 0))                 # match (D,H,W)
# we need features of feature mask ordered from 0 (or 1) to M-1 (M)
sv_values, indexes = np.unique(supervoxel_map, return_inverse=True)

supervoxel_map = indexes.reshape(supervoxel_map.shape)
print("number of supervoxels: ", np.unique(supervoxel_map).size)

# IMPORTANT 🔸: KernelShapWithMask expects **(X, Y, Z)** without channel axis
supervoxel_map = torch.from_numpy(supervoxel_map).long().to(device)   # (D,H,W)

print("Mask shape:", supervoxel_map.shape)

number of supervoxels:  384
Mask shape: torch.Size([168, 445, 450])


### derive baseline cached dictionary
using planner, iterator, predictor (just temporary solution)

In [40]:
# this cause the death, try to isolate the problem

import math
import multiprocessing
import shutil
from time import sleep
from typing import Tuple

import SimpleITK
import numpy as np
import pandas as pd
from batchgenerators.utilities.file_and_folder_operations import *
from tqdm import tqdm

import nnunetv2
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw
from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero
from nnunetv2.preprocessing.resampling.default_resampling import compute_new_shape
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDatasetBlosc2
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets

def run_case_npy(preprocessor, data: np.ndarray, seg: Union[np.ndarray, None], properties: dict,
                     plans_manager: PlansManager, configuration_manager: ConfigurationManager,
                     dataset_json: Union[dict, str]):
        # let's not mess up the inputs!
        print(1)
        data = data.astype(np.float32)  # this creates a copy
        if seg is not None:
            assert data.shape[1:] == seg.shape[1:], "Shape mismatch between image and segmentation. Please fix your dataset and make use of the --verify_dataset_integrity flag to ensure everything is correct"
            seg = np.copy(seg)

        has_seg = seg is not None

        # apply transpose_forward, this also needs to be applied to the spacing!
        print(2)
        data = data.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]])
        print(3)
        if seg is not None:
            seg = seg.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]])
        print(4)
        original_spacing = [properties['spacing'][i] for i in plans_manager.transpose_forward]

        print(5)
        # crop, remember to store size before cropping!
        shape_before_cropping = data.shape[1:]
        properties['shape_before_cropping'] = shape_before_cropping
        # this command will generate a segmentation. This is important because of the nonzero mask which we may need
        print(6)
        data, seg, bbox = crop_to_nonzero(data, seg)
        properties['bbox_used_for_cropping'] = bbox
        # print(data.shape, seg.shape)
        properties['shape_after_cropping_and_before_resampling'] = data.shape[1:]

        # resample
        target_spacing = configuration_manager.spacing  # this should already be transposed

        if len(target_spacing) < len(data.shape[1:]):
            # target spacing for 2d has 2 entries but the data and original_spacing have three because everything is 3d
            # in 2d configuration we do not change the spacing between slices
            target_spacing = [original_spacing[0]] + target_spacing

        print(7)
        new_shape = compute_new_shape(data.shape[1:], original_spacing, target_spacing)

        # normalize
        # normalization MUST happen before resampling or we get huge problems with resampled nonzero masks no
        # longer fitting the images perfectly!
        print(8)
        data = preprocessor._normalize(data, seg, configuration_manager,
                               plans_manager.foreground_intensity_properties_per_channel)

        # print('current shape', data.shape[1:], 'current_spacing', original_spacing,
        #       '\ntarget shape', new_shape, 'target_spacing', target_spacing)
        print(9)
        old_shape = data.shape[1:]
        data = configuration_manager.resampling_fn_data(data, new_shape, original_spacing, target_spacing)
        print(10)
        seg = configuration_manager.resampling_fn_seg(seg, new_shape, original_spacing, target_spacing)
        if preprocessor.verbose:
            print(f'old shape: {old_shape}, new_shape: {new_shape}, old_spacing: {original_spacing}, '
                  f'new_spacing: {target_spacing}, fn_data: {configuration_manager.resampling_fn_data}')

        # if we have a segmentation, sample foreground locations for oversampling and add those to properties
        if has_seg:
            # reinstantiating LabelManager for each case is not ideal. We could replace the dataset_json argument
            # with a LabelManager Instance in this function because that's all its used for. Dunno what's better.
            # LabelManager is pretty light computation-wise.
            print(11)
            label_manager = plans_manager.get_label_manager(dataset_json)
            print(12)
            collect_for_this = label_manager.foreground_regions if label_manager.has_regions \
                else label_manager.foreground_labels

            # when using the ignore label we want to sample only from annotated regions. Therefore we also need to
            # collect samples uniformly from all classes (incl background)
            if label_manager.has_ignore_label:
                collect_for_this.append([-1] + label_manager.all_labels)

            # no need to filter background in regions because it is already filtered in handle_labels
            # print(all_labels, regions)
            print(13)
            properties['class_locations'] = preprocessor._sample_foreground_locations(seg, collect_for_this,
                                                                                   verbose=preprocessor.verbose)
            print(14)
            seg = preprocessor.modify_seg_fn(seg, plans_manager, dataset_json, configuration_manager)
        print(15)
        if np.max(seg) > 127:
            print(16)
            seg = seg.astype(np.int16)
        else:
            print(17)
            seg = seg.astype(np.int8)
        print(18)
        return data, seg, properties

In [41]:
def get_cached_output_dictionary(volume_file: Path,
                                 predictor: CustomNNUNetPredictor,
                                verbose: bool = False) -> dict:
        """
        Return a dictionary indexed by the slices for the sliding window, of the output of the inference for each patch
        of the original volume
        """
     
        rw = predictor.plans_manager.image_reader_writer_class()

        # If nnU-Net returns a class instead of an instance, instantiate it
        if callable(rw) and not hasattr(rw, "read_images"):
            rw = rw()

        orig_image, orig_props = rw.read_images(
            [str(volume_file)]
        )                     # (C, Z, Y, X)
    
        preprocessor = predictor.configuration_manager.preprocessor_class()
        # the following cause the kernel death at first notebook run
        data_pp, _, _ = preprocessor.run_case_npy(
                orig_image,
                seg=None,
                properties=orig_props,
                plans_manager=predictor.plans_manager,
                configuration_manager=predictor.configuration_manager,
                dataset_json=predictor.dataset_json
            )
     
        # to torch, channel-first is already true
        inp_tensor = torch.from_numpy(data_pp)

        slicers = predictor._internal_get_sliding_window_slicers(inp_tensor.shape[1:])
       
        if verbose:
            print("first 3 slicers of Iterator object: ", slicers[:3])

        dictionary = predictor.get_output_dictionary_sliding_window(inp_tensor, slicers)

        return dictionary

In [42]:
USE_STORED_DICTIONARY = True

In [43]:
import pickle as pkl

if USE_STORED_DICTIONARY and os.path.exists("cropped_baseline_output_dictionary_cache.pkl"):
    with open("cropped_baseline_output_dictionary_cache.pkl", "rb") as f:
        cropped_baseline_pred_cache = pkl.load(f)
else:
    cropped_baseline_pred_cache = get_cached_output_dictionary(
        volume_file = nii_path_cropped,
        predictor = predictor,
        verbose = True,
    )
    # Write to file
    with open("cropped_baseline_output_dictionary_cache.pkl", "wb") as f:
        pkl.dump(cropped_baseline_pred_cache, f)


In [44]:
print(cropped_baseline_pred_cache[(None, None, None), (0, 72, None), (0, 160, None), (0, 160, None)].shape)

torch.Size([2, 72, 160, 160])


### get our cropped ROI segmentatoin mask

In [45]:
# segmentation mask cropped to ROI, with background extended to rf
ROI_segmentation_mask = np.transpose(cropped_ROI_segmentation_mask, (2, 1, 0))
ROI_segmentation_mask = torch.from_numpy(ROI_segmentation_mask).to(device)

print(ROI_segmentation_mask.shape)

# ROI bounding box with background extended to RF
ROI_mask = np.transpose(cropped_ROI_mask, (2, 1, 0))
ROI_mask = torch.from_numpy(ROI_mask).to(device)

print(ROI_mask.shape)

torch.Size([168, 445, 450])
torch.Size([168, 445, 450])


### Include masking in the forward function

### **Chrabaszcz aggregation**  

Let

* $z_1^{(i)}(x)$ – class-1 logit at voxel $x$ after perturbation *i*
* $P_i(x)=\mathbf 1\!\left[\arg\max_c z_c^{(i)}(x)=1\right]$ – binary mask of voxels currently predicted as lymph-node
* no ROI, total volume considered

$$
S_{\text{Chr}}^{(i)} \;=\;
\frac{1}{\alpha}\sum_{x} P_i(x)\;z_1^{(i)}(x)
$$

where $\alpha$ is the `scaling_factor`.

* **Counts evidence only from voxels the model *currently* labels as class 1.**
* *False positives (FP):* contribute **positively** (they are in $P_i$).
* *False negatives (FN):* contribute **zero** (their logit is absent).

Source: Chrabaszcz et al., *Aggregated Attributions for Explanatory Analysis of 3-D Segmentation Models*, 2024.


In [46]:
def chrabaszcz_aggregation(logits: torch.Tensor,
                           scaling_factor: float = 1.0,
                          ) -> torch.Tensor:
    """
    aggregate the output logits in a sum, following the proposed method in "Chrabaszcz et al. - 2024 - Aggregated Attributions for
    Explanatory Analysis of 3D Segmentation Models"
    """
    seg_mask = (torch.argmax(logits, dim=0) == 1)
    aggregate = torch.sum(logits[1].double() * seg_mask)

    return aggregate / scaling_factor  # normalize to avoid overflows in SHAP


### **True positive aggregation** (Chrabaszcz aggregation + baseline-mask filtering)

Introduce the unperturbed prediction $P_0$.
Keep only voxels that are **still** class 1 *and* were class 1 before:

$$
S_{\text{Chr\,keep}}^{(i)} \;=\;
\frac{1}{\alpha}\sum_{x} \bigl[P_i(x)\land P_0(x)\bigr]\;z_1^{(i)}(x)
$$

* **True positives preserved** (TP core) add positive evidence.
* **FP created by the perturbation** are **ignored** (masked out).
* **FN** lower the score indirectly because their logits disappear from the sum.

Conceptually this is the **positive part** of a signed logit-difference metric.

In [47]:
def true_positive_aggregation(logits: torch.Tensor,
                          unperturbed_binary_mask: torch.Tensor,
                           scaling_factor: float = 1.0,
                          ) -> torch.Tensor:
    """
    aggregate the output logits in a sum, following the proposed method in "Chrabaszcz et al. - 2024 - Aggregated Attributions for
    Explanatory Analysis of 3D Segmentation Models", with the  addition of filtering by the unperturbed segmentation.
    We can use this to ignore "false positive" voxels -> only account for true positive contribution;
    so this corresponds conceptually to the positive part of a logits difference metric
    """
    seg_mask = (torch.argmax(logits, dim=0) == 1)          # (D,H,W)
    seg_mask = seg_mask.bool() & unperturbed_binary_mask.bool()  # prefer boolean indexing for reletively sparse tensors
    aggregate = torch.sum(logits[1].double()[seg_mask])

    return aggregate / scaling_factor  # normalize to avoid overflows in SHAP

---

### **False-positive aggregation**

Directly sum class-1 evidence from **new** positives inside ROI:

$$
S_{\text{FP}}^{(i)} \;=\;
\frac{1}{\alpha}\sum_{x}
\bigl[P_i(x)\land\neg P_0(x)\land R(x)\bigr]\;z_1^{(i)}(x)
$$

* Measures **only** the spurious lymph-node evidence a perturbation introduces.
* Higher value ⇒ stronger tendency to hallucinate extra nodes.

In [48]:
def false_positive_aggregation(logits: torch.Tensor,
                              unperturbed_binary_mask: torch.Tensor,
                              ROI_mask: torch.Tensor,
                              scaling_factor: float = 1.0,
                              ) -> torch.Tensor:
    """
    Negative part of signed logit-difference objective, returned with positive sign;
    Only accounts for false positive voxels in segmentation (spurious lymph nodes)
    """
    # current segmentation (prevailing class)
    seg_mask = (torch.argmax(logits, dim=0) == 1)     # (D,H,W) ∈ {0,1}

    fp_mask  = seg_mask * ROI_mask * torch.logical_not(unperturbed_binary_mask.bool()).float()   # prefer float multiplication for dense tensors
    aggregate = torch.sum(logits[1].double() * fp_mask) 

    return aggregate / scaling_factor

---

### **Dice aggregation (prediction-vs-baseline, ROI-restricted)**

Let $R(x)$ be the ROI mask.

$$
P_i' = P_i \odot R, \qquad
P_0' = P_0 \odot R
$$

$$
S_{\text{Dice}}^{(i)} \;=\;
\frac{1}{\alpha}\;
\frac{2\,\langle P_i',\,P_0'\rangle}{\lVert P_i'\rVert_1 + \lVert P_0'\rVert_1 + \varepsilon}
$$

* Drops when either **FP** ($P_i'=1,\,P_0'=0$) or **FN** ($P_i'=0,\,P_0'=1$) appear → penalises both error types symmetrically.

Based on the “self-consistency Dice” used in MiSuRe (Hasany et al., 2024).





In [49]:
def dice_aggregation(logits: torch.Tensor,
                    unperturbed_binary_mask: torch.Tensor,
                    ROI_mask: torch.Tensor,
                    scaling_factor: float = 1.0,
                    eps: float = 1e-9,  # small value to avoid division by zero
                    ) -> torch.Tensor:
    """
    Use Dice score, the same aggregation measure from "Hasany et al. - 2024 - MiSuRe is all you need to explain your image segmentation"
    Dice score provides a single aggregation metric that accounts for both false negatives and false positives penalization.
    Specificly, we instead score each perturbation supervoxels by that Dice => supervoxels that contribute the most in reproducing
    the baseline segmentation, will get an higher score
    """
    # 1. Boolean masks restricted to ROI
    pred = (logits.argmax(dim=0) == 1).float() * ROI_mask.float()
    base = unperturbed_binary_mask.float()      * ROI_mask.float()

    # 2. Intersection and denominator
    inter = (pred * base).sum()
    denom = pred.sum() + base.sum() + eps       # |P| + |B|

    # 3. Dice coefficient
    dice = (2.0 * inter) / denom

    return dice / scaling_factor

---

### **Signed logit-difference (masked)**

Define a signed weight

$$
w(x)=
\begin{cases}
+1 & \text{if } P_0(x)=1\\
-1 & \text{otherwise}
\end{cases},
\qquad w(x)\leftarrow w(x)\,R(x)
$$

$$
S_{\text{LD}}^{(i)} \;=\;
\frac{1}{\alpha}\sum_{x} P_i(x)\;w(x)\;z_1^{(i)}(x)
$$

* **Positive attribution:** voxels that *keep* the baseline TP (support segmentation).
* **Negative attribution:** voxels that become class 1 **only** after perturbation (generate FP inside ROI).
* FN reduce the positive term (logits disappear) but do **not** add negative mass.

In [50]:


def logit_difference_aggregation(
        logits: torch.Tensor,
        unperturbed_binary_mask: torch.Tensor,
        ROI_mask: torch.Tensor,
        scaling_factor: float = 1.0,
) -> torch.Tensor:
    """
    Signed logit-difference objective *masked by the prevailing class*.
    """
    # current segmentation (prevailing class)
    seg_mask = (torch.argmax(logits, dim=0) == 1)     # (D,H,W) ∈ {0,1}

    # +1 inside baseline positives, −1 elsewhere...
    signed_weight = torch.where(unperturbed_binary_mask.bool(),
                                torch.tensor(1.0, device=logits.device),
                                torch.tensor(-1.0, device=logits.device))

    # ... but we only care of false positives inside the ROI (we don't even have the segmentation mask outside the ROI)
    signed_weight = signed_weight * ROI_mask

    # aggregate signed class-1 evidence, restricted to voxels
    # that are *currently* predicted as class-1 (seg_mask)
    aggregate = torch.sum(logits[1] * seg_mask * signed_weight)
    return aggregate / scaling_factor

In [None]:
# ------------------------------------------------------------
# ❷  Forward wrapper that nnU-Net expects
# ------------------------------------------------------------

@torch.inference_mode()
def forward_segmentation_output_to_explain(
        input_image:         torch.Tensor,
        perturbation_mask:   torch.BoolTensor | None,
        ROI_segmentation_mask:      torch.Tensor,   # remember that must be cropped to the same size of the other tensors
        ROI_bounding_box_mask:      torch.Tensor,
        baseline_prediction_dict: dict
) -> torch.Tensor:           # returns a scalar per sample
    """
    Example aggregate: sum of lymph-node logits (class 1) in the mask produced
    by the network – adapt to your real metric as needed.
    """
    logits = predictor.predict_sliding_window_return_logits_with_caching(
        input_image, perturbation_mask, baseline_prediction_dict,
    )                              # (C, D, H, W)
    # we now mask both by the segmentation prevalent class, and by ROI
    D,H,W = logits.shape[1:]
    """aggregate = logit_difference_aggregation(
        logits = logits,
        unperturbed_binary_mask = ROI_segmentation_mask,
        ROI_mask = ROI_bounding_box_mask,
        scaling_factor = (D*H*W)
    )"""
    aggregate = true_positive_aggregation(
                    logits, 
                    ROI_segmentation_mask,
                    scaling_factor=(D*W*H),
    )

    return aggregate

# c) wrap your cached‐forward method:
explainer = KernelShapWithMask(
    forward_func=lambda vol, _perturbation_mask: forward_segmentation_output_to_explain(
        input_image=vol,
        perturbation_mask=_perturbation_mask,
        ROI_segmentation_mask=ROI_segmentation_mask,
        ROI_bounding_box_mask=ROI_mask,
        baseline_prediction_dict=cropped_baseline_pred_cache),
    #surrogate_model = "lasso",
    #alpha_surrogate = 0.003,
    #max_iter_surrogate = 10000
)

# d) compute SHAP
attr = explainer.attribute(
    inputs=volume,       # (1,C,D,H,W)
    baselines=0.0, 
    feature_mask=supervoxel_map,
    n_samples=4000,
    return_input_shape=True,
    monitor_log_path="monitor.jsonl",
    monitor_convergence_step=10,
    monitor_local_accuracy_step=20,
    show_progress=True,
)
print("Attributions:", attr.shape)  # → (1,C,D,H,W)


Kernel Shap With Mask attribution:  93%|█████████████████████████████████████   | 3703/4000 [3:50:05<15:41,  3.17s/it]

Kernel Shap With Mask attribution:   0%|▏                                         | 14/4000 [00:42<2:55:07,  2.64s/it]

In [59]:
import pickle
with open('dataset-tp.pkl', 'wb') as file:
    pickle.dump(explainer.dataset, file)

In [60]:
attr_postprocessed = attr[0][0].detach().cpu().numpy().transpose(2,1,0) # (W, H, D)
attr_img = nib.Nifti1Image(attr_postprocessed, affine_cropped_volume)
nib.save(attr_img, 'attribution_map-TP.nii.gz')