In [63]:
# Import basic packages for later use
import os
import shutil
from collections import OrderedDict

import json
import matplotlib.pyplot as plt
import nibabel as nib

import numpy as np
import torch

In [64]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [65]:
IN_KAGGLE = os.path.exists('/kaggle/input')
IN_COLAB = not IN_KAGGLE and os.path.exists('/content')
IN_DEIB = not IN_KAGGLE and not IN_COLAB

In [66]:
if not IN_DEIB:
    !pip install nnunetv2
    !pip install captum

In [67]:
from batchgenerators.utilities.file_and_folder_operations import join

if IN_COLAB:
    # Google Colab
    # for colab users only - mounting the drive

    from google.colab import drive
    drive.mount('/content/drive',force_remount = True)

    drive_dir = "/content/drive/My Drive"
    mount_dir = join(drive_dir, "tesi", "automi")
    base_dir = os.getcwd()
elif IN_KAGGLE:
    # Kaggle
    mount_dir = "/kaggle/input/"
    base_dir = os.getcwd()
    print(base_dir)
    !ls '/kaggle/input'
    !cd "/kaggle/input/automi-seg" ; ls
else:
    mount_dir = "/workspace/data"
    base_dir = "/workspace/output"
    os.chdir(base_dir)
    print("base_dir:", base_dir)

base_dir: /workspace/output


# 3. Setting up nnU-Nets folder structure and environment variables
nnUnet expects a certain folder structure and environment variables.

Roughly they tell nnUnet:
1. Where to look for stuff
2. Where to put stuff

For more information about this please check: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/setting_up_paths.md

## 3.1 Set environment Variables and creating folders

In [68]:
# ===========================
# 📦 SETUP nnUNet ENVIRONMENT
# ===========================

# Definisci i path da settare
path_dict = {
    "nnUNet_raw": join(mount_dir, "nnunet_raw"),
    "nnUNet_preprocessed": join(mount_dir, "preprocessed_files"),#"nnUNet_preprocessed"),
    "nnUNet_results": join(mount_dir, "results"),#"nnUNet_results"),
    # "RAW_DATA_PATH": join(mount_dir, "RawData"),  # Facoltativo, se ti serve salvare zips
}

# Scrivi i path nelle variabili di ambiente, che vengono lette dal modulo paths di nnunetv2
for env_var, path in path_dict.items():
    os.environ[env_var] = path

from nnunetv2.paths import nnUNet_results, nnUNet_raw

if IN_KAGGLE:
    if nnUNet_raw == None:
        nnUNet_raw = "/kaggle/input/nnunet_raw"
    if nnUNet_results == None:
        nnUNet_results = "/kaggle/input/results"
    # Kaggle has some very unconsistent behaviors in dataset mounting...
    #nnUNet_raw = "/kaggle/input/automi-seg/nnunet_raw"
    #nnUNet_results = "/kaggle/input/automi-seg/results"
    
print("nnUNet_raw:", nnUNet_raw)
print("nnUNet_results:", nnUNet_results)

nnUNet_raw: /workspace/data/nnunet_raw
nnUNet_results: /workspace/data/results


### Some tests

In [69]:
# all volumes of fold 0 test set except 00039, already examined
volume_codes = ["00004", "00005", "00024", "00027", "00029", "00034", "00044"]

In [70]:
ct_img_paths = {}
organ_mask_paths = {}

for volume_code in volume_codes:
    if IN_KAGGLE:
        ct_img_paths[volume_code] = join(nnUNet_raw, "imagesTr", f"AUTOMI_{volume_code}_0000.nii")
        organ_mask_paths[volume_code] = join(nnUNet_raw, "total_segmentator_structures", f"AUTOMI_{volume_code}_0000", "mask_mask_add_input_20_total_segmentator.nii")
    else:
        ct_img_paths[volume_code] = join(nnUNet_raw, "imagesTr", f"AUTOMI_{volume_code}_0000.nii.gz")
        organ_mask_paths[volume_code] = join(nnUNet_raw, "total_segmentator_structures", f"AUTOMI_{volume_code}_0000", "mask_mask_add_input_20_total_segmentator.nii.gz")
    ct_img = nib.load(ct_img_paths[volume_code])
    organ_mask = nib.load(organ_mask_paths[volume_code])
    print(f"Volume {volume_code}:")
    print("CT shape:", ct_img.shape)
    print("Organ shape:", organ_mask.shape)
    print("Spacing:", ct_img.header.get_zooms())
    print("Organ spacing:", organ_mask.header.get_zooms())
    assert np.all(ct_img.affine == organ_mask.affine), "CT and organ mask affine matrices do not match!"

Volume 00004:
CT shape: (512, 512, 221)
Organ shape: (512, 512, 221)
Spacing: (1.3671875, 1.3671875, 5.0)
Organ spacing: (1.3671875, 1.3671875, 5.0)
Volume 00005:
CT shape: (512, 512, 219)
Organ shape: (512, 512, 219)
Spacing: (1.171875, 1.171875, 5.0)
Organ spacing: (1.171875, 1.171875, 5.0)
Volume 00024:
CT shape: (512, 512, 321)
Organ shape: (512, 512, 321)
Spacing: (0.976562, 0.976562, 3.0)
Organ spacing: (0.976562, 0.976562, 3.0)
Volume 00027:
CT shape: (512, 512, 236)
Organ shape: (512, 512, 236)
Spacing: (1.3671875, 1.3671875, 5.0)
Organ spacing: (1.3671875, 1.3671875, 5.0)
Volume 00029:
CT shape: (512, 512, 259)
Organ shape: (512, 512, 259)
Spacing: (1.367188, 1.367188, 7.5)
Organ spacing: (1.367188, 1.367188, 7.5)
Volume 00034:
CT shape: (512, 512, 249)
Organ shape: (512, 512, 249)
Spacing: (1.171875, 1.171875, 5.0)
Organ spacing: (1.171875, 1.171875, 5.0)
Volume 00044:
CT shape: (512, 512, 232)
Organ shape: (512, 512, 232)
Spacing: (1.1712891, 1.1712891, 5.0)
Organ spacing: (

In [71]:
ct_img = nib.load(ct_img_paths[volume_codes[0]])
print("affine:", ct_img.affine)

affine: [[  -1.3671875     0.            0.          350.        ]
 [   0.           -1.3671875     0.          278.6000061 ]
 [   0.            0.            5.         -432.23999023]
 [   0.            0.            0.            1.        ]]


## Re-align CT scan with its own organ segmentation mask

### apparently, this was solved in the remaining volumes

In [72]:
"""import SimpleITK as sitk

# Load CT and misaligned organ mask
ct = sitk.ReadImage(ct_img_path, sitk.sitkFloat32)
organ_mask = sitk.ReadImage(organ_mask_path, sitk.sitkUInt8)

# Resample organ mask to match CT space
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(ct)
resampler.SetInterpolator(sitk.sitkNearestNeighbor)
organ_resampled = resampler.Execute(organ_mask)

# Save aligned output
sitk.WriteImage(organ_resampled, "organ_mask_resampled_to_ct.nii.gz")"""

'import SimpleITK as sitk\n\n# Load CT and misaligned organ mask\nct = sitk.ReadImage(ct_img_path, sitk.sitkFloat32)\norgan_mask = sitk.ReadImage(organ_mask_path, sitk.sitkUInt8)\n\n# Resample organ mask to match CT space\nresampler = sitk.ResampleImageFilter()\nresampler.SetReferenceImage(ct)\nresampler.SetInterpolator(sitk.sitkNearestNeighbor)\norgan_resampled = resampler.Execute(organ_mask)\n\n# Save aligned output\nsitk.WriteImage(organ_resampled, "organ_mask_resampled_to_ct.nii.gz")'

In [73]:
"""#organ_mask_path = join(nnUNet_raw, "organ_mask_resampled_to_ct.nii.gz")
organ_mask_path = "organ_mask_resampled_to_ct.nii.gz"
ct_img = nib.load(ct_img_path)
organ_mask = nib.load(organ_mask_path)
print("CT shape:", ct_img.shape)
print("Organ shape:", organ_mask.shape)
print("Spacing:", ct_img.header.get_zooms())
print("Organ spacing:", organ_mask.header.get_zooms())"""

'#organ_mask_path = join(nnUNet_raw, "organ_mask_resampled_to_ct.nii.gz")\norgan_mask_path = "organ_mask_resampled_to_ct.nii.gz"\nct_img = nib.load(ct_img_path)\norgan_mask = nib.load(organ_mask_path)\nprint("CT shape:", ct_img.shape)\nprint("Organ shape:", organ_mask.shape)\nprint("Spacing:", ct_img.header.get_zooms())\nprint("Organ spacing:", organ_mask.header.get_zooms())'

In [81]:
# model directory; note that this is readonly in Kaggle environment
if IN_KAGGLE:
    model_dir = join(nnUNet_results, 'Dataset003_AUTOMI_CTVLNF_NEWGL_results/nnUNetTrainer__nnUNetPlans__3d_fullres')
else:
    model_dir = join(nnUNet_results, 'nnUNetTrainer__nnUNetPlans__3d_fullres')

print("Model directory:", model_dir)

Model directory: /workspace/data/results/nnUNetTrainer__nnUNetPlans__3d_fullres


## Utility to export logits to a visualizable segmentation

In [75]:
import numpy as np
import torch
import os
from typing import Union
from pathlib import Path
from nnunetv2.configuration import default_num_processes
from nnunetv2.inference.export_prediction import export_prediction_from_logits

def export_logits_to_nifty_segmentation(
    predictor,
    volume_file: Path,
    model_dir: str,
    logits: Union[str, np.ndarray, torch.Tensor],
    npz_dir: str | None,
    output_dir: str = "",
    fold: int = 0,
    save_probs: bool = False,
    from_file: bool = True
):
    """
    Converts a saved .npz logits file into a native-space NIfTI segmentation using nnU-Net's helper.

    Args:
        predictor: An instantiated nnU-Net predictor object with loaded plans/configs.
        volume_file: An Path object pointing to the raw image file.
        model_dir (str): Path to the nnU-Net mode1Introductionl directory containing dataset.json.
        logits (str or tensor): Base name of the .npz logits file (no extension), when from_file=True
        npz_dir (str): Directory where the .npz file is stored.
        output_dir (str): Directory where the .nii.gz segmentation will be saved.
        fold (int): The fold number used for prediction (default is 0).
        save_probs (bool): Whether to save softmax probabilities as a .npz file.
        from_file (bool). Whether to convert from a file instead of from the logits (default true)
    """
    if from_file:
        npz_logits = Path(npz_dir) / f"{logits}.npz"
        output_nii = Path(output_dir) / f"{logits}_seg.nii.gz"
        logits = np.load(npz_logits)["logits"]
    else:
        output_nii = Path(output_dir) / "exported_seg.nii.gz"

    plans_manager = predictor.plans_manager
    configuration_manager = predictor.configuration_manager
    dataset_json = Path(model_dir) / "dataset.json"

    preprocessor = configuration_manager.preprocessor_class(verbose=False)
    rw = plans_manager.image_reader_writer_class()
    if callable(rw) and not hasattr(rw, "read_images"):
        rw = rw()
    img_np, img_props = rw.read_images([str(volume_file)])

    _, _, data_props = preprocessor.run_case_npy(
        img_np, seg=None, properties=img_props,
        plans_manager=plans_manager,
        configuration_manager=configuration_manager,
        dataset_json=dataset_json
    )


    export_prediction_from_logits(
        predicted_array_or_file=logits,
        properties_dict=data_props,
        configuration_manager=configuration_manager,
        plans_manager=plans_manager,
        dataset_json_dict_or_file=str(dataset_json),
        output_file_truncated=os.path.splitext(str(output_nii))[0],
        save_probabilities=save_probs,
        num_threads_torch=default_num_processes
    )

    print(f"✅  NIfTI segmentation written → {output_nii}")

In [76]:
"""export_logits_to_nifty_segmentation(
    predictor=predictor,
    plan=plan,
    model_dir=Path(model_dir),
    logits_filename="pred_00007",
    npz_dir="SHAP/shap_run",
    output_dir="SHAP/shap_run",
    fold=0,
    save_probs=False
)"""

'export_logits_to_nifty_segmentation(\n    predictor=predictor,\n    plan=plan,\n    model_dir=Path(model_dir),\n    logits_filename="pred_00007",\n    npz_dir="SHAP/shap_run",\n    output_dir="SHAP/shap_run",\n    fold=0,\n    save_probs=False\n)'

## Create test perturbation covering one, full organ each

In [77]:
"""#output_dir = join(nnUNet_raw, "perturbed_images")
output_dir = "perturbed_image"
make_if_dont_exist(output_dir, overwrite=False)

ct_data = ct_img.get_fdata()
organ_mask_data = organ_mask.get_fdata().astype(np.int32)

# --- IDENTIFY ORGANS PRESENT ---
organ_ids = np.unique(organ_mask_data)
organ_ids = organ_ids[organ_ids != 0]  # Exclude background

print(f"Found {len(organ_ids)} organs in the mask: {organ_ids}")

# --- CREATE PERTURBED CTs IF NEEDED ---
perturbed_paths = []

for organ_id in organ_ids:
    out_path = join(output_dir, f"ct_perturbed_without_organ_{organ_id}.nii.gz")

    if os.path.exists(out_path):
        print(f"✔ Organ {organ_id} already processed — skipping.")
        perturbed_paths.append((organ_id, out_path))
        continue

    print(f"Generating perturbation for organ ID: {organ_id}")

    # Mask out voxels belonging to this organ
    ct_perturbed = np.where(organ_mask_data == organ_id, 0, ct_data)

    # Create and save NIfTI image
    perturbed_img = nib.Nifti1Image(ct_perturbed, affine=ct_img.affine)
    nib.save(perturbed_img, out_path)

    perturbed_paths.append((organ_id, out_path))

print("\n✅ Perturbed volume generation completed!")"""

'#output_dir = join(nnUNet_raw, "perturbed_images")\noutput_dir = "perturbed_image"\nmake_if_dont_exist(output_dir, overwrite=False)\n\nct_data = ct_img.get_fdata()\norgan_mask_data = organ_mask.get_fdata().astype(np.int32)\n\n# --- IDENTIFY ORGANS PRESENT ---\norgan_ids = np.unique(organ_mask_data)\norgan_ids = organ_ids[organ_ids != 0]  # Exclude background\n\nprint(f"Found {len(organ_ids)} organs in the mask: {organ_ids}")\n\n# --- CREATE PERTURBED CTs IF NEEDED ---\nperturbed_paths = []\n\nfor organ_id in organ_ids:\n    out_path = join(output_dir, f"ct_perturbed_without_organ_{organ_id}.nii.gz")\n\n    if os.path.exists(out_path):\n        print(f"✔ Organ {organ_id} already processed — skipping.")\n        perturbed_paths.append((organ_id, out_path))\n        continue\n\n    print(f"Generating perturbation for organ ID: {organ_id}")\n\n    # Mask out voxels belonging to this organ\n    ct_perturbed = np.where(organ_mask_data == organ_id, 0, ct_data)\n\n    # Create and save NIfTI 

In [78]:
# NO MORE NEEDED 

# command for copying in kaggle working directory
#!cp -r /kaggle/input/automi-seg/results/Dataset003_AUTOMI_CTVLNF_NEWGL_results ./

model_dir_readonly = join(nnUNet_results, 'Dataset003_AUTOMI_CTVLNF_NEWGL_results/nnUNetTrainer__nnUNetPlans__3d_fullres')
model_dir = join('Dataset003_AUTOMI_CTVLNF_NEWGL_results', 'nnUNetTrainer__nnUNetPlans__3d_fullres')

In [79]:

"""
import gzip
import shutil
import os

# Indica qui la cartella dove sono i .nii
folder = join(model_dir, 'fold_0', 'test_img')  # esempio: '/kaggle/input/your_dataset'

for filename in os.listdir(folder):
    if filename.endswith('.nii') and not filename.endswith('.nii.gz'):
        nii_path = os.path.join(folder, filename)
        gz_path = nii_path + '.gz'

        # Comprime il file .nii in .nii.gz
        with open(nii_path, 'rb') as f_in:
            with gzip.open(gz_path, 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)

        # Dopo la compressione, elimina il file .nii originale
        os.remove(nii_path)

print("Compressione completata! Tutti i file sono ora in formato .nii.gz")"""

'\nimport gzip\nimport shutil\nimport os\n\n# Indica qui la cartella dove sono i .nii\nfolder = join(model_dir, \'fold_0\', \'test_img\')  # esempio: \'/kaggle/input/your_dataset\'\n\nfor filename in os.listdir(folder):\n    if filename.endswith(\'.nii\') and not filename.endswith(\'.nii.gz\'):\n        nii_path = os.path.join(folder, filename)\n        gz_path = nii_path + \'.gz\'\n\n        # Comprime il file .nii in .nii.gz\n        with open(nii_path, \'rb\') as f_in:\n            with gzip.open(gz_path, \'wb\') as f_out:\n                shutil.copyfileobj(f_in, f_out)\n\n        # Dopo la compressione, elimina il file .nii originale\n        os.remove(nii_path)\n\nprint("Compressione completata! Tutti i file sono ora in formato .nii.gz")'

## Do some test inference on the perturbed images

In [83]:
print(os.listdir("/workspace/data/results/nnUNetTrainer__nnUNetPlans__3d_fullres"))
print("Model directory:", model_dir)
assert os.path.exists(model_dir + "/dataset.json")

['plans.json', 'fold_0', 'fold3_evaluation.xlsx', 'fold0_evaluation.xlsx', 'fold2_evaluation.xlsx', 'fold4_evaluation.xlsx', 'crossval_results_folds_0_1_2_3_4', 'dataset_fingerprint.json', '.DS_Store', 'fold_2', 'fold_3', 'dataset.json', 'fold_1', 'fold_4', 'fold1_evaluation.xlsx']
Model directory: /workspace/data/results/nnUNetTrainer__nnUNetPlans__3d_fullres


In [84]:
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

# instantiate the nnUNetPredictor
predictor = nnUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=True, # == test time augmentation
    perform_everything_on_device=True,
    device=torch.device('cuda', 0),
    verbose=False,
    verbose_preprocessing=False,
    allow_tqdm=True
)
# initializes the network architecture, loads the checkpoint
predictor.initialize_from_trained_model_folder(
    model_dir,
    use_folds=(0,),
    checkpoint_name='checkpoint_final.pth',
)


# variant 1: give input and output folders
imgs_dir = join(model_dir, 'fold_0', 'test_img')

predictor.predict_from_files(#[[join(imgs_dir, 'AUTOMI_00004_0000.nii.gz')],
                             #[join(imgs_dir, 'AUTOMI_00005_0000.nii.gz')]],
  imgs_dir,#, 'test_img_prova_v2'),
                          #join(nnUNet_raw, 'Dataset003_AUTOMI_CTVLNF_NEWGL', 'predTs_v2'),
                           #['provaOutputSingolo/AUTOMI_00004.nii.gz',
                           #'provaOutputSingolo/AUTOMI_00005.nii.gz'],
                           "predTs_v3",
                          save_probabilities=False, overwrite=False,
                          num_processes_preprocessing=3, num_processes_segmentation_export=3,
                          folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)




There are 8 cases in the source folder
I am processing 0 out of 1 (max process ID is 0, we start counting with 0!)
There are 8 cases that I would like to predict
overwrite was set to False, so I am only working on cases that haven't been predicted yet. That's 8 cases.


RuntimeError: Background workers died. Look for the error message further up! If there is none then your RAM was full and the worker was killed by the OS. Use fewer workers or get more RAM in that case!

# SHAP full organ, first test: 
## Produce a list of perturbations and apply them one by one to a single volume, and store predictions

In [None]:
"""
Produce and save a *perturbation schedule* for organ-wise SHAP analysis.

Each entry in the schedule is a lightweight dict:
{
    "id"          : 7,                       # running integer
    "organs_off"  : [1, 4, 5],               # organ labels to zero-out
    "pert_type"   : "zero",                  # could be "blur", "noise", …
    "seed"        : 12345                    # for stochastic perturbations
}

Nothing is pre-computed or stored in RAM apart from this list.
"""

from __future__ import annotations
from dataclasses import dataclass, asdict
from pathlib import Path
import json
import itertools
import numpy as np
import nibabel as nib
from typing import List, Sequence, Dict, Union, Optional, Tuple
import hashlib
import random

__all__ = ["OrganMaskPerturbationPlanner", "PerturbationDescriptor"]


@dataclass(frozen=True)
class PerturbationDescriptor:
    """
    An immutable description of *one* perturbation.
    """
    id: int
    organs_off: Tuple[int, ...]          # tuple so dataclass is hashable
    pert_type: str = "zero"
    seed: int = 0                        # default deterministic

    # convenient JSON-serialisation
    def to_json_dict(self) -> Dict:
        return {"id": self.id,
                "organs_off": list(self.organs_off),
                "pert_type": self.pert_type,
                "seed": self.seed}

    @staticmethod
    def from_json_dict(d: Dict) -> "PerturbationDescriptor":
        return PerturbationDescriptor(id=int(d["id"]),
                                      organs_off=tuple(d["organs_off"]),
                                      pert_type=d["pert_type"],
                                      seed=int(d["seed"]))


class OrganMaskPerturbationPlanner:
    """
    Reads a volume + organ mask, enumerates organs present, and
    *builds a perturbation plan* for KernelSHAP.

    Notes
    -----
    * The *baseline* (all organs ON) is id==0 by convention.
    * For classic KernelSHAP you usually need **all single-organ OFF**
      plus a random sample of joint subsets – we implement that.
    * The resulting JSON is < 100 kB even for hundreds of organs,
      so it is safe to keep it entirely in RAM.
    """

    #: version string – bump if JSON schema ever changes
    schema_version = "1.0"

    def __init__(self,
                 volume_file: Union[str, Path],
                 organ_mask_file: Union[str, Path],
                 strategy: str = "single-off+random",     # default
                 n_random: int = 0,                       # extra subsets
                 perturbation_type: str = "zero",
                 seed: int = 12345,
                 out_json: Optional[Union[str, Path]] = None):

        self.volume_file = Path(volume_file)
        self.organ_mask_file = Path(organ_mask_file)
        self.strategy = strategy
        self.n_random = int(n_random)
        self.perturbation_type = perturbation_type
        self.seed = int(seed)
        self._rng = random.Random(seed)

        # ------------------------------------------------------------------
        # 1) parse mask → list of organ labels
        # ------------------------------------------------------------------
        mask = nib.load(str(self.organ_mask_file)).get_fdata().astype(np.int32)
        organ_ids = np.unique(mask)
        organ_ids = organ_ids[organ_ids != 0]          # drop background
        self.organs: Tuple[int, ...] = tuple(int(x) for x in organ_ids)

        # ------------------------------------------------------------------
        # 2) build schedule
        # ------------------------------------------------------------------
        self.schedule: List[PerturbationDescriptor] = self._build_schedule()

        # ------------------------------------------------------------------
        # 3) optionally dump to JSON
        # ------------------------------------------------------------------
        if out_json is not None:
            self.save_json(out_json)

    # ------------------------------------------------------------------
    # public helpers
    # ------------------------------------------------------------------
    def save_json(self, path: Union[str, Path]) -> Path:
        path = Path(path)
        obj = {"schema": self.schema_version,
               "volume_file": str(self.volume_file),
               "organ_mask_file": str(self.organ_mask_file),
               "strategy": self.strategy,
               "n_random": self.n_random,
               "perturbation_type": self.perturbation_type,
               "seed": self.seed,
               "organs": list(self.organs),
               "schedule": [p.to_json_dict() for p in self.schedule]}
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open("w", encoding="utf-8") as f:
            json.dump(obj, f, indent=2)
        return path

    @classmethod
    def load_json(cls, path: Union[str, Path]) -> "OrganMaskPerturbationPlanner":
        with Path(path).open("r", encoding="utf-8") as f:
            obj = json.load(f)
        planner = cls(volume_file=obj["volume_file"],
                      organ_mask_file=obj["organ_mask_file"],
                      strategy=obj["strategy"],
                      n_random=obj["n_random"],
                      perturbation_type=obj["perturbation_type"],
                      seed=obj["seed"])
        # overwrite schedule that __init__ just built
        planner.schedule = [PerturbationDescriptor.from_json_dict(d)
                            for d in obj["schedule"]]
        return planner

    # ------------------------------------------------------------------
    # implementation details
    # ------------------------------------------------------------------
    def _build_schedule(self) -> List[PerturbationDescriptor]:
        """
        Build a list of PerturbationDescriptor according to the chosen strategy.
        *id==0* is always the **baseline** (no perturbation).
        """
        sched: List[PerturbationDescriptor] = [
            PerturbationDescriptor(id=0,
                                   organs_off=tuple(),       # none
                                   pert_type="identity",
                                   seed=self.seed)
        ]

        if self.strategy in {"single-off", "single-off+random"}:
            # --- (a) single-organ OFF ---
            for i, organ in enumerate(self.organs, start=1):
                sched.append(PerturbationDescriptor(id=i,
                                                    organs_off=(organ,),
                                                    pert_type=self.perturbation_type,
                                                    seed=self.seed))
            next_id = len(sched)                                # keep running id

            # --- (b) plus random K subsets if requested ---
            if self.strategy.endswith("+random") and self.n_random > 0:
                power_set = list(itertools.chain.from_iterable(
                    itertools.combinations(self.organs, r)
                    for r in range(1, len(self.organs) + 1)))
                self._rng.shuffle(power_set)

                for subset in power_set[: self.n_random]:
                    sched.append(PerturbationDescriptor(
                        id=next_id,
                        organs_off=tuple(sorted(subset)),
                        pert_type=self.perturbation_type,
                        seed=self._rng.randint(0, 2 ** 31 - 1)))
                    next_id += 1
        else:
            raise ValueError(f"Unknown strategy '{self.strategy}'")

        return sched

    # ------------------------------------------------------------------
    # magic methods
    # ------------------------------------------------------------------
    def __len__(self):
        return len(self.schedule)

    def __iter__(self):
        return iter(self.schedule)

    def __getitem__(self, idx: int) -> PerturbationDescriptor:
        return self.schedule[idx]

In [None]:
"""
SHAPPredictionIterator
======================

A memory-light wrapper that drives

    1.   *Perturbation application*   (zeroing-out selected organs)
    2.   *nnU-Net preprocessing*      (DefaultPreprocessor.run_case_npy)
    3.   *nnU-Net inference*          (predict_sliding_window_return_logits)

and yields (descriptor, prediction_logits).

Only one full-resolution volume (the **original**) plus one *preprocessed* patch
sit in RAM/GPU at any time.

This iterator is **stateless** w.r.t. SHAP bookkeeping – that will be handled
by the next module.
"""

from __future__ import annotations
from pathlib import Path
from typing import Iterator, Tuple, Optional, List
import itertools
import numpy as np
import nibabel as nib
import torch

from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor

__all__ = ["SHAPPredictionIterator"]


class SHAPPredictionIterator:
    """
    Parameters
    ----------
    planner
        A previously built (or loaded) OrganMaskPerturbationPlanner.
    predictor
        An *already initialised* `nnUNetPredictor` with
        `.plans_manager`, `.configuration_manager`, `.device`, `.network`, etc.
        The iterator never changes those attributes.
    skip_completed_ids
        Optional set/list of integer perturbation ids that should be skipped
        (for resume functionality).
    cache_sw_inference
        Cache the sliding window inference patches, running only the perturbed ones
    """

    def __init__(self,
                 planner: OrganMaskPerturbationPlanner,
                 predictor,
                 skip_completed_ids: Optional[List[int]] = None,
                 cache_sw_inference: bool = True,
                 pre_cached_output: dict | None = None,
                 verbose: bool = False):

        self.planner = planner
        self.predictor = predictor
        self.verbose = verbose
        self.skip_completed_ids = set(skip_completed_ids or [])
        self.cache_sw_inference = cache_sw_inference

        # ------------------------------------------------------------------
        # 1) Read & cache original image + mask only ONCE.
        #    They are kept in CPU RAM for the whole iterator lifetime.
        # ------------------------------------------------------------------
        # 1) Read *both* image and organ-mask via the same reader  ✅
        # ----------------------------------------------------------------------
        rw = self.predictor.plans_manager.image_reader_writer_class()
        # If nnU-Net returns a class instead of an instance, instantiate it
        if callable(rw) and not hasattr(rw, "read_images"):
            rw = rw()
        
        # ---- image -----------------------------------------------------------
        self._orig_image, self._orig_props = rw.read_images(
            [str(planner.volume_file)]
        )                     # (C, Z, Y, X)
        
        # ---- organ mask ------------------------------------------------------
        seg_arr, _ = rw.read_seg(str(planner.organ_mask_file))  # shape (1, Z, Y, X)
        self._organ_mask = seg_arr[0].astype(np.int32, copy=False)  # drop channel axis

        # convenience
        self._preprocessor = self.predictor.configuration_manager.preprocessor_class(
            verbose=self.verbose
        )

        # ----- baseline output dictionary -------------------------------------
        if self.cache_sw_inference:
            if pre_cached_output == None:
                self._baseline_output_dictionary = self.get_orig_image_output_dictionary()
            else:
                if self.verbose:
                    print("pre-cached output dictionary found")
                self._baseline_output_dictionary = pre_cached_output

    # ------------------------------------------------------------------ #
    # Python iterator protocol                                            #
    # ------------------------------------------------------------------ #
    def __iter__(self) -> Iterator[Tuple[PerturbationDescriptor, np.ndarray]]:
        """
        Yields
        ------
        (descriptor, logits)  where
            descriptor : PerturbationDescriptor
            logits     : np.ndarray  shape (C, X, Y, Z) on **CPU**
        """
        for desc in self.planner.schedule:
            if desc.id in self.skip_completed_ids:
                if self.verbose:
                    print(f"[SHAP] skipping id={desc.id} (already done)")
                continue

            # --------------------------------------------------------------
            # a) apply perturbation to a *copy* of the original volume
            # --------------------------------------------------------------
            data_pert = self._apply_perturbation(desc)

            # --------------------------------------------------------------
            # b) preprocess → network input tensor
            # --------------------------------------------------------------
            data_pp, _, _ = self._preprocessor.run_case_npy(
                data_pert,
                seg=None,
                properties=self._orig_props,
                plans_manager=self.predictor.plans_manager,
                configuration_manager=self.predictor.configuration_manager,
                dataset_json=self.predictor.dataset_json
            )
            # to torch, channel-first is already true
            inp_tensor = torch.from_numpy(data_pp)

            # -------------------------------------------------------------
            # c) get perturbation mask optionally for caching utility
            # -------------------------------------------------------------
            if self.cache_sw_inference:
                pert_mask = self._get_perturbation_mask(desc)

            # --------------------------------------------------------------
            # d) sliding-window prediction
            # --------------------------------------------------------------
            if self.cache_sw_inference:
                logits = self.predictor.predict_sliding_window_return_logits_with_caching(
                    inp_tensor,
                    pert_mask,
                    self._baseline_output_dictionary
                ).cpu().numpy()          # low-level array, channel-first
            else:
                logits = self.predictor.predict_sliding_window_return_logits(
                    inp_tensor
                ).cpu().numpy()         # low-level array, channel-first

            # free GPU cache ASAP
            torch.cuda.empty_cache()

            # yield results to caller
            yield desc, logits

    def __len__(self):
        return len(self.planner) - len(self.skip_completed_ids)

    # -------------------------------------------------------1Introduction----------- #
    # internal helpers                                                   #
    # ------------------------------------------------------------------ #
    def _apply_perturbation(self, desc: PerturbationDescriptor) -> np.ndarray:
        """
        Return a **new** NumPy array `data_pert` with the specified organs
        perturbed (currently only 'zero' implemented).
        """
        if desc.pert_type not in {"identity", "zero"}:
            raise NotImplementedError(f"Perturbation '{desc.pert_type}' not yet implemented")

        # Shallow copy when no perturbation needed (baseline) –
        # avoids wasting RAM while staying side-effect-free.
        if not desc.organs_off:
            return self._orig_image

        # Materialise a *copy* because we will mutate values
        data = self._orig_image.copy()

        # Build a Boolean mask once; broadcast over channels
        # mask shape (X,Y,Z) – True where voxel belongs to any organ in organs_off
        mask = np.isin(self._organ_mask, desc.organs_off)

        if desc.pert_type == "zero":
            # data shape (C, X, Y, Z)
            data[:, mask] = 0

        return data

    def _get_perturbation_mask(
                                self,
                                desc: PerturbationDescriptor,
                                num_channels: int = 1
                            ) -> torch.BoolTensor:
        """
        Ritorna una maschera (C, X, Y, Z) booleana, True sui voxel degli organi da azzerare.
        """
        # --- costruisci mask (Z, Y, X) su CPU ------------------------------
        mask_zyx = np.isin(self._organ_mask, desc.organs_off)          # (Z,Y,X) bool
        mask_zyx = torch.from_numpy(mask_zyx)        # torch.bool
    
        # --- porta in ordine (X, Y, Z) -------------------------------------
        mask_xyz = mask_zyx.permute(2, 1, 0).contiguous()              # (X,Y,Z)
    
        # --- replica sui canali -------------------------------------------
        mask_cxyz = mask_xyz.unsqueeze(0)                              # (1,X,Y,Z)
        mask_cxyz = mask_cxyz.expand(num_channels, -1, -1, -1)         # (C,X,Y,Z)
    
        return mask_cxyz

    def get_orig_image_output_dictionary(self) -> dict:
        """
        Return a dictionary indexed by the slices for the sliding window, of the output of the inference for each patch
        of the original volume
        """
        if not self.cache_sw_inference:
            raise Exception("get_orig_image_output_dictionary non è compatibile con cache_sw_inference = False")
            
        data_pp, _, _ = self._preprocessor.run_case_npy(
                self._orig_image,
                seg=None,
                properties=self._orig_props,
                plans_manager=self.predictor.plans_manager,
                configuration_manager=self.predictor.configuration_manager,
                dataset_json=self.predictor.dataset_json
            )
        # to torch, channel-first is already true
        inp_tensor = torch.from_numpy(data_pp)
        if self.verbose:
            print("tensor shape before _internal_get_sliding_window_slicers called in Iterator: ", inp_tensor.shape)
        
        slicers = self.predictor._internal_get_sliding_window_slicers(inp_tensor.shape[1:])
        if self.verbose:
            print("first 3 slicers of Iterator object: ", slicers[:3])

        dictionary = self.predictor.get_output_dictionary_sliding_window(inp_tensor, slicers)

        return dictionary
        

In [None]:
"""
SHAPAccumulator
===============

Responsibilities
----------------
1.  Accept (descriptor, logits) tuples from `SHAPPredictionIterator`.
2.  Write the logits to a compressed .npz file (one per perturbation).
3.  Append a tiny JSON-Lines record so we can resume reliably.
4.  Expose `.completed_ids` so the iterator can skip work that is done.
5.  Offer `finalize()` which – for now – just returns file paths, but gives
    you one place to plug in an actual KernelSHAP post-processor later.
"""

from __future__ import annotations
from pathlib import Path
import json
from typing import Dict, List, Set, Tuple, Optional
import numpy as np

__all__ = ["SHAPAccumulator"]


class SHAPAccumulator:
    """
    Parameters
    ----------
    planner
        The same planner you pass to the iterator – purely for metadata.
    out_dir
        A directory that will contain two things:

        1.  `pred_00012.npz`     (compressed logits for perturbation id 12)
        2.  `progress.jsonl`     (one JSON record per *finished* perturbation)

    mode            "r" (read-only) or "a" (append / create)
    """

    PROGRESS_FILE = "progress.jsonl"          # newline-delimited JSON

    def __init__(self,
                 planner: OrganMaskPerturbationPlanner,
                 out_dir: str | Path,
                 mode: str = "a",
                 npz_compression_kwargs: Optional[Dict] = None):

        self.planner = planner
        self.out_dir = Path(out_dir)
        self.out_dir.mkdir(parents=True, exist_ok=True)
        self.progress_path = self.out_dir / self.PROGRESS_FILE

        self._compression_kwargs = npz_compression_kwargs or {}
        self._mode = mode

        # ------------------------------------------------------------------
        # 1) read existing progress (if any) → set of completed ids
        # ------------------------------------------------------------------
        self.completed_ids: Set[int] = set()
        self._progress_fp = None

        if self.progress_path.exists():
            with self.progress_path.open("r", encoding="utf-8") as f:
                for ln in f:
                    rec = json.loads(ln.strip())
                    self.completed_ids.add(int(rec["id"]))

        if mode == "a":
            # open file in append mode so every .update() is flushed immediately
            self._progress_fp = self.progress_path.open("a", encoding="utf-8")
        elif mode == "r":
            # read-only – no writing allowed
            self._progress_fp = None
        else:
            raise ValueError("mode must be 'a' or 'r'")

    # ------------------------------------------------------------------ #
    # public API                                                          #
    # ------------------------------------------------------------------ #
    def update(self,
               desc: PerturbationDescriptor,
               logits: np.ndarray) -> Path:
        """
        Persist *one* prediction and mark it as completed.

        Returns
        -------
        Path to the .npz file that was written.
        """
        if self._mode != "a":
            raise RuntimeError("Accumulator opened read-only")

        if desc.id in self.completed_ids:
            # defensive – shouldn't happen if iterator skips
            return self._logit_path(desc.id)

        # 1) save logits on disk – compressed, float16 by default
        npz_path = self._logit_path(desc.id)
        np.savez_compressed(npz_path, logits=logits, **self._compression_kwargs)

        # 2) append a *tiny* progress record
        rec = {
            "id":        int(desc.id),
            "organs_off": list(desc.organs_off),
            "pert_type": desc.pert_type,
            "logit_file": str(npz_path.name)
        }
        self._progress_fp.write(json.dumps(rec) + "\n")
        self._progress_fp.flush()

        # 3) keep in RAM only the id
        self.completed_ids.add(int(desc.id))

        return npz_path

    def finalize(self):
        """
        Placeholder for KernelSHAP post-processing.

        At present it simply returns a *list of disk files* in the order
        of `planner.schedule`.  Replace this implementation with whatever
        SHAP fitting you need (e.g. weighted linear regression).

        Notes
        -----
        * Because all logits are on disk, the memory footprint is tiny: you
          load one .npz at a time if you want to aggregate voxel-wise stats.
        * You may also decide to compute SHAP values per-organ *metric*
          (Dice, volume, etc.) instead of raw voxel map – totally up to you.
        """
        files_in_order = [
            self._logit_path(desc.id)
            for desc in self.planner.schedule
            if desc.id in self.completed_ids
        ]
        return files_in_order

    def close(self):
        if self._progress_fp is not None:
            self._progress_fp.close()

    # ------------------------------------------------------------------ #
    # helpers                                                            #
    # ------------------------------------------------------------------ #
    def _logit_path(self, perturbation_id: int) -> Path:
        return self.out_dir / f"pred_{perturbation_id:05d}.npz"

    # make the object a context manager for convenience
    def __enter__(self):   return self
    def __exit__(self, exc_type, exc_val, exc_tb): self.close()

In [None]:
plan = OrganMaskPerturbationPlanner(
    volume_file=join(imgs_dir, 'AUTOMI_00039_0000.nii.gz'),
    organ_mask_file=join(nnUNet_raw, "total_segmentator_structures", "AUTOMI_00039_0000", "mask_mask_add_input_20_total_segmentator.nii"),
    strategy="single-off+random",
    n_random=1,                 # n_random extra random subsets, added to singletons
    perturbation_type="zero",
    out_json="SHAP/output_plan.json"
)

print(f"Schedule contains {len(plan)} perturbations "
      f"(baseline + {len(plan)-1} variants)")

In [None]:

from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

# 1) Load plan -----------------------------------------------------------
plan = OrganMaskPerturbationPlanner.load_json("SHAP/output_plan.json")

# 2) Initialise predictor (standard nnU-Net code) ------------------------
predictor = nnUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=False, # == test time augmentation
    perform_everything_on_device=True,
    device=torch.device('cuda', 0),
    verbose=False,
    verbose_preprocessing=False,
    allow_tqdm=True
)
# initializes the network architecture, loads the checkpoint
predictor.initialize_from_trained_model_folder(
    model_dir_readonly,
    use_folds=(0,),
    checkpoint_name='checkpoint_final.pth',
)


# 3) Run iterator + accumulator -----------------------------------------
with SHAPAccumulator(plan, "SHAP/shap_run", mode="a") as acc:
    for desc, logits in SHAPPredictionIterator(plan, predictor,
                                               skip_completed_ids=acc.completed_ids,
                                               cache_sw_inference=False,
                                               verbose=True):
        acc.update(desc, logits)

# 4) After all perturbations are done ------------------------------------
paths = acc.finalize()   # list of .npz files on disk (baseline first)
print("All predictions saved, ready for KernelSHAP fitting.")

## Eventually export some predicted volume if needed to show

In [None]:
import numpy as np
import torch
import os
from pathlib import Path
from nnunetv2.configuration import default_num_processes
from nnunetv2.inference.export_prediction import export_prediction_from_logits

def export_logits_to_nifty_segmentation(
    predictor,
    volume_file: Path,
    model_dir_readonly: str,
    logits: Union[str, np.ndarray, torch.Tensor],
    npz_dir: str | None,
    output_dir: str = "",
    fold: int = 0,
    save_probs: bool = False,
    from_file: bool = True
):
    """
    Converts a saved .npz logits file into a native-space NIfTI segmentation using nnU-Net's helper.

    Args:
        predictor: An instantiated nnU-Net predictor object with loaded plans/configs.
        volume_file: An Path object pointing to the raw image file.
        model_dir_readonly (str): Path to the nnU-Net mode1Introductionl directory containing dataset.json.
        logits (str or tensor): Base name of the .npz logits file (no extension), when from_file=True
        npz_dir (str): Directory where the .npz file is stored.
        output_dir (str): Directory where the .nii.gz segmentation will be saved.
        fold (int): The fold number used for prediction (default is 0).
        save_probs (bool): Whether to save softmax probabilities as a .npz file.
        from_file (bool). Whether to convert from a file instead of from the logits (default true)
    """
    if from_file:
        npz_logits = Path(npz_dir) / f"{logits}.npz"
        output_nii = Path(output_dir) / f"{logits}_seg.nii.gz"
        logits = np.load(npz_logits)["logits"]
    else:
        output_nii = Path(output_dir) / "exported_seg.nii.gz"

    plans_manager = predictor.plans_manager
    configuration_manager = predictor.configuration_manager
    dataset_json = Path(model_dir_readonly) / "dataset.json"

    preprocessor = configuration_manager.preprocessor_class(verbose=False)
    rw = plans_manager.image_reader_writer_class()
    if callable(rw) and not hasattr(rw, "read_images"):
        rw = rw()
    img_np, img_props = rw.read_images([str(plan.volume_file)])

    _, _, data_props = preprocessor.run_case_npy(
        img_np, seg=None, properties=img_props,
        plans_manager=plans_manager,
        configuration_manager=configuration_manager,
        dataset_json=dataset_json
    )


    export_prediction_from_logits(
        predicted_array_or_file=logits,
        properties_dict=data_props,
        configuration_manager=configuration_manager,
        plans_manager=plans_manager,
        dataset_json_dict_or_file=str(dataset_json),
        output_file_truncated=os.path.splitext(str(output_nii))[0],
        save_probabilities=save_probs,
        num_threads_torch=default_num_processes
    )

    print(f"✅  NIfTI segmentation written → {output_nii}")

In [None]:
"""export_logits_to_nifty_segmentation(
    predictor=predictor,
    plan=plan,
    model_dir_readonly=Path(model_dir_readonly),
    logits_filename="pred_00007",
    npz_dir="SHAP/shap_run",
    output_dir="SHAP/shap_run",
    fold=0,
    save_probs=False
)"""

## We start testing sliding window caching for faster multi-inference scenario, like SHAP

### Try to override the sliding_window_function

In [None]:
from typing import Union
import numpy as np
import torch
from tqdm import tqdm
from queue import Queue
from threading import Thread
from acvl_utils.cropping_and_padding.padding import pad_nd_image
from nnunetv2.utilities.helpers import empty_cache, dummy_context
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.inference.sliding_window_prediction import compute_gaussian, compute_steps_for_sliding_window

class CustomNNUNetPredictor(nnUNetPredictor):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    @torch.inference_mode()
    def predict_sliding_window_return_logits_with_caching(self, input_image: torch.Tensor,
                                                          perturbation_mask: torch.BoolTensor,
                                                          baseline_prediction_dict: dict) \
            -> Union[np.ndarray, torch.Tensor]:
        """
        Method predict_sliding_window_return_logits taken from official nnunetv2 documentation:
        https://github.com/MIC-DKFZ/nnUNet/blob/58a3b121a6d1846a978306f6c79a7c005b7d669b/nnunetv2/inference/predict_from_raw_data.py
        We add a perturbation_mask parameter to check each patch for the actual presence of a perturbation
        """
        assert isinstance(input_image, torch.Tensor)
        self.network = self.network.to(self.device)
        self.network.eval()

        empty_cache(self.device)

        # DEBUG --------------
        """voxels  = np.prod(input_image.shape[1:])          # (X*Y*Z)
        bytes_per_voxel = 2                        # fp16
        needed  = voxels * self.label_manager.num_segmentation_heads * bytes_per_voxel
        print(f"≈{needed/1e9:.1f} GB per predicted_logits")"""

        # Autocast can be annoying
        # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection)
        # and needs to be disabled.
        # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False
        # is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
        # So autocast will only be active if we have a cuda device.
        with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
            assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'

            if self.verbose:
                print(f'Input shape: {input_image.shape}')
                print("step_size:", self.tile_step_size)
                print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)
                print(f'Perturbation mask shape: {perturbation_mask.shape}')


            # if input_image is smaller than tile_size we need to pad it to tile_size.
            data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
                                                       'constant', {'value': 0}, True,
                                                       None)

            # slicers can be applied to both perturbed volume and 
            slicers = self._internal_get_sliding_window_slicers(data.shape[1:])

            if self.perform_everything_on_device and self.device != 'cpu':
                # behavior changed
                try:
                    predicted_logits = self._internal_predict_sliding_window_return_logits(
                        data, slicers, True, perturbation_mask, baseline_prediction_dict, caching=True
                    )
                except RuntimeError as e:
                    if "CUDA out of memory" in str(e):
                        print("⚠️  CUDA OOM, cambiare batch size o patch size!")
                        raise
                    else:
                        # Mostra l'errore reale e aborta: niente CPU fallback
                        raise
            else:
                predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
                                                                                       self.perform_everything_on_device)

            empty_cache(self.device)
            # revert padding
            predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
        return predicted_logits
                

    def _slice_key(self, slicer_tuple):
        # make slicer object hashable to use it for cache lookup
        return tuple((s.start, s.stop, s.step) for s in slicer_tuple)


    @torch.inference_mode()
    def _internal_predict_sliding_window_return_logits(self,
                                                       data: torch.Tensor,
                                                       slicers,
                                                       do_on_device: bool = True,
                                                       perturbation_mask: torch.BoolTensor | None = None,
                                                       baseline_prediction_dict: dict | None = None,
                                                       caching: bool = False,
                                                       ):
        """
        Modified to manage the caching of patches
        """
        predicted_logits = n_predictions = prediction = gaussian = workon = None
        results_device = self.device if do_on_device else torch.device('cpu')
        if next(self.network.parameters()).device != results_device:
            self.network = self.network.to(results_device)

        def producer(d, slh, q):
            for s in slh:
                q.put((torch.clone(d[s][None], memory_format=torch.contiguous_format).to(results_device), s))
            q.put('end')

        try:
            empty_cache(self.device)

            # move data to device
            if self.verbose:
                print(f'move image to device {results_device}')
            data = data.to(results_device)
            queue = Queue(maxsize=2)
            t = Thread(target=producer, args=(data, slicers, queue))
            t.start()

            # preallocate arrays
            if self.verbose:
                print(f'preallocating results arrays on device {results_device}')
            predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
                                           dtype=torch.half,
                                           device=results_device)
            n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)

            if self.use_gaussian:
                gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
                                            value_scaling_factor=10,
                                            device=results_device)
            else:
                gaussian = 1

        

            if not self.allow_tqdm and self.verbose:
                print(f'running prediction: {len(slicers)} steps')

            with tqdm(desc=None, total=len(slicers), disable=not self.allow_tqdm) as pbar:
                cache_hits = 0
                while True:
                    item = queue.get()
                    if item == 'end':
                        queue.task_done()
                        break
                    workon, sl = item
                    try:
                        if caching and not self.check_overlapping(sl, perturbation_mask):
                            prediction = baseline_prediction_dict[self._slice_key(sl)].to(results_device)
                            cache_hits += 1
                        else:
                            prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)
                    except Exception as e:
                        raise RuntimeError("Errore nella predizione del patch") from e

                    # 2) sanity-check device
                    assert prediction.device == predicted_logits.device

                    if self.use_gaussian:
                        prediction *= gaussian
                    predicted_logits[sl] += prediction
                    n_predictions[sl[1:]] += gaussian

                    # free up gpu memory
                    del prediction, workon
                    
                    queue.task_done()
                    pbar.set_postfix(
                        cache=f"{cache_hits}",
                        mem=f"{torch.cuda.memory_allocated()/1e9:.2f} GB"
                    )
                    pbar.update(1)
            queue.join()
            if self.verbose and not self.allow_tqdm:
                print(f"Cache hits: {cache_hits}\\{len(slicers)}")
            

            # predicted_logits /= n_predictions
            torch.div(predicted_logits, n_predictions, out=predicted_logits)
            # check for infs
            if torch.any(torch.isinf(predicted_logits)):
                raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '
                                   'reduce value_scaling_factor in compute_gaussian or increase the dtype of '
                                   'predicted_logits to fp32')
        except Exception as e:
            del predicted_logits, n_predictions, prediction, gaussian, workon
            empty_cache(self.device)
            empty_cache(results_device)
            raise e
        return predicted_logits
  


    def get_output_dictionary_sliding_window(self, data: torch.Tensor, slicers,
                                            do_on_device: bool = True,
                                            ) -> torch.Tensor:
        """
        # create a dictionary that associates the output of the inference, to each slicer of the sliding window module
        # this way we can set ready for cache the output for the untouched patches.
        """
        
        dictionary = dict()
        prediction = workon = None
        results_device = self.device if do_on_device else torch.device('cpu')
        if next(self.network.parameters()).device != results_device:
            self.network = self.network.to(results_device)

        def producer(d, slh, q):
            for s in slh:
                #tqdm.write(f"put patch {s} on queue")    # dentro producer
                q.put((torch.clone(d[s][None], memory_format=torch.contiguous_format).to(self.device), s))
            q.put('end')

        try:
            empty_cache(self.device)

            # move data and network to device
            if self.verbose:
                print(f'move image and model to device {results_device}')

            self.network = self.network.to(results_device)
            data = data.to(results_device)
            queue = Queue(maxsize=2)
            t = Thread(target=producer, args=(data, slicers, queue))
            t.start()

            if not self.allow_tqdm and self.verbose:
                print(f'running prediction: {len(slicers)} steps')

            with tqdm(desc=None, total=len(slicers), disable=not self.allow_tqdm) as pbar:
                while True:
                    item = queue.get()
                    if item == 'end':
                        queue.task_done()
                        break
                    workon, sl = item
                    pred_gpu = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)

                    pred_cpu = pred_gpu.cpu()
                    # save prediction in the dictionary
                    dictionary[self._slice_key(sl)] = pred_cpu
                    # immediately free gpu memory
                    del pred_gpu
                    
                    queue.task_done()
                    pbar.update()
            queue.join()

        except Exception as e:
            del workon#, prediction
            empty_cache(self.device)
            empty_cache(results_device)
            raise e
        return dictionary

        try:
            empty_cache(self.device)

            # move data and network to device
            if self.verbose:
                print(f'move image and model to device {results_device}')

            self.network = self.network.to(results_device)
            data = data.to(results_device)
            queue = Queue(maxsize=2)
            t = Thread(target=producer, args=(data, slicers, queue))
            t.start()

            if not self.allow_tqdm and self.verbose:
                print(f'running prediction: {len(slicers)} steps')

            with tqdm(desc=None, total=len(slicers), disable=not self.allow_tqdm) as pbar:
                while True:
                    item = queue.get()
                    if item == 'end':
                        queue.task_done()
                        break
                    workon, sl = item
                    pred_gpu = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)

                    pred_cpu = pred_gpu.cpu()
                    # save prediction in the dictionary
                    dictionary[self._slice_key(sl)] = pred_cpu
                    # immediately free gpu memory
                    del pred_gpu
                    
                    queue.task_done()
                    pbar.update()
            queue.join()

        except Exception as e:
            del workon#, prediction
            empty_cache(self.device)
            empty_cache(results_device)
            raise e
        return dictionary

    def check_overlapping(self, slicer, perturbation_mask: torch.BoolTensor) -> bool:
        """
        Restituisce True se la patch definita da `slicer`
        contiene almeno un voxel perturbato.
    
        Parameters
        ----------
        slicer : tuple
            Quello prodotto da `_internal_get_sliding_window_slicers`,
            cioè (slice(None), slice(x0,x1), slice(y0,y1), slice(z0,z1)).
        perturbation_mask : torch.BoolTensor
            Maschera (C, X, Y, Z) con True nei voxel da perturbare
            (di solito C==1 o replicata sui canali).
    
        Returns
        -------
        bool
            True ↔ almeno un voxel True nella patch.
        """
        # NB: il primo elemento del tuple è sempre slice(None) (canali).
        #     Lo manteniamo: non ha overhead e semplifica.
        return perturbation_mask[slicer].any().item()

In [None]:
# Create plan ------------------------------------------------------------
plan = OrganMaskPerturbationPlanner(
    volume_file=join(imgs_dir, 'AUTOMI_00039_0000.nii.gz'),
    organ_mask_file=join(nnUNet_raw, "total_segmentator_structures", "AUTOMI_00039_0000", "mask_mask_add_input_20_total_segmentator.nii"),
    strategy="single-off+random",
    n_random=1,                 # n_random extra random subsets, added to singletons
    perturbation_type="zero",
    out_json="SHAP-test/output_plan.json"
)

print(f"Schedule contains {len(plan)} perturbations "
      f"(baseline + {len(plan)-1} variants)")

# 1) Load plan -----------------------------------------------------------
plan = OrganMaskPerturbationPlanner.load_json("SHAP-test/output_plan.json")

In [None]:

from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
del predictor.network
# 2) Initialise predictor ------------------------
predictor = CustomNNUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=False, # == test time augmentation
    perform_everything_on_device=True,
    device=torch.device('cuda', 0),
    verbose=True,
    verbose_preprocessing=False,
    allow_tqdm=True
)
# initializes the network architecture, loads the checkpoint
predictor.initialize_from_trained_model_folder(
    model_dir_readonly,
    use_folds=(0,),
    checkpoint_name='checkpoint_final.pth',
)
"""
torch.cuda.empty_cache()

iterator = SHAPPredictionIterator(plan, predictor,
                                               skip_completed_ids=acc.completed_ids,
                                               cache_sw_inference=True,
                                               verbose=True)
pre_cached = iterator.get_orig_image_output_dictionary()




# 3) Run iterator + accumulator -----------------------------------------
with SHAPAccumulator(plan, "SHAP-test/shap_run", mode="a") as acc:
    for desc, logits in SHAPPredictionIterator(plan, predictor,
                                               skip_completed_ids=acc.completed_ids,
                                               cache_sw_inference = True,
                                               #pre_cached_output = pre_cached,
                                               verbose=False):
        acc.update(desc, logits)

# 4) After all perturbations are done ------------------------------------
paths = acc.finalize()   # list of .npz files on disk (baseline first)
print("All predictions saved, ready for KernelSHAP fitting.")"""

In [None]:
"""export_logits_to_nifty_segmentation(
    predictor=predictor,
    plan=plan,
    model_dir_readonly=Path(model_dir_readonly),
    logits_filename="pred_00006",
    npz_dir="SHAP-test/shap_run",
    output_dir="SHAP-test/shap_run",
    fold=0,
    save_probs=False
)"""

## Check that the 2 methods provide the same results

In [None]:
"""# Load the segmentations
seg1 = nib.load("SHAP/shap_run/pred_00006_seg.nii.gz").get_fdata()
seg2 = nib.load("SHAP-test/shap_run/pred_00006_seg.nii.gz").get_fdata()

# Check for voxel-wise identity
if np.array_equal(seg1, seg2):
    print("✅ The segmentations are exactly the same.")
else:
    print("❌ Differences found between the segmentations.")"""

In [None]:
"""def dice_coefficient(seg1, seg2):
    intersection = np.logical_and(seg1, seg2).sum()
    return 2. * intersection / (seg1.sum() + seg2.sum())

print("Dice:", dice_coefficient(seg1, seg2))"""

### A dice score of nearly 1 means that the only difference may be Floating Point precision errors accumulating in overlap regions across patches. No worries!

## Try Captum's kernel SHAP on the organ mask

### first derive a customized class from Captum library, to use sliding window caching

## Try to customize KernelShap as a "sibling", so let's inherit the parent, LimeBase
that's because we need to override (to-and-from)/interpret_rep_transform methods used to map the (1,M) binary mask vector with the perturbed volume AND the perturbation mask we need for caching

In [None]:
#!/usr/bin/env python3

# pyre-strict
import inspect
import math
import typing
import warnings
from collections.abc import Iterator
from typing import Any, Callable, cast, List, Literal, Optional, Tuple, Union

import torch
from captum._utils.common import (
    _expand_additional_forward_args,
    _expand_target,
    _flatten_tensor_or_tuple,
    _format_output,
    _format_tensor_into_tuples,
    _get_max_feature_index,
    _is_tuple,
    _reduce_list,
    _run_forward,
)
from captum._utils.models.linear_model import SkLearnLasso
from captum._utils.models.model import Model
from captum._utils.progress import progress
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.batching import _batch_example_iterator
from captum.attr._utils.common import (
    _construct_default_feature_mask,
    _format_input_baseline,
)
from captum.log import log_usage
from torch import Tensor, BoolTensor
from torch.nn import CosineSimilarity
from torch.utils.data import DataLoader, TensorDataset


class LimeBaseWithCustomArgumentToForwardFunc(PerturbationAttribution):
    r"""
    Here we create a modification of Lime class from Captum Library (https://captum.ai/api/_modules/captum/attr/_core/lime.html)
    """

    def __init__(
        self,
        forward_func: Callable[..., Tensor],
        interpretable_model: Model,
        similarity_func: Callable[
            ...,
            Union[float, Tensor],
        ],
        perturb_func: Callable[..., object],
        perturb_interpretable_space: bool,
        from_interp_rep_transform: Optional[
            Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
        ],
        to_interp_rep_transform: Optional[Callable[..., Tensor]],
    ) -> None:
        r"""

        Args:


            forward_func (Callable): The forward function of the model or any
                    modification of it. If a batch is provided as input for
                    attribution, it is expected that forward_func returns a scalar
                    representing the entire batch.
            interpretable_model (Model): Model object to train interpretable model.
                    A Model object provides a `fit` method to train the model,
                    given a dataloader, with batches containing three tensors:

                    - interpretable_inputs: Tensor
                      [2D num_samples x num_interp_features],
                    - expected_outputs: Tensor [1D num_samples],
                    - weights: Tensor [1D num_samples]

                    The model object must also provide a `representation` method to
                    access the appropriate coefficients or representation of the
                    interpretable model after fitting.
                    Some predefined interpretable linear models are provided in
                    captum._utils.models.linear_model including wrappers around
                    SkLearn linear models as well as SGD-based PyTorch linear
                    models.

                    Note that calling fit multiple times should retrain the
                    interpretable model, each attribution call reuses
                    the same given interpretable model object.
            similarity_func (Callable): Function which takes a single sample
                    along with its corresponding interpretable representation
                    and returns the weight of the interpretable sample for
                    training interpretable model. Weight is generally
                    determined based on similarity to the original input.
                    The original paper refers to this as a similarity kernel.

                    The expected signature of this callable is:

                    >>> similarity_func(
                    >>>    original_input: Tensor or tuple[Tensor, ...],
                    >>>    perturbed_input: Tensor or tuple[Tensor, ...],
                    >>>    perturbed_interpretable_input:
                    >>>        Tensor [2D 1 x num_interp_features],
                    >>>    **kwargs: Any
                    >>> ) -> float or Tensor containing float scalar

                    perturbed_input and original_input will be the same type and
                    contain tensors of the same shape (regardless of whether or not
                    the sampling function returns inputs in the interpretable
                    space). original_input is the same as the input provided
                    when calling attribute.

                    All kwargs passed to the attribute method are
                    provided as keyword arguments (kwargs) to this callable.
            perturb_func (Callable): Function which returns a single
                    sampled input, generally a perturbation of the original
                    input, which is used to train the interpretable surrogate
                    model. Function can return samples in either
                    the original input space (matching type and tensor shapes
                    of original input) or in the interpretable input space,
                    which is a vector containing the intepretable features.
                    Alternatively, this function can return a generator
                    yielding samples to train the interpretable surrogate
                    model, and n_samples perturbations will be sampled
                    from this generator.

                    The expected signature of this callable is:

                    >>> perturb_func(
                    >>>    original_input: Tensor or tuple[Tensor, ...],
                    >>>    **kwargs: Any
                    >>> ) -> Tensor, tuple[Tensor, ...], or
                    >>>    generator yielding tensor or tuple[Tensor, ...]

                    All kwargs passed to the attribute method are
                    provided as keyword arguments (kwargs) to this callable.

                    Returned sampled input should match the input type (Tensor
                    or Tuple of Tensor and corresponding shapes) if
                    perturb_interpretable_space = False. If
                    perturb_interpretable_space = True, the return type should
                    be a single tensor of shape 1 x num_interp_features,
                    corresponding to the representation of the
                    sample to train the interpretable model.

                    All kwargs passed to the attribute method are
                    provided as keyword arguments (kwargs) to this callable.
            perturb_interpretable_space (bool): Indicates whether
                    perturb_func returns a sample in the interpretable space
                    (tensor of shape 1 x num_interp_features) or a sample
                    in the original space, matching the format of the original
                    input. Once sampled, inputs can be converted to / from
                    the interpretable representation with either
                    to_interp_rep_transform or from_interp_rep_transform.
            from_interp_rep_transform (Callable): Function which takes a
                    single sampled interpretable representation (tensor
                    of shape 1 x num_interp_features) and returns
                    the corresponding representation in the input space
                    (matching shapes of original input to attribute).

                    This argument is necessary if perturb_interpretable_space
                    is True, otherwise None can be provided for this argument.

                    The expected signature of this callable is:

                    >>> from_interp_rep_transform(
                    >>>    curr_sample: Tensor [2D 1 x num_interp_features]
                    >>>    original_input: Tensor or Tuple of Tensors,
                    >>>    **kwargs: Any
                    >>> ) -> Tensor or tuple[Tensor, ...]

                    Returned sampled input should match the type of original_input
                    and corresponding tensor shapes.

                    All kwargs passed to the attribute method are
                    provided as keyword arguments (kwargs) to this callable.

            to_interp_rep_transform (Callable): Function which takes a
                    sample in the original input space and converts to
                    its interpretable representation (tensor
                    of shape 1 x num_interp_features).

                    This argument is necessary if perturb_interpretable_space
                    is False, otherwise None can be provided for this argument.

                    The expected signature of this callable is:

                    >>> to_interp_rep_transform(
                    >>>    curr_sample: Tensor or Tuple of Tensors,
                    >>>    original_input: Tensor or Tuple of Tensors,
                    >>>    **kwargs: Any
                    >>> ) -> Tensor [2D 1 x num_interp_features]

                    curr_sample will match the type of original_input
                    and corresponding tensor shapes.

                    All kwargs passed to the attribute method are
                    provided as keyword arguments (kwargs) to this callable.
        """
        PerturbationAttribution.__init__(self, forward_func)
        self.interpretable_model = interpretable_model
        self.similarity_func = similarity_func
        self.perturb_func = perturb_func
        self.perturb_interpretable_space = perturb_interpretable_space
        self.from_interp_rep_transform = from_interp_rep_transform
        self.to_interp_rep_transform = to_interp_rep_transform

        if self.perturb_interpretable_space:
            assert (
                self.from_interp_rep_transform is not None
            ), "Must provide transform from interpretable space to original input space"
            " when sampling from interpretable space."
        else:
            assert (
                self.to_interp_rep_transform is not None
            ), "Must provide transform from original input space to interpretable space"

    @log_usage(part_of_slo=True)
    @torch.no_grad()
    def attribute(
        self,
        inputs: TensorOrTupleOfTensorsGeneric,
        target: TargetType = None,
        additional_forward_args: Optional[Tuple[object, ...]] = None,
        n_samples: int = 50,
        perturbations_per_eval: int = 1,
        show_progress: bool = False,
        **kwargs: object,
    ) -> Tensor:
        r"""
        This method attributes the output of the model with given target index
        (in case it is provided, otherwise it assumes that output is a
        scalar) to the inputs of the model using the approach described above.
        It trains an interpretable model and returns a representation of the
        interpretable model.

        It is recommended to only provide a single example as input (tensors
        with first dimension or batch size = 1). This is because LIME is generally
        used for sample-based interpretability, training a separate interpretable
        model to explain a model's prediction on each individual example.

        A batch of inputs can be provided as inputs only if forward_func
        returns a single value per batch (e.g. loss).
        The interpretable feature representation should still have shape
        1 x num_interp_features, corresponding to the interpretable
        representation for the full batch, and perturbations_per_eval
        must be set to 1.

        Args:

            inputs (Tensor or tuple[Tensor, ...]): Input for which LIME
                        is computed. If forward_func takes a single
                        tensor as input, a single input tensor should be provided.
                        If forward_func takes multiple tensors as input, a tuple
                        of the input tensors should be provided. It is assumed
                        that for all given input tensors, dimension 0 corresponds
                        to the number of examples, and if multiple input tensors
                        are provided, the examples must be aligned appropriately.
            target (int, tuple, Tensor, or list, optional): Output indices for
                        which surrogate model is trained
                        (for classification cases,
                        this is usually the target class).
                        If the network returns a scalar value per example,
                        no target index is necessary.
                        For general 2D outputs, targets can be either:

                        - a single integer or a tensor containing a single
                          integer, which is applied to all input examples

                        - a list of integers or a 1D tensor, with length matching
                          the number of examples in inputs (dim 0). Each integer
                          is applied as the target for the corresponding example.

                        For outputs w            except --------ith > 2 dimensions, targets can be either:

                        - A single tuple, which contains #output_dims - 1
                          elements. This target index is applied to all examples.

                        - A list of tuples with length equal to the number of
                          examples in inputs (dim 0), and each tuple containing
                          #output_dims - 1 elements. Each tuple is applied as the
                          target for the corresponding example.

                        Default: None
            additional_forward_args (Any, optional): If the forward function
                        requires additional arguments other than the inputs for
                        which attributions should not be computed, this argument
                        can be provided. It must be either a single additional
                        argument of a Tensor or arbitrary (non-tuple) type or a
                        tuple containing multiple additional arguments including
                        tensors or any arbitrary python types. These arguments
                        are provided to forward_func in order following the
                        arguments in inputs.
                        For a tensor, the first dimension of the tensor must
                        correspond to the number of examples. For all other types,
                        the given argument is used for all forward evaluations.
                        Note that attributions are not computed with respect
                        to these arguments.
                        Default: None
            n_samples (int, optional): The number of samples of the original
                        model used to train the surrogate interpretable model.
                        Default: `50` if `n_samples` is not provided.
            perturbations_per_eval (int, optional): Allows multiple samples
                        to be processed simultaneously in one call to forward_fn.
                        Each forward pass will contain a maximum of
                        perturbations_per_eval * #examples samples.
                        For DataParallel models, each batch is split among the
                        available devices, so evaluations on each available
                        device contain at most
                        (perturbations_per_eval * #examples) / num_devices
                        samples.
                        If the forward function returns a single scalar per batch,
                        perturbations_per_eval must be set to 1.
                        Default: 1
            show_progress (bool, optional): Displays the progress of computation.
                        It will try to use tqdm if available for advanced features
                        (e.g. time estimation). Otherwise, it will fallback to
                        a simple output of progress.
                        Default: False
            **kwargs (Any, optional): Any additional arguments necessary for
                        sampling and transformation functions (provided to
                        constructor).
                        Default: None

        Returns:
            **interpretable model representation**:
            - **interpretable model representation** (*Any*):
                    A representation of the interpretable model trained. The return
                    type matches the return type of train_interpretable_model_func.
                    For example, this could contain coefficients of a
                    linear surrogate model.

        Examples::

            >>> # SimpleClassifier takes a single input tensor of
            >>> # float features with size N x 5,
            >>> # and returns an Nx3 tensor of class probabilities.
            >>> net = SimpleClassifier()
            >>>
            >>> # We will train an interpretable model with the same
            >>> # features by simply sampling with added Gaussian noise
            >>> # to the inputs and training a model to predict the
            >>> # score of the target class.
            >>>
            >>> # For interpretable model training, we will use sklearn
            >>> # linear model in this example. We have provided wrappers
            >>> # around sklearn linear models to fit the Model interface.
            >>> # Any arguments provided to the sklearn constructor can also
            >>> # be provided to the wrapper, e.g.:
            >>> # SkLearnLinearModel("linear_model.Ridge", alpha=2.0)
            >>> from captum._utils.models.linear_model import SkLearnLinearModel
            >>>
            >>>
            >>> # Define similarity kernel (exponential kernel based on L2 norm)
            >>> def similarity_kernel(
            >>>     original_input: Tensor,
            >>>     perturbed_input: Tensor,
            >>>     perturbed_interpretable_input: Tensor,
            >>>     **kwargs)->Tensor:
            >>>         # kernel_width will be provided to attribute as a kwarg
            >>>         kernel_width = kwargs["kernel_width"]
            >>>         l2_dist = torch.norm(original_input - perturbed_input)
            >>>         return torch.exp(- (l2_dist**2) / (kernel_width**2))
            >>>
            >>>
            >>> # Define sampling function
            >>> # This function samples in original input space
            >>> def perturb_func(
            >>>     original_input: Tensor,
            >>>     **kwargs)->Tensor:
            >>>         return original_input + torch.randn_like(original_input)
            >>>
            >>> # For this example, we are setting the interpretable input to
            >>> # match the model input, so the to_interp_rep_transform
            >>> # function simply returns the input. In most cases, the interpretable
            >>> # input will be different and may have a smaller feature set, so
            >>> # an appropriate transformation function should be provided.
            >>>
            >>> def to_interp_transform(curr_sample, original_inp,
            >>>                                      **kwargs):
            >>>     return curr_sample
            >>>
            >>> # Generating random input with size 1 x 5
            >>> input = torch.randn(1, 5)
            >>> # Defining LimeBase interpreter
            >>> lime_attr = LimeBase(net,
                                     SkLearnLinearModel("linear_model.Ridge"),
                                     similarity_func=similarity_kernel,
                                     perturb_func=perturb_func,
                                     perturb_interpretable_space=False,
                                     from_interp_rep_transform=None,
                                     to_interp_rep_transform=to_interp_transform)
            >>> # Computes interpretable model, returning coefficients of linear
            >>> # model.
            >>> attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1)
        """
        inp_tensor = cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0]
        device = inp_tensor.device

        interpretable_inps = []
        similarities = []
        outputs = []

        curr_model_inputs = []
        expanded_additional_args = None
        expanded_target = None
        gen_perturb_func = self._get_perturb_generator_func(inputs, **kwargs)

        if show_progress:
            attr_progress = progress(
                total=math.ceil(n_samples / perturbations_per_eval),
                desc=f"{self.get_name()} attribution",
            )
            attr_progress.update(0)

        # LOOP FUNCTION -> HERE WE NEED TO GET THE PERTURBED INPUT, BUILD OUR PERTURBATION_MASK AND PASS IT TO THE FORWARD FUNCTION
        # one convoluted thing is feature mask scope, it is passed to this method from Lime via super().attribute(), but it's not in the 
        # method declaration
        feature_mask = kwargs["feature_mask"]
        batch_count = 0
        for _ in range(n_samples):
            try:
                interpretable_inp, curr_model_input = gen_perturb_func()
                perturbation_mask = self._get_perturbation_mask(interpretable_inp, curr_model_input, feature_mask)
            except StopIteration:
                warnings.warn(
                    "Generator completed prior to given n_samples iterations!",
                    stacklevel=1,
                )
                break
            except:
                print("error in the perturbation mask generation")
                raise
               
            # add the perturbation mask as an additional parameter for the forward function
            if additional_forward_args is None:
                additional_forward_args_with_mask = (perturbation_mask,)
            elif isinstance(additional_forward_args, tuple):
                additional_forward_args_with_mask = additional_forward_args + (perturbation_mask,)
            else:
                additional_forward_args_with_mask = (additional_forward_args, perturbation_mask)
             #---------------------------------------------
            batch_count += 1
            interpretable_inps.append(interpretable_inp)
            curr_model_inputs.append(curr_model_input)

            curr_sim = self.similarity_func(
                inputs, curr_model_input, interpretable_inp, **kwargs
            )
            similarities.append(
                curr_sim.flatten()
                if isinstance(curr_sim, Tensor)
                else torch.tensor([curr_sim], device=device)
            )

            if len(curr_model_inputs) == perturbations_per_eval:
                # change: removed if condition, to rebuild final expanded_additional_forward_args at each iteration
                #if expanded_additional_args is None:
                expanded_additional_args = _expand_additional_forward_args(
                    additional_forward_args_with_mask, len(curr_model_inputs)
                )
                if expanded_target is None:
                    expanded_target = _expand_target(target, len(curr_model_inputs))
            
                
                model_out = self._evaluate_batch(
                    curr_model_inputs,
                    expanded_target,
                    expanded_additional_args,
                    device,
                )

                if show_progress:
                    attr_progress.update()

                outputs.append(model_out)

                curr_model_inputs = []

        if len(curr_model_inputs) > 0:
            expanded_additional_args = _expand_additional_forward_args(
                additional_forward_args_with_mask, len(curr_model_inputs)
            )
            expanded_target = _expand_target(target, len(curr_model_inputs))

            model_out = self._evaluate_batch(
                curr_model_inputs,
                expanded_target,
                expanded_additional_args,
                device,
            )
            if show_progress:
                attr_progress.update()
            outputs.append(model_out)

        if show_progress:
            attr_progress.close()

        # Argument 1 to "cat" has incompatible type
        # "list[Tensor | tuple[Tensor, ...]]";
        # expected "tuple[Tensor, ...] | list[Tensor]"  [arg-type]
        combined_interp_inps = torch.cat(interpretable_inps).float()  # type: ignore
        combined_outputs = (
            torch.cat(outputs) if len(outputs[0].shape) > 0 else torch.stack(outputs)
        ).float()
        combined_sim = (
            torch.cat(similarities)
            if len(similarities[0].shape) > 0
            else torch.stack(similarities)
        ).float()
        self.dataset = TensorDataset(combined_interp_inps, combined_outputs, combined_sim)
        self.interpretable_model.fit(DataLoader(self.dataset, batch_size=batch_count))
        return self.interpretable_model.representation()


    def _get_perturbation_mask(
        self,
        interpretable_input: torch.Tensor,       # shape = (B, M)
        original_inputs: TensorOrTupleOfTensors, # shape = (B, C, D, W, H) or tuple thereof
        feature_mask,
    ) -> Union[torch.BoolTensor, Tuple[torch.BoolTensor, ...]]:
        """
        Build a Boolean mask of shape (B, *input_dims) indicating which
        elements should be perturbed (True) vs. left untouched (False).
        """
    
        # Case 1: single‐Tensor input
        if isinstance(feature_mask, torch.Tensor):
            # advanced indexing over the batch dimension
            # result has shape (B, *feature_mask.shape)
            mask = interpretable_input[:, feature_mask]
            mask = ~mask.bool()

            return mask
    
        # Case 2: multi‐input (tuple) model
        else:
            masks = []
            for fm_i in feature_mask:
                mask_i = interpretable_input[:, fm_i]  # → (B, *fm_i.shape)
                masks.append(~mask_i.bool())
            return tuple(masks)

    

    def _get_perturb_generator_func(
        self, inputs: TensorOrTupleOfTensorsGeneric, **kwargs: Any
    ) -> Callable[
        [], Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
    ]:
        perturb_generator: Optional[Iterator[TensorOrTupleOfTensorsGeneric]]
        perturb_generator = None
        if inspect.isgeneratorfunction(self.perturb_func):
            perturb_generator = self.perturb_func(inputs, **kwargs)

        def generate_perturbation() -> (
            Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
        ):
            if perturb_generator:
                curr_sample = next(perturb_generator)
            else:
                curr_sample = self.perturb_func(inputs, **kwargs)

            if self.perturb_interpretable_space:
                interpretable_inp = curr_sample
                curr_model_input = self.from_interp_rep_transform(  # type: ignore
                    curr_sample, inputs, **kwargs
                )
            else:
                curr_model_input = curr_sample
                interpretable_inp = self.to_interp_rep_transform(  # type: ignore
                    curr_sample, inputs, **kwargs
                )

            return interpretable_inp, curr_model_input  # type: ignore

        return generate_perturbation

    # pyre-fixme[24] Generic type `Callable` expects 2 type parameters.)
    def attribute_future(self) -> Callable:
        r"""
        This method is not implemented for LimeBase.
        """
        raise NotImplementedError(
            "LimeBase does not support attribution of future samples."
        )

    def _evaluate_batch(
        self,
        curr_model_inputs: List[TensorOrTupleOfTensorsGeneric],
        expanded_target: TargetType,
        expanded_additional_args: object,
        device: torch.device,
    ) -> Tensor:
        model_out = _run_forward(
            self.forward_func,
            #MOMENTANEAOUS---> sliding_window forward function only works with single items (no batch) --> take first
            #_reduce_list(curr_model_inputs),
            _reduce_list(curr_model_inputs)[0],
            expanded_target,
            expanded_additional_args,
        )
        if isinstance(model_out, Tensor):
            assert model_out.numel() == len(curr_model_inputs), (
                "Number of outputs is not appropriate, must return "
                "one output per perturbed input"
            )
        if isinstance(model_out, Tensor):
            return model_out.flatten()
        return torch.tensor([model_out], device=device)

    def has_convergence_delta(self) -> bool:
        return False

    @property
    def multiplies_by_inputs(self) -> bool:
        return False


# Default transformations and methods
# for Lime child implementation.


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
    assert (
        "feature_mask" in kwargs
    ), "Must provide feature_mask to use default interpretable representation transform"
    assert (
        "baselines" in kwargs
    ), "Must provide baselines to use default interpretable representation transform"
    feature_mask = kwargs["feature_mask"]
    if isinstance(feature_mask, Tensor):
        binary_mask = curr_sample[0][feature_mask].bool()
        input_space_transformed = (
            binary_mask.to(original_inputs.dtype) * original_inputs
            + (~binary_mask).to(original_inputs.dtype) * kwargs["baselines"]
        )
        
        return input_space_transformed
    else:
        binary_mask = tuple(
            curr_sample[0][feature_mask[j]].bool() for j in range(len(feature_mask))
        )
        input_space_transformed = tuple(
            binary_mask[j].to(original_inputs[j].dtype) * original_inputs[j]
            + (~binary_mask[j]).to(original_inputs[j].dtype) * kwargs["baselines"][j]
            for j in range(len(feature_mask))
        )
        return input_space_transformed


def get_exp_kernel_similarity_function(
    distance_mode: str = "cosine",
    kernel_width: float = 1.0,
) -> Callable[..., float]:
    r"""
    This method constructs an appropriate similarity function to compute
    weights for perturbed sample in LIME. Distance between the original
    and perturbed inputs is computed based on the provided distance mode,
    and the distance is passed through an exponential kernel with given
    kernel width to convert to a range between 0 and 1.

    The callable returned can be provided as the similarity_fn for
    Lime or LimeBase.

    Args:

        distance_mode (str, optional): Distance mode can be either "cosine" or
                    "euclidean" corresponding to either cosine distance
                    or Euclidean distance respectively. Distance is computed
                    by flattening the original inputs and perturbed inputs
                    (concatenating tuples of inputs if necessary) and computing
                    distances between the resulting vectors.
                    Default: "cosine"
        kernel_width (float, optional):
                    Kernel width for exponential kernel applied to distance.
                    Default: 1.0

    Returns:

        *Callable*:
        - **similarity_fn** (*Callable*):
            Similarity function. This callable can be provided as the
            similarity_fn for Lime or LimeBase.
    """

    # pyre-fixme[3]: Return type must be annotated.
    # pyre-fixme[2]: Parameter must be annotated.
    def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs):
        flattened_original_inp = _flatten_tensor_or_tuple(original_inp).float()
        flattened_perturbed_inp = _flatten_tensor_or_tuple(perturbed_inp).float()
        if distance_mode == "cosine":
            cos_sim = CosineSimilarity(dim=0)
            distance = 1 - cos_sim(flattened_original_inp, flattened_perturbed_inp)
        elif distance_mode == "euclidean":
            distance = torch.norm(flattened_original_inp - flattened_perturbed_inp)
        else:
            raise ValueError("distance_mode must be either cosine or euclidean.")
        return math.exp(-1 * (distance**2) / (2 * (kernel_width**2)))

    return default_exp_kernel


def default_perturb_func(
    original_inp: TensorOrTupleOfTensorsGeneric, **kwargs: object
) -> Tensor:
    assert (
        "num_interp_features" in kwargs
    ), "Must provide num_interp_features to use default interpretable sampling function"
    if isinstance(original_inp, Tensor):
        device = original_inp.device
    else:
        device = original_inp[0].device

    probs = torch.ones(1, cast(int, kwargs["num_interp_features"])) * 0.5
    return torch.bernoulli(probs).to(device=device).long()


def construct_feature_mask(
    feature_mask: Union[None, Tensor, Tuple[Tensor, ...]],
    formatted_inputs: Tuple[Tensor, ...],
) -> Tuple[Tuple[Tensor, ...], int]:
    feature_mask_tuple: Tuple[Tensor, ...]
    if feature_mask is None:
        feature_mask_tuple, num_interp_features = _construct_default_feature_mask(
            formatted_inputs
        )
    else:
        feature_mask_tuple = _format_tensor_into_tuples(feature_mask)
        min_interp_features = int(
            min(
                torch.min(single_mask).item()
                for single_mask in feature_mask_tuple
                if single_mask.numel()
            )
        )
        if min_interp_features != 0:
            warnings.warn(
                "Minimum element in feature mask is not 0, shifting indices to"
                " start at 0.",
                stacklevel=2,
            )
            feature_mask_tuple = tuple(
                single_mask - min_interp_features for single_mask in feature_mask_tuple
            )

        num_interp_features = _get_max_feature_index(feature_mask_tuple) + 1
    return feature_mask_tuple, num_interp_features




In [None]:
class LimeWithCustomArgumentToForwardFunc(LimeBaseWithCustomArgumentToForwardFunc):
    r"""
    Here we create a modification of Lime class from Captum Library (https://captum.ai/api/_modules/captum/attr/_core/lime.html)
    This will just inherit our modified LimeBase class
    """

    def __init__(
        self,
        forward_func: Callable[..., Tensor],
        interpretable_model: Optional[Model] = None,
        # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
        similarity_func: Optional[Callable] = None,
        # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
        perturb_func: Optional[Callable] = None,
    ) -> None:
        r"""

        Args:


            forward_func (Callable): The forward function of the model or any
                    modification of it
            interpretable_model (Model, optional): Model object to train
                    interpretable model.

                    This argument is optional and defaults to SkLearnLasso(alpha=0.01),
                    which is a wrapper around the Lasso linear model in SkLearn.
                    This requires having sklearn version >= 0.23 available.

                    Other predefined interpretable linear models are provided in
                    captum._utils.models.linear_model.

                    Alternatively, a custom model object must provide a `fit` method to
                    train the model, given a dataloader, with batches containing
                    three tensors:

                    - interpretable_inputs: Tensor
                      [2D num_samples x num_interp_features],
                    - expected_outputs: Tensor [1D num_samples],
                    - weights: Tensor [1D num_samples]

                    The model object must also provide a `representation` method to
                    access the appropriate coefficients or representation of the
                    interpretable model after fitting.

                    Note that calling fit multiple times should retrain the
                    interpretable model, each attribution call reuses
                    the same given interpretable model object.
            similarity_func (Callable, optional): Function which takes a single sample
                    along with its corresponding interpretable representation
                    and returns the weight of the interpretable sample for
                    training the interpretable model.
                    This is often referred to as a similarity kernel.

                    This argument is optional and defaults to a function which
                    applies an exponential kernel to the cosine distance between
                    the original input and perturbed input, with a kernel width
                    of 1.0.

                    A similarity function applying an exponential
                    kernel to cosine / euclidean distances can be constructed
                    using the provided get_exp_kernel_similarity_function in
                    captum.attr._core.lime.

                    Alternately, a custom callable can also be provided.
                    The expected signature of this callable is:

                    >>> def similarity_func(
                    >>>    original_input: Tensor or tuple[Tensor, ...],
                    >>>    perturbed_input: Tensor or tuple[Tensor, ...],
                    >>>    perturbed_interpretable_input:
                    >>>        Tensor [2D 1 x num_interp_features],
                    >>>    **kwargs: Any
                    >>> ) -> float or Tensor containing float scalar

                    perturbed_input and original_input will be the same type and
                    contain tensors of the same shape, with original_input
                    being the same as the input provided when calling attribute.

                    kwargs includes baselines, feature_mask, num_interp_features
                    (integer, determined from feature mask).
            perturb_func (Callable, optional): Function which returns a single
                    sampled input, which is a binary vector of length
                    num_interp_features, or a generator of such tensors.

                    This function is optional, the default function returns
                    a binary vector where each element is selected
                    independently and uniformly at LimeWithCustomArgumentToForwardFuncrandom. Custom
                    logic for selecting sampled binary vectors can
                    be implemented by providing a function with the
                    following expected signature:

                    >>> perturb_func(
                    >>>    original_input: Tensor or tuple[Tensor, ...],
                    >>>    **kwargs: Any
                    >>> ) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
                    >>>  or generator yielding such tensors

                    kwargs includes baselines, feature_mask, num_interp_features
                    (integer, determined from feature mask).

        """
        if interpretable_model is None:
            interpretable_model = SkLearnLasso(alpha=0.01)

        if similarity_func is None:
            similarity_func = get_exp_kernel_similarity_function()

        if perturb_func is None:
            perturb_func = default_perturb_func

        LimeBase.__init__(
            self,
            forward_func,
            interpretable_model,
            similarity_func,
            perturb_func,
            True,
            default_from_interp_rep_transform,
            None,
        )

    @log_usage(part_of_slo=True)
    def attribute(  # type: ignore
        self,
        inputs: TensorOrTupleOfTensorsGeneric,
        baselines: BaselineType = None,
        target: TargetType = None,
        additional_forward_args: Optional[object] = None,
        feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
        n_samples: int = 25,
        perturbations_per_eval: int = 1,
        return_input_shape: bool = True,
        show_progress: bool = False,
    ) -> TensorOrTupleOfTensorsGeneric:
        r"""
        This method attributes the output of the model with given target index
        (in case it is provided, otherwise it assumes that output is a
        scalar) to the inputs of the model using the approach described above,
        training an interpretable model and returning a representation of the
        interpretable model.

        It is recommended to only provide a single example as input (tensors
        with first dimension or batch size = 1). This is because LIME is generally
        used for sample-based interpretability, training a separate interpretable
        model to explain a model's prediction on each individual example.

        A batch of inputs can also be provided as inputs, similar to
        other perturbation-based attribution methods. In this case, if forward_fn
        returns a scalar per example, attributions will be computed for each
        example independently, with a separate interpretable model trained for each
        example. Note that provided similarity and pertforward_funcurbation functions will be
        provided each example separately (first dimension = 1) in this case.
        If forward_fn returns a scalar per batch (e.g. loss), attributions will
        still be computed using a single interpretable model for the full batch.
        In this case, similarity and perturbation functions will be provided the
        same original input containing the full batch.

        The number of interpretable features is determined from the provided
        feature mask, or if none is provided, from the default feature mask,
        which considers each scalar input as a separate feature. It is
        generally recommended to provide a feature mask which groups features
        into a small number of interpretable features / components (e.g.
        superpixels in images).

        Args:

            inputs (Tensor or tuple[Tensor, ...]): Input for which LIME
                        is computed. If forward_func takes a single
                        tensor as input, a single input tensor should be provided.
                        If forward_func takes multiple tensors as input, a tuple
                        of the input tensors should be provided. It is assumed
                        that for all given input tensors, dimension 0 corresponds
                        to the number of examples, and if multiple input tensors
                        are provided, the examples must be aligned appropriately.
            baselines (scalar, Tensor, tuple of scalar, or Tensor, optional):
                        Baselines define reference value which replaces each
                        feature when the corresponding interpretable feature
                        is set to 0.
                        Baselines can be provided as:

                        - a single tensor, if inputs is a single tensor, with
                          exactly the same dimensions as inputs or the first
                          dimension is one and the remaining dimensions match
                          with inputs.

                        - a single scalar, if inputs is a single tensor, which will
                          be broadcasted for each input value in input tensor.

                        - a tuple of tensors or scalars, the baseline corresponding
                          to each tensor in the inputs' tuple can be:

                          - either a tensor with matching dimensions to
                            corresponding tensor in the inputs' tuple
                            or the first dimension is one and the remaining
                            dimensions match with the corresponding
                            input tensor.

                          - or a scalar, corresponding to a tensor in the
                            inputs' tuple. This scalar value is broadcasted
                            for corresponding input tensor.

                        In the cases when `baselines` iforward_funcs not provided, we internally
                        use zero scalar corresponding to each input tensor.
                        Default: None
            target (int, tuple, Tensor, or list, optional): Output indices for
                        which surrogate model is trained
                        (for classification cases,
                        this is usually the target class).
                        If the network returns a scalar value per example,
                        no target index is necessary.
                        For general 2D outputs, targets can be either:

                        - a single integer or a tensor containing a single
                          integer, which is applied to all input examples

                        - a list of integers or a 1D tensor, with length matching
                          the number of examples in inputs (dim 0). Each integer
                          is applied as the target for the corresponding example.

                        For outputs with > 2 dimensions, targets can be either:

                        - A single tuple, which contains #output_dims - 1
                          elements. This target index is applied to all examples.

                        - A list of tuples with length equal to the number of
                          examples in inputs (dim 0), and each tuple containing
                          #output_dims - 1 elements. Each tuple is applied as the
                          target for the corresponding example.

                        Default: None
            additional_forward_args (Any, optional): If the forward function
                        requires additional arguments other than the inputs for
                        which attributions should not be computed, this argument
                        can be provided. It must be either a single additional
                        argument of a Tensor or arbitrary (non-tuple) type or a
                        tuple containing multiple additional arguments including
                        tensors or any arbitrary python types. These arguments
                        are provided to forward_func in order following the
                        arguments in inputs.
                        For a tensor, the first dimension of the tensor must
                        correspond to the number of examples. It will be
                        repeated for each of `n_steps` along the integrated
                        path. For all other types, the given argument is used
                        for all forward evaluations.
                        Note that attributions are not computed with respect
                        to these arguments.
                        Default: None
            feature_mask (Tensor or tuple[Tensor, ...], optional):
                        feature_mask defines a mask for the input, grouping
                        features which correspond to the same
                        interpretable feature. feature_mask
                        should contain the same number of tensors as inputs.
                        Each tensor should
                        be the same size as the corresponding input or
                        broadcastable to match the inpuforward_funct tensor. Values across
                        all tensors should be integers in the range 0 to
                        num_interp_features - 1, and indices corresponding to the
                        same feature should have the same value.
                        Note that features are grouped across tensors
                        (unlike feature ablation and occlusion), so
                        if the same index is used in different tensors, those
                        features are still grouped and added simultaneously.
                        If None, then a feature mask is constructed which assigns
                        each scalar within a tensor as a separate feature.
                        Default: None
            n_samples (int, optional): The number of samples of the original
                        model used to train the surrogate interpretable model.
                        Default: `50` if `n_samples` is not provided.
            perturbations_per_eval (int, optional): Allows multiple samples
                        to be processed simultaneously in one call to forward_fn.
                        Each forward pass will contain a maximum of
                        perturbations_per_eval * #examples samples.
                        For DataParallel models, each batch is split among the
                        available devices, so evaluations on each available
                        device contain at most
                        (perturbations_per_eval * #examples) / num_devices
                        samples.
                        If the forward function returns a single scalar per batch,
                        perturbations_per_eval must be set to 1.
                        Default: 1
            return_input_shape (bool, optional): Determines whether the returned
                        tensor(s) only contain the coefficients for each interp-
                        retable feature from the trained surrogate model, or
                        whether the returned attributions match the input shape.
                        When return_input_shape is True, the return type of attribute
                        matches the input shape, with each element containing the
                        coefficient of the corresponding interpretale feature.
                        All elements with the same value in the feature mask
                        will contain the same coefficient in the returned
                        attributions.
                        If forward_func returns a single element per batch, then the
                        first dimension of each tensor will be 1, and the remaining
                        dimensions will have the same shape as the original input
                        tensor.
                        If return_input_shape is False, a 1D
                        tensor is returned, containing only the coefficients
                        of the trained interpreatable models, with length
                        num_interp_features.
            show_progress (bool, optional): Displays the progress of computation.
                        It will try to use tqdm if available for advanced features
                        (e.g. time estimation). Otherwise, it will fallback to
                        a simple output of progress.
                        Default: False

        Returns:
            *Tensor* or *tuple[Tensor, ...]* of **attributions**:
            - **attributions** (*Tensor* or *tuple[Tensor, ...]*):
                        The attributions with respect to each input feature.
                        If return_input_shape = True, attributions will be
                        the same size as the provided inputs, with each value
                        providing the coefficient of the corresponding
                        interpretale feature.
                        If return_input_shape is False, a 1D
                        tensor is returned, containing only the coefficients
                        of the trained interpreatable models, with length
                        num_interp_features.
        Examples::

            >>> # SimpleClassifier takes a single input tensor of size Nx4x4,
            >>> # and returns an Nx3 tensor of class probabilities.
            >>> net = SimpleClassifier()

            >>> # Generating random input with size 1 x 4 x 4
            >>> input = torch.randn(1, 4, 4)

            >>> # Defining Lime interpreter
            >>> lime = Lime(net)
            >>> # Computes attribution, with each of the 4 x 4 = 16
            >>> # features as a separate interpretable feature
            >>> attr = lime.attribute(input, target=1, n_samples=200)

            >>> # Alternatively, we can group each 2x2 square of the inputs
            >>> # as one 'interpretable' feature and perturb them together.
            >>> # This can be done by creating a feature mask as follows, which
            >>> # defines the feature groups, e.g.:
            >>> # +---+---+---+---+
            >>> # | 0 | 0 | 1 | 1 |
            >>> # +---+---+---+---+
            >>> # | 0 | 0 | 1 | 1 |
            >>> # +---+---+---+---+
            >>> # | 2 | 2 | 3 | 3 |
            >>> # +---+---+---+---+
            >>> # | 2 | 2 | 3 | 3 |
            >>> # +---+---+---+---+
            >>> # With this mask, all inputs with the same value are set to their
            >>> # baseline value, when the corresponding binary interpretable
            >>> # feature is set to 0.
            >>> # The attributions can be calculated as follows:
            >>> # feature mask has dimensions 1 x 4 x 4
            >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
            >>>                             [2,2,3,3],[2,2,3,3]]])

            >>> # Computes interpretable model and returning attributions
            >>> # matching input shape.
            >>> attr = lime.attribute(input, target=1, feature_mask=feature_mask)
        """
        return self._attribute_kwargs(
            inputs=inputs,
            baselines=baselines,
            target=target,
            additional_forward_args=additional_forward_args,
            feature_mask=feature_mask,
            n_samples=n_samples,
            perturbations_per_eval=perturbations_per_eval,
            return_input_shape=return_input_shape,
            show_progress=show_progress,
        )

    # pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
    def attribute_future(self) -> Callable:
        return super().attribute_future()

    def _attribute_kwargs(  # type: ignore
        self,
        inputs: TensorOrTupleOfTensorsGeneric,
        baselines: BaselineType = None,
        target: TargetType = None,
        additional_forward_args: Optional[object] = None,
        feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
        n_samples: int = 25,
        perturbations_per_eval: int = 1,
        return_input_shape: bool = True,
        show_progress: bool = False,
        **kwargs: object,
    ) -> TensorOrTupleOfTensorsGeneric:
        is_inputs_tuple = _is_tuple(inputs)
        formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
        bsz = formatted_inputs[0].shape[0]

        feature_mask, num_interp_features = construct_feature_mask(
            feature_mask, formatted_inputs
        )

        if num_interp_features > 10000:
            warnings.warn(
                "Attempting to construct interpretable model with > 10000 features."
                "This can be very slow or lead to OOM issues. Please provide a feature"
                "mask which groups input features to reduce the number of interpretable"
                "features. ",
                stacklevel=1,
            )

        coefs: Tensor
        if bsz > 1:
            test_output = _run_forward(
                self.forward_func, inputs, target, additional_forward_args
            )
            if isinstance(test_output, Tensor) and torch.numel(test_output) > 1:
                if torch.numel(test_output) == bsz:
                    warnings.warn(
                        "You are providing multiple inputs for Lime / Kernel SHAP "
                        "attributions. This trains a separate interpretable model "
                        "for each example, which can be time consuming. It is "
                        "recommended to compute attributions for one example at a "
                        "time.",
                        stacklevel=1,
                    )
                    output_list = []
                    for (
                        curr_inps,
                        curr_target,
                        curr_additional_args,
                        curr_baselines,
                        curr_feature_mask,
                    ) in _batch_example_iterator(
                        bsz,
                        formatted_inputs,
                        target,
                        additional_forward_args,# -----> CAN BE ALSO BATCHED AUTOMATICALLY BY THE LIBRARY ITERATOR
                        baselines,
                        feature_mask,
                    ):
                        coefs = super().attribute.__wrapped__(
                            self,
                            inputs=curr_inps if is_inputs_tuple else curr_inps[0],
                            target=curr_target,
                            additional_forward_args=curr_additional_args,
                            n_samples=n_samples,
                            perturbations_per_eval=perturbations_per_eval,
                            baselines=(
                                curr_baselines if is_inputs_tuple else curr_baselines[0]
                            ),
                            feature_mask=(
                                curr_feature_mask
                                if is_inputs_tuple
                                else curr_feature_mask[0]
                            ),
                            num_interp_features=num_interp_features,
                            show_progress=show_progress,
                            **kwargs,
                        )
                        if return_input_shape:
                            output_list.append(
                                self._convert_output_shape(
                                    curr_inps,
                                    curr_feature_mask,
                                    coefs,
                                    num_interp_features,
                                    is_inputs_tuple,
                                )
                            )
                        else:
                            output_list.append(coefs.reshape(1, -1))  # type: ignore

                    return _reduce_list(output_list)
                else:
                    raise AssertionError(
                        "Invalid number of outputs, forward function should return a"
                        "scalar per example or a scalar per input batch."
                    )
            else:
                assert perturbations_per_eval == 1, (
                    "Perturbations per eval must be 1 when forward function"
                    "returns single value per batch!"
                )

        coefs = super().attribute.__wrapped__(
            self,
            inputs=inputs,
            target=target,
            additional_forward_args=additional_forward_args,
            n_samples=n_samples,
            perturbations_per_eval=perturbations_per_eval,
            baselines=baselines if is_inputs_tuple else baselines[0],
            feature_mask=feature_mask if is_inputs_tuple else feature_mask[0],
            num_interp_features=num_interp_features,
            show_progress=show_progress,
            **kwargs,
        )
        if return_input_shape:
            # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
            #  `Tuple[Tensor, ...]`.
            return self._convert_output_shape(
                formatted_inputs,
                feature_mask,
                coefs,
                num_interp_features,
                is_inputs_tuple,
    
            leading_dim_one=(bsz > 1),
            )
        else:
            return coefs

    @typing.overload
    def _convert_output_shape(
        self,
        formatted_inp: Tuple[Tensor, ...],
        feature_mask: Tuple[Tensor, ...],
        coefs: Tensor,
        num_interp_features: int,
        is_inputs_tuple: Literal[True],
        leading_dim_one: bool = False,
    ) -> Tuple[Tensor, ...]: ...

    @typing.overload
    def _convert_output_shape(  # type: ignore
        self,
        formatted_inp: Tuple[Tensor, ...],
        feature_mask: Tuple[Tensor, ...],
        coefs: Tensor,
        num_interp_features: int,
        is_inputs_tuple: Literal[False],
        leading_dim_one: bool = False,
    ) -> Tensor: ...

    @typing.overload
    def _convert_output_shape(
        self,
        formatted_inp: Tuple[Tensor, ...],
        feature_mask: Tuple[Tensor, ...],
        coefs: Tensor,
        num_interp_features: int,
        is_inputs_tuple: bool,
        leading_dim_one: bool = False,
    ) -> Union[Tensor, Tuple[Tensor, ...]]: ...

    def _convert_output_shape(
        self,
        formatted_inp: Tuple[Tensor, ...],
        feature_mask: Tuple[Tensor, ...],
        coefs: Tensor,
        num_interp_features: int,
        is_inputs_tuple: bool,
        leading_dim_one: bool = False,
    ) -> Union[Tensor, Tuple[Tensor, ...]]:
        coefs = coefs.flatten()
        attr = [
            torch.zeros_like(single_inp, dtype=torch.float)
            for single_inp in formatted_inp
        ]
        for tensor_ind in range(len(formatted_inp)):
            for single_feature in range(num_interp_features):
                attr[tensor_ind] += (
                    coefs[single_feature].item()
                    * (feature_mask[tensor_ind] == single_feature).float()
                )
        if leading_dim_one:
            for i in range(len(attr)):
                attr[i] = attr[i][0:1]
        return _format_output(is_inputs_tuple, tuple(attr))


In [None]:
#!/usr/bin/env python3

# pyre-strict

from typing import Callable, cast, Generator, Optional, Tuple, Union

import torch
from captum._utils.models.linear_model import SkLearnLinearRegression
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.lime import construct_feature_mask, Lime
from captum.attr._utils.common import _format_input_baseline
from captum.log import log_usage
from torch import Tensor
from torch.distributions.categorical import Categorical


class KernelShapWithMask(LimeWithCustomArgumentToForwardFunc):
    r"""
    Kernel SHAP is a method that uses the LIME framework to compute
    Shapley Values. Setting the loss function, weighting kernel and
    regularization terms appropriately in the LIME framework allows
    theoretically obtaining Shapley Values more efficiently than
    directly computing Shapley Values.

    More information regarding this method and proof of equivalence
    can be found in the original paper here:
    https://arxiv.org/abs/1705.07874
    """

    def __init__(self, forward_func: Callable[..., Tensor]) -> None:
        r"""
        Args:

            forward_func (Callable): The forward function of the model or
                        any modification of it.
        """
        Lime.__init__(
            self,
            forward_func,
            interpretable_model=SkLearnLinearRegression(),
            similarity_func=self.kernel_shap_similarity_kernel,
            perturb_func=self.kernel_shap_perturb_generator,
        )
        self.inf_weight = 1000000.0

    @log_usage(part_of_slo=True)
    def attribute(  # type: ignore
        self,
        inputs: TensorOrTupleOfTensorsGeneric,
        baselines: BaselineType = None,
        target: TargetType = None,
        additional_forward_args: Optional[object] = None,
        feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
        n_samples: int = 25,
        perturbations_per_eval: int = 1,
        return_input_shape: bool = True,
        show_progress: bool = False,
    ) -> TensorOrTupleOfTensorsGeneric:
        r"""
        This method attributes the output of the model with given target index
        (in case it is provided, otherwise it assumes that output is a
        scalar) to the inputs of the model using the approach described above,
        training an interpretable model based on KernelSHAP and returning a
        representation of the interpretable model.

        It is recommended to only provide a single example as input (tensors
        with first dimension or batch size = 1). This is because LIME / KernelShap
        is generally used for sample-based interpretability, training a separate
        interpretable model to explain a model's prediction on each individual example.

        A batch of inputs can also be provided as inputs, similar to
        other perturbation-based attribution methods. In this case, if forward_fn
        returns a scalar per example, attributions will be computed for each
        example independently, with a separate interpretable model trained for each
        example. Note that provided similarity and perturbation functions will be
        provided each example separately (first dimension = 1) in this case.
        If forward_fn returns a scalar per batch (e.g. loss), attributions will
        still be computed using a single interpretable model for the full batch.
        In this case, similarity and perturbation functions will be provided the
        same original input containing the full batch.

        The number of interpretable features is determined from the provided
        feature mask, or if none is provided, from the default feature mask,
        which considers each scalar input as a separate feature. It is
        generally recommended to provide a feature mask which groups features
        into a small number of interpretable features / components (e.g.
        superpixels in images).

        Args:

            inputs (Tensor or tuple[Tensor, ...]): Input for which KernelShap
                        is computed. If forward_func takes a single
                        tensor as input, a single input tensor should be provided.
                        If forward_func takes multiple tensors as input, a tuple
                        of the input tensors should be provided. It is assumed
                        that for all given input tensors, dimension 0 corresponds
                        to the number of examples, and if multiple input tensors
                        are provided, the examples must be aligned appropriately.
            baselines (scalar, Tensor, tuple of scalar, or Tensor, optional):
                        Baselines define the reference value which replaces each
                        feature when the corresponding interpretable feature
                        is set to 0.
                        Baselines can be provided as:

                        - a single tensor, if inputs is a single tensor, with
                          exactly the same dimensions as inputs or the first
                          dimension is one and the remaining dimensions match
                          with inputs.

                        - a single scalar, if inputs is a single tensor, which will
                          be broadcasted for each input value in input tensor.

                        - a tuple of tensors or scalars, the baseline corresponding
                          to each tensor in the inputs' tuple can be:

                          - either a tensor with matching dimensions to
                            corresponding tensor in the inputs' tuple
                            or the first dimension is one and the remaining
                            dimensions match with the corresponding
                            input tensor.

                          - or a scalar, corresponding to a tensor in the
                            inputs' tuple. This scalar value is broadcasted
                            for corresponding input tensor.

                        In the cases when `baselines` is not provided, we internally
                        use zero scalar corresponding to each input tensor.
                        Default: None
            target (int, tuple, Tensor, or list, optional): Output indices for
                        which surrogate model is trained
                        (for classification cases,
                        this is usually the target class).
                        If the network returns a scalar value per example,
                        no target index is necessary.
                        For general 2D outputs, targets can be either:

                        - a single integer or a tensor containing a single
                          integer, which is applied to all input examples

                        - a list of integers or a 1D tensor, with length matching
                          the number of examples in inputs (dim 0). Each integer
                          is applied as the target for the corresponding example.

                        For outputs with > 2 dimensions, targets can be either:

                        - A single tuple, which contains #output_dims - 1
                          elements. This target index is applied to all examples.

                        - A list of tuples with length equal to the number of
                          examples in inputs (dim 0), and each tuple containing
                          #output_dims - 1 elements. Each tuple is applied as the
                          target for the corresponding example.

                        Default: None
            additional_forward_args (Any, optional): If the forward function
                        requires additional arguments other than the inputs for
                        which attributions should not be computed, this argument
                        can be provided. It must be either a single additional
                        argument of a Tensor or arbitrary (non-tuple) type or a
                        tuple containing multiple additional arguments including
                        tensors or any arbitrary python types. These arguments
                        are provided to forward_func in order following the
                        arguments in inputs.
                        For a tensor, the first dimension of the tensor must
                        correspond to the number of examples. It will be
                        repeated for each of `n_steps` along the integrated
                        path. For all other types, the given argument is used
                        for all forward evaluations.
                        Note that attributions are not computed with respect
                        to these arguments.
                        Default: None
            feature_mask (Tensor or tuple[Tensor, ...], optional):
                        feature_mask defines a mask for the input, grouping
                        features which correspond to the same
                        interpretable feature. feature_mask
                        should contain the same number of tensors as inputs.
                        Each tensor should
                        be the same size as the corresponding input or
                        broadcastable to match the input tensor. Values across
                        all tensors should be integers in the range 0 to
                        num_interp_features - 1, and indices corresponding to the
                        same feature should have the same value.
                        Note that features are grouped across tensors
                        (unlike feature ablation and occlusion), so
                        if the same index is used in different tensors, those
                        features are still grouped and added simultaneously.
                        If None, then a feature mask is constructed which assigns
                        each scalar within a tensor as a separate feature.
                        Default: None
            n_samples (int, optional): The number of samples of the original
                        model used to train the surrogate interpretable model.
                        Default: `50` if `n_samples` is not provided.
            perturbations_per_eval (int, optional): Allows multiple samples
                        to be processed simultaneously in one call to forward_fn.
                        Each forward pass will contain a maximum of
                        perturbations_per_eval * #examples samples.
                        For DataParallel models, each batch is split among the
                        available devices, so evaluations on each available
                        device contain at most
                        (perturbations_per_eval * #examples) / num_devices
                        samples.
                        If the forward function returns a single scalar per batch,
                        perturbations_per_eval must be set to 1.
                        Default: 1
            return_input_shape (bool, optional): Determines whether the returned
                        tensor(s) only contain the coefficients for each interp-
                        retable feature from the trained surrogate model, or
                        whether the returned attributions match the input shape.
                        When return_input_shape is True, the return type of attribute
                        matches the input shape, with each element containing the
                        coefficient of the corresponding interpretable feature.
                        All elements with the same value in the feature mask
                        will contain the same coefficient in the returned
                        attributions. If return_input_shape is False, a 1D
                        tensor is returned, containing only the coefficients
                        of the trained interpretable model, with length
                        num_interp_features.
            show_progress (bool, optional): Displays the progress of computation.
                        It will try to use tqdm if available for advanced features
                        (e.g. time estimation). Otherwise, it will fallback to
                        a simple output of progress.
                        Default: False

        Returns:
            *Tensor* or *tuple[Tensor, ...]* of **attributions**:
            - **attributions** (*Tensor* or *tuple[Tensor, ...]*):
                        The attributions with respect to each input feature.
                        If return_input_shape = True, attributions will be
                        the same size as the provided inputs, with each value
                        providing the coefficient of the corresponding
                        interpretale feature.
                        If return_input_shape is False, a 1D
                        tensor is returned, containing only the coefficients
                        of the trained interpreatable models, with length
                        num_interp_features.
        Examples::
            >>> # SimpleClassifier takes a single input tensor of size Nx4x4,
            >>> # and returns an Nx3 tensor of class probabilities.
            >>> net = SimpleClassifier()

            >>> # Generating random input with size 1 x 4 x 4
            >>> input = torch.randn(1, 4, 4)

            >>> # Defining KernelShap interpreter
            >>> ks = KernelShap(net)
            >>> # Computes attribution, with each of the 4 x 4 = 16
            >>> # features as a separate interpretable feature
            >>> attr = ks.attribute(input, target=1, n_samples=200)

            >>> # Alternatively, we can group each 2x2 square of the inputs
            >>> # as one 'interpretable' feature and perturb them together.
            >>> # This can be done by creating a feature mask as follows, which
            >>> # defines the feature groups, e.g.:
            >>> # +---+---+---+---+
            >>> # | 0 | 0 | 1 | 1 |
            >>> # +---+---+---+---+
            >>> # | 0 | 0 | 1 | 1 |
            >>> # +---+---+---+---+
            >>> # | 2 | 2 | 3 | 3 |
            >>> # +---+---+---+---+
            >>> # | 2 | 2 | 3 | 3 |
            >>> # +---+---+---+---+
            >>> # With this mask, all inputs with the same value are set to their
            >>> # baseline value, when the corresponding binary interpretable
            >>> # feature is set to 0.
            >>> # The attributions can be calculated as follows:
            >>> # feature mask has dimensions 1 x 4 x 4
            >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
            >>>                             [2,2,3,3],[2,2,3,3]]])

            >>> # Computes KernelSHAP attributions with feature mask.
            >>> attr = ks.attribute(input, target=1, feature_mask=feature_mask)
        """
        formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
        feature_mask, num_interp_features = construct_feature_mask(
            feature_mask, formatted_inputs
        )
        num_features_list = torch.arange(num_interp_features, dtype=torch.float)
        denom = num_features_list * (num_interp_features - num_features_list)
        probs = torch.tensor((num_interp_features - 1)) / denom
        probs[0] = 0.0
        return self._attribute_kwargs(
            inputs=inputs,
            baselines=baselines,
            target=target,
            additional_forward_args=additional_forward_args,
            feature_mask=feature_mask,
            n_samples=n_samples,
            perturbations_per_eval=perturbations_per_eval,
            return_input_shape=return_input_shape,
            num_select_distribution=Categorical(probs),
            show_progress=show_progress,
        )

    # pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
    def attribute_future(self) -> Callable:
        r"""
        This method is not implemented for KernelShap.
        """
        raise NotImplementedError("attribute_future is not implemented for KernelShap")

    def kernel_shap_similarity_kernel(
        self,
        _,
        __,
        interpretable_sample: Tensor,
        **kwargs: object,
    ) -> Tensor:
        assert (
            "num_interp_features" in kwargs
        ), "Must provide num_interp_features to use default similarity kernel"
        num_selected_features = int(interpretable_sample.sum(dim=1).item())
        num_features = kwargs["num_interp_features"]
        if num_selected_features == 0 or num_selected_features == num_features:
            # weight should be theoretically infinite when
            # num_selected_features = 0 or num_features
            # enforcing that trained linear model must satisfy
            # end-point criteria. In practice, it is sufficient to
            # make this weight substantially larger so setting this
            # weight to 1000000 (all other weights are 1).
            similarities = self.inf_weight
        else:
            similarities = 1.0
        return torch.tensor([similarities])

    def kernel_shap_perturb_generator(
        self,
        original_inp: Union[Tensor, Tuple[Tensor, ...]],
        **kwargs: object,
    ) -> Generator[Tensor, None, None]:
        r"""
        Perturbations are sampled by the following process:
         - Choose k (number of selected features), based on the distribution
                p(k) = (M - 1) / (k * (M - k))

            where M is the total number of features in the interpretable space

         - Randomly select a binary vector with k ones, each sample is equally
            likely. This is done by generating a random vector of normal
            values and thresholding based on the top k elements.

         Since there are M choose k vectors with k ones, this weighted sampling
         is equivalent to applying the Shapley kernel for the sample weight,
         defined as:
         k(M, k) = (M - 1) / (k * (M - k) * (M choose k))
        """
        assert (
            "num_select_distribution" in kwargs and "num_interp_features" in kwargs
        ), (
            "num_select_distribution and num_interp_features are necessary"
            " to use kernel_shap_perturb_func"
        )
        if isinstance(original_inp, Tensor):
            device = original_inp.device
        else:
            device = original_inp[0].device
        num_features = cast(int, kwargs["num_interp_features"])
        yield torch.ones(1, num_features, device=device, dtype=torch.long)
        yield torch.zeros(1, num_features, device=device, dtype=torch.long)
        while True:
            num_selected_features = cast(
                Categorical, kwargs["num_select_distribution"]
            ).sample()
            rand_vals = torch.randn(1, num_features)
            threshold = torch.kthvalue(
                rand_vals, num_features - num_selected_features
            ).values.item()
            yield (rand_vals > threshold).to(device=device).long()

   

In [None]:
"""    # Part 3: Quick sanity-check of KernelShapWithMask on 8³ toy data
    
    import torch
    from torch import Tensor
    
    # 1) Dummy predictor that just counts perturbed voxels (so SHAP ≈ region size)
    class DummyPredictor(torch.nn.Module):
        def forward(self, volume: Tensor, pert_mask: Tensor) -> Tensor:
            # volume: (B,C,X,Y,Z), pert_mask: (B,C,X,Y,Z)
            # Our “score” is negative #perturbed voxels per example
            return -(pert_mask.sum(dim=[1,2,3,4]).float())  # shape (B,)
    
    # 2) Synthesise an 8×8×8 volume and a 4-supervoxel map
    vol = torch.randn(1, 1, 8, 8, 8)                  # (B=1, C=1, X=8,Y=8,Z=8)
    sv_map = torch.randint(0, 4, (8, 8, 8))           # labels 0–3
    
    
    # 4) Instantiate explainer (assumes KernelShapWithMask is already in scope)
    explainer = KernelShapWithMask(
        forward_func=DummyPredictor
    )
    attrs = explainer.attribute(
            inputs = vol,
            baselines = 0.0,
            feature_mask= sv_map,
            n_samples = 2,
            show_progress = True
    )
    
    print("Attributions shape:", attrs.shape)  
    # should be (1,1,8,8,8), and values ≈ size of each supervoxel region
"""

### try to use it in our pipeline

In [None]:
# define an utility for annoying nnunetv2 preprocessing
def nnunetv2_default_preprocessing(ct_img_path, predictor, dataset_json_path) -> np.ndarray:
    plans_manager = predictor.plans_manager
    configuration_manager = predictor.configuration_manager
    
    preprocessor = configuration_manager.preprocessor_class(verbose=False)
    rw = plans_manager.image_reader_writer_class()
    if callable(rw) and not hasattr(rw, "read_images"):
        rw = rw()
    img_np, img_props = rw.read_images([str(ct_img_path)])
    
    preprocessed, _, _ = preprocessor.run_case_npy(
        img_np, seg=None, properties=img_props,
        plans_manager=plans_manager,
        configuration_manager=configuration_manager,
        dataset_json=dataset_json_path
    )
    return preprocessed

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# 2) Initialise predictor ------------------------
predictor = CustomNNUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=False, # == test time augmentation
    perform_everything_on_device=True,
    device=torch.device('cuda', 0),
    verbose=False,
    verbose_preprocessing=False,
    allow_tqdm=False #it interfere with SHAP loading bar
)
# initializes the network architecture, loads the checkpoint
predictor.initialize_from_trained_model_folder(
    model_dir_readonly,
    use_folds=(0,),
    checkpoint_name='checkpoint_final.pth',
)

In [None]:
import pickle as pkl

with open("baseline_output_dictionary_cache.pkl", "rb") as f:
    baseline_pred_cache = pkl.load(f)

"""# code to get our baseline prediction patch dictionary
iterator = SHAPPredictionIterator(plan, predictor,
                                               skip_completed_ids=acc.completed_ids,
                                               cache_sw_inference=True,
                                               verbose=True)
baseline_pred_cache = iterator._baseline_output_dictionary
import pickle as pkl
# Write to file
with open("baseline_output_dictionary_cache.pkl", "wb") as f:
    pkl.dump(baseline_pred_cache, f)

"""

In [None]:
"""example = baseline_pred_cache[((None, None, None), (0, 72, None), (0, 160, None), (0, 160, None))]
print(example.shape)
C,D,H,W = example.shape
print(D*H*W)"""

In [None]:
"""# (a) load + preprocess the volume  (1, C, D, H, W) – nnU-Net order
nii_path = "input_volume_39.nii.gz"
dataset_json_path = Path(model_dir_readonly) / "dataset.json"

volume_np = nnunetv2_default_preprocessing(nii_path, predictor, dataset_json_path)

volume = torch.from_numpy(volume_np).unsqueeze(0).to(device)        # torch (1,C,D,H,W)

print("Volume shape:", volume.shape)                # (1, C, D, H, W)"""

In [None]:
"""# (b) super-voxel / organ-id map  (W, H, D)
organ_mask_path = join(
    nnUNet_raw, "total_segmentator_structures",
    "AUTOMI_00039_0000", "mask_mask_add_input_20_total_segmentator.nii",
)
sv_np = nib.load(organ_mask_path).get_fdata()
sv_np = np.transpose(sv_np, (2, 1, 0))                 # match (D,H,W)
# we need features of feature mask ordered from 0 (or 1) to M-1 (M)
sv_values, indexes = np.unique(sv_np, return_inverse=True)

sv_np = indexes.reshape(sv_np.shape)
print(np.unique(sv_np))

# IMPORTANT 🔸: KernelShapWithMask expects **(X, Y, Z)** without channel axis
supervox = torch.from_numpy(sv_np).long().to(device)   # (D,H,W)

print("Mask shape:", supervox.shape)                      # (D, H, W)"""

In [None]:
"""print(sv_np.shape)
print(np.prod(sv_np.shape))
print(np.sum(sv_np != 0))
print(np.sum(sv_np == 0))
assert np.sum(sv_np != 0) + np.sum(sv_np == 0) == np.prod(sv_np.shape)"""

In [None]:
# ------------------------------------------------------------
# ❷  Forward wrapper that nnU-Net expects
# ------------------------------------------------------------

@torch.inference_mode()
def forward_segmentation_output_to_explain(
        input_image:         torch.Tensor,
        perturbation_mask:   torch.BoolTensor,
        baseline_prediction_dict: dict
) -> torch.Tensor:           # returns a scalar per sample
    """
    Example aggregate: sum of lymph-node logits (class 1) in the mask produced
    by the network – adapt to your real metric as needed.
    """
    logits = predictor.predict_sliding_window_return_logits_with_caching(
        input_image, perturbation_mask, baseline_prediction_dict,
    )                              # (C, D, H, W)
    seg_mask = torch.argmax(logits, dim=0)          # (D,H,W)
    D, H, W = seg_mask.shape
    aggregate = torch.sum(logits[1].double() * seg_mask) / (D*H*W)  # normalize to avoid overflows in SHAP

    return aggregate

# c) wrap your cached‐forward method:
explainer = KernelShapWithMask(
    forward_func=lambda vol, mask: forward_segmentation_output_to_explain(
        input_image=vol,
        perturbation_mask=mask,
        baseline_prediction_dict=baseline_pred_cache)
)
"""
# d) compute SHAP
attr = explainer.attribute(
    inputs=volume,       # (1,C,D,H,W)
    baselines=0.0,
    feature_mask=supervox,
    n_samples=50,    
    return_input_shape=True,
    show_progress=True,
)
print("Attributions:", attr.shape)  # → (1,C,D,H,W)
"""

In [None]:
"""for x,y,_ in explainer.dataset:
    print(x)
    
formatted_inputs, baselines = _format_input_baseline(inputs=volume, baselines=0.0)
feature_mask, num_interp_features = construct_feature_mask(
            feature_mask=supervox, formatted_inputs=formatted_inputs
        )
print(num_interp_features)"""

### Move back attribution to physical space and save the nifty

In [None]:
"""affine = nib.load(nii_path).affine
attr_postprocessed = attr[0][0].detach().cpu().numpy().transpose(2,1,0) # (W, H, D)
attr_img = nib.Nifti1Image(attr_postprocessed, affine)
nib.save(attr_img, 'attribution_map-n_samples=50.nii.gz')"""

# Next step: define regular, fixed size superpixels and try to compute the attributions of each of them
So we also need to define a metric to compare, since segmentation explanations, differently from classification, is intrinsically ambiguous. For example, let's select a priori a single region of the segmentation output, and use the average of these pixels to compute the impact of perturbations.

### Animation Plot util (interactive won't work on a notebook with matplotlib inline-mode

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, FFMpegWriter, PillowWriter
from IPython.display import HTML, display

def visualize_volume(
    volume: np.ndarray,
    mode: str = 'interactive',
    plane: str = 'axial',
    step: int = 1,
    cmap = 'gray',
    gif_filename: str = 'volume_animation.gif',
    mp4_filename: str = 'volume_animation.mp4',
    interval: int = 100
):
    """
    Visualize a 3D volume either as an inline (HTML5) animation or as a saved GIF/MP4 file,
    with optional subsampling of slices (step).

    Parameters
    ----------
    volume : np.ndarray
        3D NumPy array (shape: X × Y × Z).
    mode : str, optional
        - 'interactive': render a lightweight HTML5 video inline.
        - 'gif': save as GIF to `gif_filename` and display.
        - 'mp4': save as MP4 to `mp4_filename` and display.
      Default is 'interactive'.
    plane : str, optional
        One of {'axial', 'coronal', 'sagittal'}. Determines slicing orientation:
        - 'axial':   slices = volume[:, :, idx]
        - 'coronal': slices = volume[:, idx, :]
        - 'sagittal': slices = volume[idx, :, :]
      Default is 'axial'.
    step : int, optional
        Take every `step`-th slice along the chosen axis. Must be ≥ 1.
        E.g. step=2 shows slices 0,2,4,… instead of 0,1,2,3,… 
        (reduces #frames if your volume has many slices).
      Default is 1.
    gif_filename : str, optional
        If mode='gif', save the animation to this file.
      Default is 'volume_animation.gif'.
    mp4_filename : str, optional
        If mode='mp4', save the animation to this file.
      Default is 'volume_animation.mp4'.
    interval : int, optional
        Milliseconds between frames in the animation. Lower → faster playback.
      Default is 100 (i.e. 10 FPS).

    Raises
    ------
    ValueError
        If inputs are invalid (e.g. volume not 3D, unknown mode/plane, or step < 1).
    """

    # 1) Validate inputs
    if not isinstance(volume, np.ndarray) or volume.ndim != 3:
        raise ValueError("`volume` must be a 3D NumPy array.")
    if mode not in {'interactive', 'gif', 'mp4'}:
        raise ValueError("`mode` must be one of {'interactive', 'gif', 'mp4'}.")
    if plane not in {'axial', 'coronal', 'sagittal'}:
        raise ValueError("`plane` must be one of {'axial', 'coronal', 'sagittal'}.")
    if not (isinstance(step, int) and step >= 1):
        raise ValueError("`step` must be an integer ≥ 1.")

    # 2) Decide how to slice + total number of (subsampled) frames
    if plane == 'axial':
        get_slice = lambda vol, idx: np.rot90(vol[:, :, idx])
        full_num = volume.shape[2]
    elif plane == 'coronal':
        get_slice = lambda vol, idx: np.rot90(vol[:, idx, :])
        full_num = volume.shape[1]
    else:  # sagittal
        get_slice = lambda vol, idx: np.rot90(vol[idx, :, :])
        full_num = volume.shape[0]

    # Build a list of indices: [0, step, 2*step, …] but not exceeding full_num-1
    indices = list(range(0, full_num, step))
    num_frames = len(indices)

    # 3) Set up matplotlib figure once
    fig, ax = plt.subplots(figsize=(6,6))
    ax.axis('off')

    # Show the first frame
    first_img = get_slice(volume, indices[0])
    vmin, vmax = np.percentile(first_img, (1, 99))
    im = ax.imshow(first_img, cmap=cmap, vmin=vmin, vmax=vmax)
    title = ax.set_title(f"{plane.capitalize()} Slice 1/{num_frames}", fontsize=14)

    def _update(frame_idx):
        """
        frame_idx runs from 0 to num_frames-1. We map it to the actual voxel index.
        """
        actual_idx = indices[frame_idx]
        slice_img = get_slice(volume, actual_idx)
        vmin, vmax = np.percentile(slice_img, (1, 99))
        im.set_data(slice_img)
        im.set_clim(vmin, vmax)
        title.set_text(f"{plane.capitalize()} Slice {frame_idx+1}/{num_frames}")
        return (im, title)

    # 4) Build the FuncAnimation
    anim = FuncAnimation(
        fig,
        _update,
        frames=range(num_frames),
        interval=interval,
        blit=True
    )

    # 5) Depending on mode, either render inline or save to file
    if mode == 'interactive':
        # Convert to JSHTML and display inline.
        # This is (relatively) lightweight compared to embedding 300+ full-size frames.
        html_widget = HTML(anim.to_jshtml())
        display(html_widget)

    elif mode == 'gif':
        # Save as GIF. You can adjust fps=1000//interval if desired.
        gif_writer = PillowWriter(fps=max(1, 1000 // interval))
        anim.save(gif_filename, writer=gif_writer)
        plt.close(fig)
        display(HTML(f'<img src="{gif_filename}" />'))

    else:  # mode == 'mp4'
        # Save as MP4 using FFMpegWriter (usually more compact than GIF).
        # Kaggle has ffmpeg installed, so this should work out of the box.
        fps = max(1, 1000 // interval)
        mp4_writer = FFMpegWriter(fps=fps, codec='libx264')
        anim.save(mp4_filename, writer=mp4_writer)
        plt.close(fig)
        display(HTML(f'<video controls src="{mp4_filename}" width="512"></video>'))



In [None]:
import nibabel as nib

ct_img_path = join(nnUNet_raw, "imagesTr", "AUTOMI_00039_0000.nii")
nii = nib.load(ct_img_path)
volume_data = nii.get_fdata()  # shape: (512, 512, 283)


In [None]:
#example usage GIF
"""visualize_volume(
    volume=volume_data,
    mode='gif',
    plane='coronal',
    step=2, 
    gif_filename='fast.gif',
    interval=80     # faster playback (~12.5 FPS)
)"""
#example usage MP4
"""visualize_volume(
    volume=volume_data,
    mode='mp4',
    plane='axial',
    step=1, 
    mp4_filename='axial_full.mp4',
    interval=80
)"""


### We observe out of ROI regions (CT machinery, empty sections) => We need to tessellate the preprocessed volume, otherwise we won't have a match between the network input and our attributions mapping.

In [None]:
ct_img_path = "input_volume_39.nii.gz"#join(nnUNet_raw, "imagesTr", "AUTOMI_00000_0000.nii")
nii = nib.load(ct_img_path)
#volume_data = nii.get_fdata()  # shape: (512, 512, 283)

plans_manager = predictor.plans_manager
configuration_manager = predictor.configuration_manager
dataset_json = Path(model_dir_readonly) / "dataset.json"

preprocessor = configuration_manager.preprocessor_class(verbose=False)
rw = plans_manager.image_reader_writer_class()
if callable(rw) and not hasattr(rw, "read_images"):
    rw = rw()
img_np, img_props = rw.read_images([str(ct_img_path)])

preprocessed, _, _ = preprocessor.run_case_npy(
    img_np, seg=None, properties=img_props,
    plans_manager=plans_manager,
    configuration_manager=configuration_manager,
    dataset_json=dataset_json
)
# preprocessed image has batch dimension added, and it's transposed
preprocessed = preprocessed.squeeze().transpose([2,1,0])

"""visualize_volume(
    volume=preprocessed,
    mode='gif',
    plane='coronal',
    step=2, 
    gif_filename='fast.gif',
    interval=80     # faster playback (~12.5 FPS)
)"""

## It seems that actually it just applies a filter that hides some artefact noise at the beginning in this case. In principle, nnunetv2 preprocessing can crop to remove empty background, but it does this only for completely 0 voxels, or when a segmentation is passed in cascade-mode (not our case). Anyway, it's important that we apply tessellation on the preprocessed volume to be sure.

In [None]:
print(img_props)

## There is not cropping indeed

## Now create a superpixel mask (several algorithms tested)

### 1. Cubic

In [None]:
import numpy as np
import nibabel as nib
import math

def create_supervoxel_mask(volume_shape, side_length):
    """
    Partition a 3D volume into cubic supervoxels of given side length, and
    return a mask array where each voxel is labeled by an integer supervoxel ID.

    Parameters
    ----------
    volume_shape : tuple of ints (D, H, W)
        Shape of the 3D volume in voxels (depth, height, width).
    side_length : int
        Desired side length (in voxels) of each cubic supervoxel.

    Returns
    -------
    supervoxel_mask : np.ndarray, shape (D, H, W), dtype=np.int32
        Integer mask such that all voxels belonging to the same L×L×L block
        have the same label. Labels start at 1 and increase to the total number
        of supervoxels.
    """
    D, H, W = volume_shape
    # Compute how many blocks fit along each axis (using ceiling for remainders)
    n_blocks_d = math.ceil(D / side_length)
    n_blocks_h = math.ceil(H / side_length)
    n_blocks_w = math.ceil(W / side_length)

    # Initialize mask
    supervoxel_mask = np.zeros((D, H, W), dtype=np.int32)

    # Assign labels: iterate over block indices
    label = 1
    for bd in range(n_blocks_d):
        start_d = bd * side_length
        end_d = min((bd + 1) * side_length, D)

        for bh in range(n_blocks_h):
            start_h = bh * side_length
            end_h = min((bh + 1) * side_length, H)

            for bw in range(n_blocks_w):
                start_w = bw * side_length
                end_w = min((bw + 1) * side_length, W)

                # Assign this block the current label
                supervoxel_mask[
                    start_d:end_d,
                    start_h:end_h,
                    start_w:end_w
                ] = label
                label += 1

    return supervoxel_mask

# 1. Take the preprocessed CT volume
ct_data = preprocessed.squeeze().transpose([2,1,0])  # shape: (D, H, W)
print("Shape: ", ct_data.shape)

# 2. Decide on a supervoxel side length (in voxels)
side_length = 32

# 3. Create the supervoxel mask
#    The mask will have the same spatial shape as the CT volume.
#supervoxel_mask = create_supervoxel_mask(ct_data.shape, side_length)

### 2. Hexagonal prism

In [None]:
import numpy as np, math
from scipy.spatial import cKDTree
import matplotlib.pyplot as plt

def hex_prism_mask(shape, r=24, z_thick=None, orientation='flat',
                   return_edges=False):
    """
    shape        : (D,H,W) of volume
    r            : hex side length (voxels)
    z_thick      : slab height; None -> whole D (prisms)
    orientation  : 'flat' (flat-top)  or  'pointy' (pointy-top)
    return_edges : if True also returns a 2-D bool array with cell borders
    """
    D, H, W = shape
    z_thick  = D if z_thick is None else int(z_thick)

    if orientation == 'flat':
        dx, dy = 3*r, math.sqrt(3)*r
        shift  = lambda col: (dy/2) if (col & 1) else 0
        n_cols = int(W/dx)+3; n_rows = int(H/dy)+3
        centres = [(col*dx, row*dy+shift(col))
                   for col in range(-1, n_cols)
                   for row in range(-1, n_rows)]
    else:  # pointy-top
        dx, dy = math.sqrt(3)*r, 3*r
        shift  = lambda row: (dx/2) if (row & 1) else 0
        n_cols = int(W/dx)+3; n_rows = int(H/dy)+3
        centres = [(col*dx+shift(row), row*dy)
                   for row in range(-1, n_rows)
                   for col in range(-1, n_cols)]

    centres = np.array(centres, float)
    m = r
    keep = ((centres[:,0] >= -m)&(centres[:,0]<=W-1+m)&
            (centres[:,1] >= -m)&(centres[:,1]<=H-1+m))
    centres = centres[keep]

    Y,X = np.mgrid[:H,:W]
    lab = cKDTree(centres).query(np.c_[X.ravel(),Y.ravel()],1)[1]
    hex2d = lab.reshape(H,W).astype(np.int32)+1      # labels start at 1

    mask3d = np.zeros((D,H,W), np.int32)
    off = 0
    for z0 in range(0, D, z_thick):
        mask3d[z0:z0+z_thick] = hex2d + off
        off += hex2d.max()

    if not return_edges:
        return mask3d

    # simple edge map: a pixel differs from any 4-neighbour
    edge = np.zeros_like(hex2d, bool)
    edge[:-1] |= hex2d[:-1] != hex2d[1:]
    edge[1:]  |= hex2d[1:]  != hex2d[:-1]
    edge[:,:-1]|= hex2d[:,:-1]!= hex2d[:,1:]
    edge[:,1:] |= hex2d[:,1:] != hex2d[:,:-1]
    return mask3d, edge


### 3. SLIC algorithm

In [None]:
from skimage.segmentation import slic

volume = preprocessed.squeeze().transpose([2,1,0])  # shape: (D, H, W)
print("Shape volume:", volume.shape)  # ad es. (D, H, W)

# Parametri SLIC: n_segments definisce quanti supervoxels circa si voglion
segments = slic(volume, n_segments=100, compactness=1, start_label=0, max_num_iter=100, channel_axis=None)

print("Mappa supervoxel shape:", segments.shape)
n_supervoxels = len(np.unique(segments))
print("Numero di supervoxels:", n_supervoxels)
#for sup in range(n_supervoxels):
#   print("sup_",sup, " number of voxels: ", np.sum(np.where(segments == sup)))

In [None]:
# Optional: Save the result

ct_img_path = "input_volume_39.nii.gz"
nii = nib.load(ct_img_path)
affine = nii.affine

output_img = nib.Nifti1Image(segments.astype(np.int32), affine=np.eye(4))
nib.save(output_img, 'slic_map.nii.gz')

In [None]:
"""visualize_volume(
    volume=segments,
    mode='gif',
    plane='sagittal',
    step=2, 
    cmap="tab20",
    gif_filename='fast.gif',
    interval=80     # faster playback (~12.5 FPS)
)"""

## Install and configure 3D-SEEDS 
https://github.com/Zch0414/3d-seeds/blob/master/README.md

In [None]:
"""# resetta l'ambiente di lavoro
%cd /kaggle/working
!rm -rf 3d-seeds-compile

# aggiorna repository di sistema e installa gli header di OpenCV
!apt-get update -qq
!apt-get install -y -q libopencv-dev libopencv-contrib-dev ninja-build
"""

In [None]:
"""!git clone https://github.com/Zch0414/3d-seeds.git 3d-seeds-compile   # clona in /kaggle/working/3d-seeds-compile
%cd 3d-seeds-compile
"""

In [None]:
"""import pathlib, re
setup_py = pathlib.Path("setup.py").read_text()

# puntiamo agli header/librerie appena installati da apt-get
setup_py = re.sub(r'OPENCV_INCLUDE_DIRS\s*=.*',
                  'OPENCV_INCLUDE_DIRS = "/usr/include/opencv4"', setup_py)
setup_py = re.sub(r'OPENCV_LIBRARY_DIRS\s*=.*',
                  'OPENCV_LIBRARY_DIRS = "/usr/lib/x86_64-linux-gnu"', setup_py)

pathlib.Path("setup.py").write_text(setup_py)
"""

In [None]:
"""!python setup.py bdist_wheel -q
!pip install dist/*.whl
"""

In [None]:
"""%cd /kaggle/working/3d-seeds-compile   # ← root della repo

# 1) esporta il path degli header in CPLUS_INCLUDE_PATH (vale per g++)
import os, sys, subprocess, shlex
os.environ["CPLUS_INCLUDE_PATH"] = "/usr/include/opencv4"
os.environ["CPATH"]              = "/usr/include/opencv4"    # (copertura per clang)

# 2) idem per i .so (non strettamente necessario, ma sicuro)
os.environ["LIBRARY_PATH"] = "/usr/lib/x86_64-linux-gnu"

# 3) build + install
!python setup.py bdist_wheel   # stavolta senza -q per vedere il comando completo
!pip install dist/*.whl
"""

In [None]:
"""%cd /kaggle/working"""

### Rapid test

In [None]:
"""import numpy as np, python_3d_seeds as sv, time

vol = np.random.rand(64, 64, 64).astype(np.float32)
D, H, W = vol.shape

# factory: W, H, D, channels, n_segments, hist_bins, λ, block_iter, pixel_iter
seeds = sv.createSupervoxelSEEDS(W, H, D, 1,
                                 200,    # n_segments
                                 15, 2,  # hist_bins, lambda_boundary
                                 2, 4)   # block_iters, pixel_iters

t0 = time.time()
seeds.iterate(vol, num_iterations=4)      # ← senza spacing
labels = seeds.getLabels()

print("supervoxels:", np.unique(labels).size,
      "| tempo:", round(time.time() - t0, 3), "s")
"""

In [None]:
"""import numpy as np, python_3d_seeds as sv, nibabel as nib
from scipy.ndimage import zoom

# ---------- carica volume ----------
ct_img_path = "input_volume_39.nii.gz"#join(nnUNet_raw, "imagesTr", "AUTOMI_00039_0000.nii")
nii = nib.load(ct_img_path)
vol = preprocessed.astype(np.float32)          # già normalizzato (-1..1 o 0..1)
print("Shape iniziale volume preprocessed: ", vol.shape)

# (facoltativo) resampling “morbido” se dz >> dy,dx
dz,dy,dx = nii.header.get_zooms()
if dz > 3*dy:                                  # esempio: slice spesse 5 mm
    factor = (dy/dz, 1.0, 1.0)                 # porta dz≈dy
    vol = zoom(vol, factor, order=1)

# ---------- cast a uint8 ----------
vmin, vmax = vol.min(), vol.max()
vol8 = np.round(255*(vol - vmin)/(vmax - vmin)).astype(np.uint8, copy=False)
vol8 = np.ascontiguousarray(vol8)              # C-order garantito

# ---------- SEEDS ----------
D,H,W = vol8.shape
seeds = sv.createSupervoxelSEEDS(
        W, H, D, 1,          # width, height, depth, channels
        200,                 # num_superpixels
        2,                   # num_levels
        2,                   # prior (λ bordo)
        15,                  # histogram_bins
        True)                # double_step

seeds.iterate(vol8, num_iterations=4)
labels = seeds.getLabels()
print("Supervoxels:", np.unique(labels).size)
print("type: ", type(labels))
print("shape: ", labels.shape)"""

In [None]:
"""output_img = nib.Nifti1Image(labels, affine=nii.affine)
nib.save(output_img, "3d-seeds_supervoxel_map.nii.gz")"""

### 4. Face-centered cubic (FCC) lattice induced supervoxel assignment
-> more *isotropic* than simple cubes

In [None]:
import numpy as np
import nibabel as nib
from scipy.spatial import cKDTree

def generate_fcc_centers(W, H, D, S):
    """
    Generate supervoxel centers on an FCC lattice within the volume.

    Parameters:
    - W, H, D: Volume dimensions (width, height, depth).
    - S: Spacing parameter (approximate supervoxel size).

    Returns:
    - centers: Array of [x, y, z] coordinates for supervoxel centers.
    """
    # Lattice parameter: adjust so nearest-neighbor distance approximates S
    a = S * np.sqrt(2)  # In FCC, nearest-neighbor distance is a*sqrt(2)/2
    centers = []

    # Base grid points (where i+j+k is even)
    for i in range(int(np.ceil(W / a)) + 1):
        for j in range(int(np.ceil(H / a)) + 1):
            for k in range(int(np.ceil(D / a)) + 1):
                if (i + j + k) % 2 == 0:
                    x, y, z = i * a, j * a, k * a
                    if 0 <= x < W and 0 <= y < H and 0 <= z < D:
                        centers.append([x, y, z])

    # Add face-centered points (simplified; add other offsets as needed)
    for i in range(int(np.ceil((W - a/2) / a)) + 1):
        for j in range(int(np.ceil((H - a/2) / a)) + 1):
            for k in range(int(np.ceil(D / a)) + 1):
                x, y, z = i * a + a/2, j * a + a/2, k * a
                if 0 <= x < W and 0 <= y < H and 0 <= z < D:
                    centers.append([x, y, z])
                # Add other offsets (e.g., (i+0.5, j, k+0.5), (i, j+0.5, k+0.5)) if needed

    return np.array(centers)

def assign_voxels_to_centers(volume_shape, centers):
    """
    Assign each voxel to the nearest supervoxel center.

    Parameters:
    - volume_shape: Tuple (W, H, D) of volume dimensions.
    - centers: Array of supervoxel center coordinates.

    Returns:
    - labels: 3D array of integer labels (int32).
    """
    W, H, D = volume_shape
    # Create coordinate grid for all voxels
    grid_x, grid_y, grid_z = np.mgrid[0:W, 0:H, 0:D]
    voxels = np.vstack([grid_x.ravel(), grid_y.ravel(), grid_z.ravel()]).T

    # Use KD-tree for efficient nearest-neighbor search
    tree = cKDTree(centers)
    _, labels = tree.query(voxels)

    # Cast to int32 for compatibility
    return labels.reshape(W, H, D).astype(np.int32)

def create_supervoxel_map(nifti_file, S=10):
    """
    Produce a supervoxel map from a NIfTI CT scan volume.

    Parameters:
    - nifti_file: Path to the NIfTI file.
    - S: Spacing parameter for supervoxel size (default=10).

    Returns:
    - supervoxel_map: 3D array with integer labels for each supervoxel (int32).
    - affine: Affine matrix from the input NIfTI file.
    """
    # Load NIfTI volume
    img = nib.load(nifti_file)
    volume = img.get_fdata()
    affine = img.affine  # Get affine matrix
    W, H, D = volume.shape

    # Generate supervoxel centers
    centers = generate_fcc_centers(W, H, D, S)
    if len(centers) == 0:
        raise ValueError("No centers generated. Reduce S or check volume dimensions.")

    # Assign voxels to supervoxels
    supervoxel_map = assign_voxels_to_centers((W, H, D), centers)

    return supervoxel_map, affine


In [None]:
"""nifti_file = "input_volume_39.nii.gz"
supervoxel_map, affine = create_supervoxel_map(nifti_file, S=30)
print(f"Data type: {supervoxel_map.dtype}")  # Should print int32
print(f"Number of supervoxels: {len(np.unique(supervoxel_map))}")
output_img = nib.Nifti1Image(supervoxel_map, affine=affine)
nib.save(output_img, "supervoxel_map.nii.gz")"""

### 4.1 affine transformation to translate isotropy from voxel space into the original geometrical space (measured in mm)

In [None]:
"""import numpy as np
import nibabel as nib
from scipy.spatial import cKDTree

# Load the NIfTI image
img = nib.load('input_volume_39.nii.gz')
volume = img.get_fdata()
affine = img.affine
W, H, D = volume.shape

# Compute correct voxel spacings from affine matrix
voxel_spacings = np.array([np.linalg.norm(affine[:3, 0]),
                           np.linalg.norm(affine[:3, 1]),
                           np.linalg.norm(affine[:3, 2])])

# Set desired supervoxel size in mm
S = 300

# Compute grid steps in voxel space
step_x = max(1, int(np.round(S / voxel_spacings[0])))
step_y = max(1, int(np.round(S / voxel_spacings[1])))
step_z = max(1, int(np.round(S / voxel_spacings[2])))

# Generate centers in voxel space
centers_voxel = []
for i in range(0, W, step_x):
    for j in range(0, H, step_y):
        for k in range(0, D, step_z):
            centers_voxel.append([i, j, k])
centers_voxel = np.array(centers_voxel)

# Transform to physical space
centers_physical = np.dot(affine, np.hstack((centers_voxel, np.ones((len(centers_voxel), 1)))).T).T[:, :3]

# Verify centers were generated
if len(centers_physical) == 0:
    raise ValueError("No supervoxel centers generated. Check volume dimensions and affine matrix.")

# Assign voxels to centers using KD-tree
def assign_voxels_to_centers(volume_shape, centers_physical, affine):
    W, H, D = volume_shape
    # Generate all voxel coordinates
    grid_x, grid_y, grid_z = np.mgrid[0:W, 0:H, 0:D]
    voxels_voxel = np.vstack([grid_x.ravel(), grid_y.ravel(), grid_z.ravel(), np.ones(grid_x.size)]).T
    # Transform to physical space
    voxels_physical = np.dot(affine, voxels_voxel.T).T[:, :3]
    # Find nearest center for each voxel
    tree = cKDTree(centers_physical)
    _, labels = tree.query(voxels_physical)
    return labels.reshape(W, H, D).astype(np.int32)

# Generate supervoxel map
supervoxel_map = assign_voxels_to_centers((W, H, D), centers_physical, affine)

# Optional: Save the result
output_img = nib.Nifti1Image(supervoxel_map, affine)
nib.save(output_img, 'supervoxel_map.nii.gz')"""

In [None]:
#print(affine)

### 4.1 Try to apply the original algorithm FCC it to an affine transformed volume that has the same proportion as the .nii in the physical space. Transform->apply the algorithm to derive the map-> back transform the map onto the voxel space



In [None]:
import numpy as np
import nibabel as nib
from scipy.spatial import cKDTree

def generate_supervoxel_map(img, S=200.0):
    """
    Generate a supervoxel map using FCC tessellation in physical space (original version).
    
    Args:
        img: Nibabel NIfTI image object
        S (float): Desired supervoxel size in millimeters (default: 200.0)
    
    Returns:
        supervoxel_map: 3D NumPy array with integer labels for supervoxels
    """
    # Load volume and affine from image
    volume = img.get_fdata()
    affine = img.affine
    W, H, D = volume.shape

    # Compute the physical bounding box of the volume
    corners_voxel = np.array([
        [0, 0, 0],
        [W-1, 0, 0],
        [0, H-1, 0],
        [0, 0, D-1],
        [W-1, H-1, 0],
        [W-1, 0, D-1],
        [0, H-1, D-1],
        [W-1, H-1, D-1]
    ])
    corners_hom = np.hstack((corners_voxel, np.ones((8, 1))))
    corners_physical = (affine @ corners_hom.T).T[:, :3]
    min_xyz = corners_physical.min(axis=0)
    max_xyz = corners_physical.max(axis=0)

    # Generate FCC lattice centers in physical space
    a = S * np.sqrt(2)
    factor = 2 / a
    p_min = int(np.floor(factor * min_xyz[0])) - 1
    p_max = int(np.ceil(factor * max_xyz[0])) + 1
    q_min = int(np.floor(factor * min_xyz[1])) - 1
    q_max = int(np.ceil(factor * max_xyz[1])) + 1
    r_min = int(np.floor(factor * min_xyz[2])) - 1
    r_max = int(np.ceil(factor * max_xyz[2])) + 1

    # Create grid of possible indices
    p_vals = np.arange(p_min, p_max + 1)
    q_vals = np.arange(q_min, q_max + 1)
    r_vals = np.arange(r_min, r_max + 1)
    P, Q, R = np.meshgrid(p_vals, q_vals, r_vals, indexing='ij')
    P = P.flatten()
    Q = Q.flatten()
    R = R.flatten()

    # Filter for FCC lattice points (sum of indices is even)
    mask = (P + Q + R) % 2 == 0
    P = P[mask]
    Q = Q[mask]
    R = R[mask]

    # Compute physical coordinates of centers
    centers = np.column_stack((P * a / 2, Q * a / 2, R * a / 2))

    # Keep only centers within the bounding box
    inside = ((centers[:, 0] >= min_xyz[0]) & (centers[:, 0] <= max_xyz[0]) &
              (centers[:, 1] >= min_xyz[1]) & (centers[:, 1] <= max_xyz[1]) &
              (centers[:, 2] >= min_xyz[2]) & (centers[:, 2] <= max_xyz[2]))
    centers = centers[inside]

    # Check if any centers were generated
    print(f"Number of supervoxel centers: {len(centers)}")
    if len(centers) == 0:
        raise ValueError("No supervoxel centers generated. Try reducing S.")

    # Generate voxel indices and transform to physical coordinates
    voxel_indices = np.indices((W, H, D)).reshape(3, -1).T  # shape (W*H*D, 3)
    voxel_indices_hom = np.hstack((voxel_indices, np.ones((voxel_indices.shape[0], 1))))  # shape (W*H*D, 4)
    physical_coords = (affine @ voxel_indices_hom.T).T[:, :3]  # shape (W*H*D, 3)

    # Assign each voxel to the nearest supervoxel center
    tree = cKDTree(centers)
    _, labels = tree.query(physical_coords)

    # Create the supervoxel map
    supervoxel_map = labels.reshape((W, H, D)).astype(np.int32)

    return supervoxel_map

In [None]:
"""import numpy as np
import nibabel as nib
import time

# Load the image
img = nib.load('input_volume_39.nii.gz')

# Generate and save supervoxel map using the original method with timing
start_time = time.time()
supervoxel_map = generate_supervoxel_map(img, S=100.0)
end_time = time.time()
original_time = end_time - start_time
print(f"Execution time: {original_time:.2f} seconds")

output_img = nib.Nifti1Image(supervoxel_map, img.affine)
nib.save(output_img, 'supervoxel_map.nii.gz')"""

## Apply modified KernelShap to this supervoxel tessellation

In [None]:
# 2) Initialise predictor ------------------------
predictor = CustomNNUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=False, # == test time augmentation
    perform_everything_on_device=True,
    device=torch.device('cuda', 0),
    verbose=False,
    verbose_preprocessing=False,
    allow_tqdm=False #it interfere with SHAP loading bar
)
# initializes the network architecture, loads the checkpoint
predictor.initialize_from_trained_model_folder(
    model_dir_readonly,
    use_folds=(0,),
    checkpoint_name='checkpoint_final.pth',
)

In [None]:
# (a) load + preprocess the volume  (1, C, D, H, W) – nnU-Net order
nii_path = "input_volume_39.nii.gz"
dataset_json_path = Path(model_dir_readonly) / "dataset.json"

volume_np = nnunetv2_default_preprocessing(nii_path, predictor, dataset_json_path)

volume = torch.from_numpy(volume_np).unsqueeze(0).to(device)        # torch (1,C,D,H,W)

print("Volume shape:", volume.shape)                # (1, C, D, H, W)

In [None]:
"""# (b) super-voxel / organ-id map  (W, H, D)
# Load the image
img = nib.load('input_volume_39.nii.gz')

# Generate and save supervoxel map using the 
supervoxel_map = generate_supervoxel_map(img, S=100.0)
                    # (W, H, D)"""

In [None]:
"""supervoxel_map = np.transpose(supervoxel_map, (2, 1, 0))                 # match (D,H,W)
# we need features of feature mask ordered from 0 (or 1) to M-1 (M)
sv_values, indexes = np.unique(supervoxel_map, return_inverse=True)

supervoxel_map = indexes.reshape(supervoxel_map.shape)
print(np.unique(supervoxel_map))

# IMPORTANT 🔸: KernelShapWithMask expects **(X, Y, Z)** without channel axis
supervoxel_map = torch.from_numpy(supervoxel_map).long().to(device)   # (D,H,W)

print("Mask shape:", supervoxel_map.shape)  """

In [None]:
"""print(supervoxel_map.shape)
print(np.prod(supervoxel_map.shape))
print(torch.sum(supervoxel_map != 0))
print(torch.sum(supervoxel_map == 0))"""

In [None]:
# ------------------------------------------------------------
# ❷  Forward wrapper that nnU-Net expects
# ------------------------------------------------------------

@torch.inference_mode()
def forward_segmentation_output_to_explain(
        input_image:         torch.Tensor,
        perturbation_mask:   torch.BoolTensor,
        baseline_prediction_dict: dict
) -> torch.Tensor:           # returns a scalar per sample

    logits = predictor.predict_sliding_window_return_logits_with_caching(
        input_image, perturbation_mask, baseline_prediction_dict,
    )                              # (C, D, H, W)
    seg_mask = torch.argmax(logits, dim=0)          # (D,H,W)
    D, H, W = seg_mask.shape
    aggregate = torch.sum(logits[1].double() * seg_mask) / (D*H*W)  # scaling to avoid overflows in SHAP

    return aggregate

# c) wrap your cached‐forward method:
explainer = KernelShapWithMask(
    forward_func=lambda vol, mask: forward_segmentation_output_to_explain(
        input_image=vol,
        perturbation_mask=mask,
        baseline_prediction_dict=baseline_pred_cache)
)

"""# d) compute SHAP
attr = explainer.attribute(
    inputs=volume,       # (1,C,D,H,W)
    baselines=0.0, 
    feature_mask=supervoxel_map,
    n_samples=1000,    
    return_input_shape=True,
    show_progress=True,
)
print("Attributions:", attr.shape)  # → (1,C,D,H,W)"""


In [None]:
"""for x,y,_ in explainer.dataset:
    print(x)
    break
    
formatted_inputs, baselines = _format_input_baseline(inputs=volume, baselines=0.0)
feature_mask, num_interp_features = construct_feature_mask(
            feature_mask=supervoxel_map, formatted_inputs=formatted_inputs
        )
print(num_interp_features)"""

### Move back attributions to physical space and save the nifty

In [None]:
"""affine = nib.load(nii_path).affine
attr_postprocessed = attr[0][0].detach().cpu().numpy().transpose(2,1,0) # (W, H, D)
attr_img = nib.Nifti1Image(attr_postprocessed, affine)
nib.save(attr_img, 'attribution_map-FCC.nii.gz')"""

derive attribution module, ignoring sign

In [None]:
"""attr_postprocessed_module = np.abs(attr_postprocessed)
attr_img = nib.Nifti1Image(attr_postprocessed_module, affine)
nib.save(attr_img, 'attribution_map-FCC_module.nii.gz')"""

# Define a ROI to explain segmentation in. 
Maybe this will provide a more useful attribution map, highlighting nearby organs

In [None]:
# get the manually derived ROI mask from the dataset, where we manually added it
ROI_mask_path = "/kaggle/input/automi-seg/segmentation-masked-ROI.nii"
ROI_mask = nib.load(ROI_mask_path)

print(ROI_mask.get_fdata().shape)
print(ROI_mask.affine)

### ROI mask is a binary mask highlighting the lymphnodes of interest. We need a bounding box to crop the volume accordingly

In [None]:
import nibabel as nib
import numpy as np

def get_mask_bbox_slices(mask_nii_path):
    """
    Load a binary ROI mask NIfTI and compute the minimal 3D bounding
    box slices containing all positive voxels.

    Parameters
    ----------
    mask_nii_path : str or Path
        Path to the input binary ROI mask NIfTI (.nii or .nii.gz).

    Returns
    -------
    bbox_slices : tuple of slice
        A 3-tuple of Python slice objects (x_slice, y_slice, z_slice)
        defining the minimal bounding box.
    """
    # 1) Load mask
    nii = nib.load(str(mask_nii_path))
    data = nii.get_fdata()
    if data.ndim != 3:
        raise ValueError("Input NIfTI must be a 3D volume")
    
    # 2) Find indices of positive voxels
    pos_voxels = np.argwhere(data > 0)
    if pos_voxels.size == 0:
        raise ValueError("No positive voxels found in mask")
    
    # 3) Compute min/max per axis
    x_min, y_min, z_min = pos_voxels.min(axis=0)
    x_max, y_max, z_max = pos_voxels.max(axis=0)
    
    # 4) Build slice objects (end is exclusive, hence +1)
    bbox_slices = (
        slice(int(x_min), int(x_max) + 1),
        slice(int(y_min), int(y_max) + 1),
        slice(int(z_min), int(z_max) + 1),
    )
    
    return bbox_slices

In [None]:
x_slice, y_slice, z_slice = get_mask_bbox_slices(ROI_mask_path)
print("Bounding box slices:")
print("  x:", x_slice)
print("  y:", y_slice)
print("  z:", z_slice)

### to **crop** correctly the volume around the *ROI*, we need to derive the **receptive field** of the sliding window inference, that depends on the *patch size*.

In [None]:
patch_size = np.array(predictor.configuration_manager.patch_size)
print("Patch size: ", patch_size)

# Receptive field is twice the patch size-1
RF = 2*(patch_size-1)
print("Receptive field: ", RF)

### Consider the receptive field to compute the final slices for cropping
Remember that model metadata are related to transposed volume (nnunetv2 takes (D, H, W) shape)

In [None]:
volume_path = "input_volume_39.nii.gz"
volume_shape = nib.load(volume_path).shape # (W, H, D) -> x, y, z
# RF shape -> (D, H, W) -> z, y, x (model input shape)
print("Original volume shape:", volume_shape)

W, H, D = volume_shape

# backward sorted receptive field axes
RF_x, RF_y, RF_z = RF[2],RF[1],RF[0]

x_slice_RF = slice(int(max(x_slice.start - RF_x/2, 0)), int(min(x_slice.stop + RF_x/2, W)))
y_slice_RF = slice(int(max(y_slice.start - RF_y/2, 0)), int(min(y_slice.stop + RF_y/2, H)))
z_slice_RF = slice(int(max(z_slice.start - RF_z/2, 0)), int(min(z_slice.stop + RF_z/2, D)))

print("new  x:", x_slice_RF)
print("new  y:", y_slice_RF)
print("new  z:", z_slice_RF)

In [None]:
import nibabel as nib
import numpy as np

def crop_volume_and_affine(nii_path, bbox_slices, save_cropped_nii_path=None):
    """
    Crop a 3D NIfTI volume using the given bounding-box slices and
    recompute the affine so the cropped volume retains correct world coordinates.

    Parameters
    ----------
    nii_path : str or Path
        Path to the input NIfTI volume (.nii or .nii.gz).
    bbox_slices : tuple of slice
        A 3-tuple (x_slice, y_slice, z_slice) as returned by get_mask_bbox_slices().
    save_cropped_nii_path : str or Path, optional
        If provided, the cropped volume will be saved here as a new NIfTI.

    Returns
    -------
    cropped_data : np.ndarray
        The volume data cropped to the bounding box.
    new_affine : np.ndarray
        The updated 4×4 affine transform for the cropped volume.
    """
    # 1) Load the original image
    img = nib.load(str(nii_path))
    data = img.get_fdata()
    affine = img.affine

    # 2) Crop the data array
    cropped_data = data[bbox_slices]

    # 3) Extract the voxel‐offsets for x, y, z from the slice starts
    x_slice, y_slice, z_slice = bbox_slices
    z0, y0, x0 = z_slice.start, y_slice.start, x_slice.start

    # 4) Compute the new affine translation: shift the origin by the voxel offsets
    # Note voxel coordinates are (i, j, k) = (x, y, z)
    offset_vox = np.array([x0, y0, z0])
    new_affine = affine.copy()
    new_affine[:3, 3] += affine[:3, :3].dot(offset_vox)

    # 5) Optionally save the cropped volume
    if save_cropped_nii_path is not None:
        cropped_img = nib.Nifti1Image(cropped_data, new_affine)
        nib.save(cropped_img, str(save_cropped_nii_path))

    return cropped_data, new_affine


In [None]:
from pathlib import Path

slices = (x_slice_RF, y_slice_RF, z_slice_RF)

nii_path = "input_volume_39.nii.gz"

cropped_volume, affine_cropped_volume = crop_volume_and_affine(
    nii_path=nii_path,
    bbox_slices=slices,
    save_cropped_nii_path=Path("cropped_volume_with_RF.nii.gz")
)

print("Cropped data shape:", cropped_volume.shape)
print("New affine:\n", affine_cropped_volume)

### We need a way to check original mask overlapping in the new cropped volume

In [None]:
cropped_mask, affine_cropped_mask = crop_volume_and_affine(
    nii_path=ROI_mask_path,
    bbox_slices=slices,
    save_cropped_nii_path=Path("cropped_mask_with_RF.nii.gz")
)

print("Cropped mask shape:", cropped_mask.shape)
print("New mask affine:\n", affine_cropped_mask)

## We execute SHAP on this cropped image, and we only consider our ROI

In [None]:
# 2) Initialise predictor ------------------------
predictor = CustomNNUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=False, # == test time augmentation
    perform_everything_on_device=True,
    device=torch.device('cuda', 0),
    verbose=False,
    verbose_preprocessing=False,
    allow_tqdm=False #it interfere with SHAP loading bar
)
# initializes the network architecture, loads the checkpoint
predictor.initialize_from_trained_model_folder(
    model_dir_readonly,
    use_folds=(0,),
    checkpoint_name='checkpoint_final.pth',
)

In [None]:
# (a) load + cropped volume  (1, C, D, H, W) – nnU-Net order
nii_path_cropped = "cropped_volume_with_RF.nii.gz"
dataset_json_path = Path(model_dir_readonly) / "dataset.json"

volume_np = nnunetv2_default_preprocessing(nii_path_cropped, predictor, dataset_json_path)

volume = torch.from_numpy(volume_np).unsqueeze(0).to(device)        # torch (1,C,D,H,W)

print("Volume shape:", volume.shape)                # (1, C, D, H, W)

In [None]:
# (b) super-voxel / organ-id map  (W, H, D)
# Load the image
img = nib.load(nii_path_cropped)

# Generate and save supervoxel map using the 
supervoxel_map = generate_supervoxel_map(img, S=100.0)
                    # (D, H, W)

In [None]:
supervoxel_map_path = 'FCC-supervoxel_map.nii.gz'
supervoxel_map_img = nib.Nifti1Image(supervoxel_map, affine_cropped_mask)
nib.save(supervoxel_map_img, supervoxel_map_path)

In [None]:
supervoxel_map = np.transpose(supervoxel_map, (2, 1, 0))                 # match (D,H,W)
# we need features of feature mask ordered from 0 (or 1) to M-1 (M)
sv_values, indexes = np.unique(supervoxel_map, return_inverse=True)

supervoxel_map = indexes.reshape(supervoxel_map.shape)
print(np.unique(supervoxel_map))

# IMPORTANT 🔸: KernelShapWithMask expects **(X, Y, Z)** without channel axis
supervoxel_map = torch.from_numpy(supervoxel_map).long().to(device)   # (D,H,W)

print("Mask shape:", supervoxel_map.shape)

### derive baseline cached dictionary
using planner, iterator, predictor (just temporary solution)

In [None]:
import pickle as pkl

with open("baseline_output_dictionary_cache.pkl", "rb") as f:
    cropped_baseline_pred_cache = pkl.load(f)

"""# temporarily use a new planner and iterator

plan_tmp = OrganMaskPerturbationPlanner(
    volume_file=nii_path_cropped,
    organ_mask_file=supervoxel_map_path,
)

# code to get our baseline prediction patch dictionary
iterator_tmp = SHAPPredictionIterator(plan_tmp, predictor,
                                               skip_completed_ids=acc.completed_ids,
                                               cache_sw_inference=True,
                                               verbose=True)
cropped_baseline_pred_cache = iterator_tmp.get_orig_image_output_dictionary()
import pickle as pkl
# Write to file
with open("cropped_baseline_output_dictionary_cache.pkl", "wb") as f:
    pkl.dump(cropped_baseline_pred_cache, f)"""


In [None]:
print(cropped_baseline_pred_cache[(None, None, None), (0, 72, None), (0, 160, None), (0, 160, None)].shape)

### get our cropped ROI mask

In [None]:
ROI_mask = np.transpose(cropped_mask, (2, 1, 0))
ROI_mask = torch.from_numpy(ROI_mask).to(device)

print(ROI_mask.shape)

### Include masking in the forward function

In [None]:
# ------------------------------------------------------------
# ❷  Forward wrapper that nnU-Net expects
# ------------------------------------------------------------

@torch.inference_mode()
def forward_segmentation_output_to_explain(
        input_image:         torch.Tensor,
        perturbation_mask:   torch.BoolTensor,
        ROI_mask:            torch.Tensor,   # remember that must be cropped to the same size of the other tensors
        baseline_prediction_dict: dict
) -> torch.Tensor:           # returns a scalar per sample
    """
    Example aggregate: sum of lymph-node logits (class 1) in the mask produced
    by the network – adapt to your real metric as needed.
    """
    logits = predictor.predict_sliding_window_return_logits_with_caching(
        input_image, perturbation_mask, baseline_prediction_dict,
    )                              # (C, D, H, W)
    # we now mask both by the segmentation prevalent class, and by ROI
    seg_mask = torch.argmax(logits, dim=0)          # (D,H,W)
    seg_mask = seg_mask * ROI_mask
    D, H, W = seg_mask.shape
    aggregate = torch.sum(logits[1].double() * seg_mask) / (D*H*W)  # scaling to avoid overflows in SHAP

    return aggregate

# c) wrap your cached‐forward method:
explainer = KernelShapWithMask(
    forward_func=lambda vol, mask: forward_segmentation_output_to_explain(
        input_image=vol,
        perturbation_mask=mask,
        ROI_mask=ROI_mask,
        baseline_prediction_dict=cropped_baseline_pred_cache)
)

"""# d) compute SHAP
attr = explainer.attribute(
    inputs=volume,       # (1,C,D,H,W)
    baselines=0.0, 
    feature_mask=supervoxel_map,
    n_samples=1000,    
    return_input_shape=True,
    show_progress=True,
)
print("Attributions:", attr.shape)  # → (1,C,D,H,W)"""


In [None]:
"""attr_postprocessed = attr[0][0].detach().cpu().numpy().transpose(2,1,0) # (W, H, D)
attr_img = nib.Nifti1Image(attr_postprocessed, affine_cropped_volume)
nib.save(attr_img, 'attribution_map-FCC-ROI.nii.gz')"""