# 1. Import Packages for the Environment

In [301]:
# Import basic packages for later use
import os
import shutil
from collections import OrderedDict

import json
import matplotlib.pyplot as plt
import nibabel as nib

import numpy as np
import torch

In [302]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [303]:
IN_KAGGLE = os.path.exists('/kaggle/input')
IN_COLAB = not IN_KAGGLE and os.path.exists('/content')
IN_DEIB = not IN_KAGGLE and not IN_COLAB

In [304]:
if not IN_DEIB:
    !pip install nnunetv2
    !pip install captum

# 2. Mount the dataset

In [305]:
from batchgenerators.utilities.file_and_folder_operations import join

if IN_COLAB:
    # Google Colab
    # for colab users only - mounting the drive

    from google.colab import drive
    drive.mount('/content/drive',force_remount = True)

    drive_dir = "/content/drive/My Drive"
    mount_dir = join(drive_dir, "tesi", "automi")
    base_dir = os.getcwd()
elif IN_KAGGLE:
    # Kaggle
    mount_dir = "/kaggle/input/automi-seg"
    base_dir = os.getcwd()
    print(base_dir)
    !ls '/kaggle/input'
    !cd "/kaggle/input/automi-seg" ; ls
else:
    mount_dir = "/workspace/data"
    base_dir = "/workspace/output"
    os.chdir(base_dir)
    print("base_dir:", base_dir)

base_dir: /workspace/output


# 3. Setting up nnU-Nets folder structure and environment variables
nnUnet expects a certain folder structure and environment variables.

Roughly they tell nnUnet:
1. Where to look for stuff
2. Where to put stuff

For more information about this please check: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/setting_up_paths.md

## 3.1 Set environment Variables and creating folders

In [306]:
# ===========================
# 📦 SETUP nnUNet ENVIRONMENT
# ===========================

# Definisci i path da settare
path_dict = {
    "nnUNet_raw": join(mount_dir, "nnunet_raw"),
    "nnUNet_preprocessed": join(mount_dir, "preprocessed_files"),#"nnUNet_preprocessed"),
    "nnUNet_results": join(mount_dir, "results"),#"nnUNet_results"),
    # "RAW_DATA_PATH": join(mount_dir, "RawData"),  # Facoltativo, se ti serve salvare zips
}

# Scrivi i path nelle variabili di ambiente, che vengono lette dal modulo paths di nnunetv2
for env_var, path in path_dict.items():
    os.environ[env_var] = path

from nnunetv2.paths import nnUNet_results, nnUNet_raw

if IN_KAGGLE:
    if nnUNet_raw == None:
        nnUNet_raw = "/kaggle/input/nnunet_raw"
    if nnUNet_results == None:
        nnUNet_results = "/kaggle/input/results"
    # Kaggle has some very unconsistent behaviors in dataset mounting...
    #nnUNet_raw = "/kaggle/input/automi-seg/nnunet_raw"
    #nnUNet_results = "/kaggle/input/automi-seg/results"
    
print("nnUNet_raw:", nnUNet_raw)
print("nnUNet_results:", nnUNet_results)

nnUNet_raw: /workspace/data/nnunet_raw
nnUNet_results: /workspace/data/results


In [307]:
exp_results_path = join(mount_dir, "experiment_results")

### Some tests

In [308]:
# all volumes of fold 0 test set
volume_codes = ["00004", "00005", "00024", "00027", "00029", "00034", "00039", "00044"]

In [309]:
ct_img_paths = {}
organ_map_paths = {}

for volume_code in volume_codes:
    if IN_KAGGLE:
        ct_img_paths[volume_code] = join(nnUNet_raw, "imagesTr", f"AUTOMI_{volume_code}_0000.nii")
        organ_map_paths[volume_code] = join(nnUNet_raw, "total_segmentator_structures", f"AUTOMI_{volume_code}_0000", "mask_mask_add_input_20_total_segmentator.nii")
    else:
        ct_img_paths[volume_code] = join(nnUNet_raw, "imagesTr", f"AUTOMI_{volume_code}_0000.nii.gz")
        organ_map_paths[volume_code] = join(nnUNet_raw, "total_segmentator_structures", f"AUTOMI_{volume_code}_0000", "mask_mask_add_input_20_total_segmentator.nii.gz")
    ct_img = nib.load(ct_img_paths[volume_code])
    organ_map = nib.load(organ_map_paths[volume_code])
    print(f"Volume {volume_code}:")
    print("CT shape:", ct_img.shape)
    print("Organ shape:", organ_map.shape)
    print("Spacing:", ct_img.header.get_zooms())
    print("Organ spacing:", organ_map.header.get_zooms())
    assert np.all(ct_img.affine == organ_map.affine), "CT and organ mask affine matrices do not match!"

Volume 00004:
CT shape: (512, 512, 221)
Organ shape: (512, 512, 221)
Spacing: (1.3671875, 1.3671875, 5.0)
Organ spacing: (1.3671875, 1.3671875, 5.0)
Volume 00005:
CT shape: (512, 512, 219)
Organ shape: (512, 512, 219)
Spacing: (1.171875, 1.171875, 5.0)
Organ spacing: (1.171875, 1.171875, 5.0)
Volume 00024:
CT shape: (512, 512, 321)
Organ shape: (512, 512, 321)
Spacing: (0.976562, 0.976562, 3.0)
Organ spacing: (0.976562, 0.976562, 3.0)
Volume 00027:
CT shape: (512, 512, 236)
Organ shape: (512, 512, 236)
Spacing: (1.3671875, 1.3671875, 5.0)
Organ spacing: (1.3671875, 1.3671875, 5.0)
Volume 00029:
CT shape: (512, 512, 259)
Organ shape: (512, 512, 259)
Spacing: (1.367188, 1.367188, 7.5)
Organ spacing: (1.367188, 1.367188, 7.5)
Volume 00034:
CT shape: (512, 512, 249)
Organ shape: (512, 512, 249)
Spacing: (1.171875, 1.171875, 5.0)
Organ spacing: (1.171875, 1.171875, 5.0)
Volume 00039:
CT shape: (512, 512, 283)
Organ shape: (512, 512, 283)
Spacing: (1.171875, 1.171875, 5.0)
Organ spacing: (1.

In [310]:
ct_img = nib.load(ct_img_paths[volume_codes[1]])
print("affine:", ct_img.affine)

affine: [[-1.17187500e+00  0.00000000e+00  0.00000000e+00  3.00000000e+02]
 [ 0.00000000e+00 -1.17187500e+00  0.00000000e+00  1.85300003e+02]
 [ 0.00000000e+00  0.00000000e+00  5.00000000e+00 -1.43419995e+03]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]]


In [311]:
from typing import Union

# define an utility for Nifty file saving
def save_nifty(data: Union[np.ndarray, torch.Tensor], affine, path, dtype=np.float32):
    if isinstance(data, torch.Tensor):
        data = data.detach().cpu().numpy()
    # set fixed type float32
    data = data.astype(dtype)
    nib.save(nib.Nifti1Image(data, affine), path)

def save_nifty_binary(data: Union[np.ndarray, torch.Tensor], affine, path):
    return save_nifty(data, affine, path, dtype=np.uint8)


# define an utility for annoying nnunetv2 preprocessing
def nnunetv2_default_preprocessing(ct_img_path, 
                                   predictor, 
                                   dataset_json_path,
                                   other_volumes: Union[np.ndarray, None] = None) -> np.ndarray:
    """
    Preprocesses the CT image and other volumes using nnunetv2's default preprocessing
    pipeline. This function reads the CT image, applies the preprocessor, and returns
    the preprocessed image.
    """
    plans_manager = predictor.plans_manager
    configuration_manager = predictor.configuration_manager
    
    preprocessor = configuration_manager.preprocessor_class(verbose=False)
    rw = plans_manager.image_reader_writer_class()
    if callable(rw) and not hasattr(rw, "read_images"):
        rw = rw()
    img_np, img_props = rw.read_images([str(ct_img_path)])
    
    preprocessed, other_volumes_preprocessed, _ = preprocessor.run_case_npy(
        img_np, seg=other_volumes, properties=img_props,
        plans_manager=plans_manager,
        configuration_manager=configuration_manager,
        dataset_json=dataset_json_path
    )
    if other_volumes:
        return preprocessed, other_volumes_preprocessed
    return preprocessed

In [312]:
# model directory; note that this is readonly in Kaggle environment
if IN_KAGGLE:
    model_dir = join(nnUNet_results, 'Dataset003_AUTOMI_CTVLNF_NEWGL_results/nnUNetTrainer__nnUNetPlans__3d_fullres')
else:
    model_dir = join(nnUNet_results, 'nnUNetTrainer__nnUNetPlans__3d_fullres')

## Utility to export logits to a visualizable segmentation

In [313]:
import numpy as np
import torch
import os
from typing import Union
from pathlib import Path
from nnunetv2.configuration import default_num_processes
from nnunetv2.inference.export_prediction import export_prediction_from_logits

def export_logits_to_nifty_segmentation(
    predictor,
    volume_file: Path,
    model_dir: str,
    logits: Union[str, np.ndarray, torch.Tensor],
    npz_dir: str | None,
    output_dir: str = "",
    fold: int = 0,
    save_probs: bool = False,
    from_file: bool = True
):
    """
    Converts a saved .npz logits file into a native-space NIfTI segmentation using nnU-Net's helper.

    Args:
        predictor: An instantiated nnU-Net predictor object with loaded plans/configs.
        volume_file: An Path object pointing to the raw image file.
        model_dir (str): Path to the nnU-Net mode1Introductionl directory containing dataset.json.
        logits (str or tensor): Base name of the .npz logits file (no extension), when from_file=True
        npz_dir (str): Directory where the .npz file is stored.
        output_dir (str): Directory where the .nii.gz segmentation will be saved.
        fold (int): The fold number used for prediction (default is 0).
        save_probs (bool): Whether to save softmax probabilities as a .npz file.
        from_file (bool). Whether to convert from a file instead of from the logits (default true)
    """
    if from_file:
        npz_logits = Path(npz_dir) / f"{logits}.npz"
        output_nii = Path(output_dir) / f"{logits}_seg.nii.gz"
        logits = np.load(npz_logits)["logits"]
    else:
        output_nii = Path(output_dir) / "exported_seg.nii.gz"

    plans_manager = predictor.plans_manager
    configuration_manager = predictor.configuration_manager
    dataset_json = Path(model_dir) / "dataset.json"

    preprocessor = configuration_manager.preprocessor_class(verbose=False)
    rw = plans_manager.image_reader_writer_class()
    if callable(rw) and not hasattr(rw, "read_images"):
        rw = rw()
    img_np, img_props = rw.read_images([str(volume_file)])

    _, _, data_props = preprocessor.run_case_npy(
        img_np, seg=None, properties=img_props,
        plans_manager=plans_manager,
        configuration_manager=configuration_manager,
        dataset_json=dataset_json
    )


    export_prediction_from_logits(
        predicted_array_or_file=logits,
        properties_dict=data_props,
        configuration_manager=configuration_manager,
        plans_manager=plans_manager,
        dataset_json_dict_or_file=str(dataset_json),
        output_file_truncated=os.path.splitext(str(output_nii))[0],
        save_probabilities=save_probs,
        num_threads_torch=default_num_processes
    )

    print(f"✅  NIfTI segmentation written → {output_nii}")

In [314]:
"""export_logits_to_nifty_segmentation(
    predictor=predictor,
    plan=plan,
    model_dir=Path(model_dir),
    logits_filename="pred_00007",
    npz_dir="SHAP/shap_run",
    output_dir="SHAP/shap_run",
    fold=0,
    save_probs=False
)"""

'export_logits_to_nifty_segmentation(\n    predictor=predictor,\n    plan=plan,\n    model_dir=Path(model_dir),\n    logits_filename="pred_00007",\n    npz_dir="SHAP/shap_run",\n    output_dir="SHAP/shap_run",\n    fold=0,\n    save_probs=False\n)'

## We define a sliding window caching for faster multi-inference scenario, like SHAP

### Try to override the sliding_window_function

In [315]:
from typing import Union
import time
import numpy as np
import torch
from tqdm import tqdm
from queue import Queue
from threading import Thread
from acvl_utils.cropping_and_padding.padding import pad_nd_image
from nnunetv2.utilities.helpers import empty_cache, dummy_context
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.inference.sliding_window_prediction import compute_gaussian, compute_steps_for_sliding_window

class CustomNNUNetPredictor(nnUNetPredictor):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def set_wandb_logging(self, wandb_logging: bool, wandb_label: str = "", wandb_commit: bool = True):
        self.wandb_logging = wandb_logging
        self.wandb_label = wandb_label
        self.wandb_commit = wandb_commit

    @torch.inference_mode()
    def predict_sliding_window_return_logits_with_caching(self, input_image: torch.Tensor,
                                                          perturbation_mask: torch.BoolTensor | None,
                                                          baseline_prediction_dict: dict) \
            -> Union[np.ndarray, torch.Tensor]:
        """
        Method predict_sliding_window_return_logits taken from official nnunetv2 documentation:
        https://github.com/MIC-DKFZ/nnUNet/blob/58a3b121a6d1846a978306f6c79a7c005b7d669b/nnunetv2/inference/predict_from_raw_data.py
        We add a perturbation_mask parameter to check each patch for the actual presence of a perturbation
        """
        # fallback to original method if perturbation_mask is None
        if perturbation_mask is None:
            return self.predict_sliding_window_return_logits(input_image)
                
        assert isinstance(input_image, torch.Tensor)
        self.network = self.network.to(self.device)
        self.network.eval()

        # Autocast can be annoying
        # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection)
        # and needs to be disabled.
        # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False
        # is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
        # So autocast will only be active if we have a cuda device.
        with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
            assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'

            
            if self.verbose:
                print(f'Input shape: {input_image.shape}')
                print("step_size:", self.tile_step_size)
                print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)
                print(f'Perturbation mask shape: {perturbation_mask.shape}')

            # DEBUG
            """# save as nifti the input image and the perturbation map
            affine = nib.load("supervoxel_map.nii.gz").affine
            save_nifty(torch.permute(input_image[0], (2,1,0)), affine, "input_image_debug.nii.gz")
            save_nifty(torch.permute(perturbation_mask[0], (2,1,0)), affine, "perturbation_mask_debug.nii.gz")"""

            # if input_image is smaller than tile_size we need to pad it to tile_size.
            data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
                                                       'constant', {'value': 0}, True,
                                                       None)

            # slicers can be applied to both perturbed volume and 
            slicers = self._internal_get_sliding_window_slicers(data.shape[1:])

            if self.perform_everything_on_device and self.device != 'cpu':
                # behavior changed
                try:
                    predicted_logits = self._internal_predict_sliding_window_return_logits(
                        data, slicers, True, perturbation_mask, baseline_prediction_dict, caching=True
                    )
                except RuntimeError as e:
                    if "CUDA out of memory" in str(e):
                        print("⚠️  CUDA OOM, cambiare batch size o patch size!")
                        raise
                    else:
                        # Mostra l'errore reale e aborta: niente CPU fallback
                        raise
            else:
                predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
                                                                                       self.perform_everything_on_device)

            empty_cache(self.device)
            # revert padding
            predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
        return predicted_logits
                

    def _slice_key(self, slicer_tuple):
        # make slicer object hashable to use it for cache lookup
        return tuple((s.start, s.stop, s.step) for s in slicer_tuple)


    @torch.inference_mode()
    def _internal_predict_sliding_window_return_logits(self,
                                                       data: torch.Tensor,
                                                       slicers,
                                                       do_on_device: bool = True,
                                                       perturbation_mask: torch.BoolTensor | None = None,
                                                       baseline_prediction_dict: dict | None = None,
                                                       caching: bool = False,
                                                       ):
        """
        Modified to manage the caching of patches
        """
        predicted_logits = n_predictions = prediction = gaussian = workon = None
        results_device = self.device if do_on_device else torch.device('cpu')
        if next(self.network.parameters()).device != results_device:
            self.network = self.network.to(results_device)

        def producer(d, slh, q):
            for s in slh:
                q.put((torch.clone(d[s][None], memory_format=torch.contiguous_format).to(results_device), s))
            q.put('end')

        try:
            empty_cache(self.device)

            # move data to device
            if self.verbose:
                print(f'move image to device {results_device}')
            data = data.to(results_device)
            queue = Queue(maxsize=2)
            t = Thread(target=producer, args=(data, slicers, queue))
            t.start()

            # preallocate arrays
            if self.verbose:
                print(f'preallocating results arrays on device {results_device}')
            predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
                                           dtype=torch.half,
                                           device=results_device)
            n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)

            if self.use_gaussian:
                gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
                                            value_scaling_factor=10,
                                            device=results_device)
            else:
                gaussian = 1

        

            if not self.allow_tqdm and self.verbose:
                print(f'running prediction: {len(slicers)} steps')

            # Before starting queue processing:
            start_time = time.time()
            
            with tqdm(desc=None, total=len(slicers), disable=not self.allow_tqdm) as pbar:
                cache_hits = 0
                cache_misses = 0
                while True:
                    item = queue.get()
                    if item == 'end':
                        queue.task_done()
                        break
                    workon, sl = item
                    try:
                        if caching and not self.check_overlapping(sl, perturbation_mask):
                            prediction = baseline_prediction_dict[self._slice_key(sl)].to(results_device)
                            cache_hits += 1
                        else:
                            prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)
                            cache_misses += 1
                    except Exception as e:
                        raise RuntimeError("Errore nella predizione del patch") from e
            
                    assert prediction.device == predicted_logits.device
            
                    if self.use_gaussian:
                        prediction *= gaussian
                    predicted_logits[sl] += prediction
                    n_predictions[sl[1:]] += gaussian
            
                    del prediction, workon
                    queue.task_done()
                    pbar.set_postfix(
                        cache=f"{cache_hits}",
                        mem=f"{torch.cuda.memory_allocated()/1e9:.2f} GB"
                    )
                    pbar.update(1)
            
            queue.join()
            
            if caching:
                # Final metrics
                hit_ratio = (cache_hits / len(slicers) * 100) if len(slicers) > 0 else 0.0
                elapsed_time = time.time() - start_time
            
                if self.verbose and not self.allow_tqdm:
                    print(f"Cache hits: {cache_hits}\\{len(slicers)}")
                    print(f"Inference time: {elapsed_time:.2f} sec")

                # Log to W&B (only if you are logged and there is a run)
                if self.wandb_logging:
                    wandb.log({
                        f"{self.wandb_label}cache_hit_ratio_percent": hit_ratio,
                        f"{self.wandb_label}inference_time_sec": elapsed_time
                    }, commit=self.wandb_commit)


            # predicted_logits /= n_predictions
            torch.div(predicted_logits, n_predictions, out=predicted_logits)
            # check for infs
            if torch.any(torch.isinf(predicted_logits)):
                raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '
                                   'reduce value_scaling_factor in compute_gaussian or increase the dtype of '
                                   'predicted_logits to fp32')
        except Exception as e:
            del predicted_logits, n_predictions, prediction, gaussian, workon
            empty_cache(self.device)
            empty_cache(results_device)
            raise e
        return predicted_logits
  


    def get_output_dictionary_sliding_window(self, data: torch.Tensor, slicers,
                                            do_on_device: bool = True,
                                            ) -> torch.Tensor:
        """
        # create a dictionary that associates the output of the inference, to each slicer of the sliding window module
        # this way we can set ready for cache the output for the untouched patches.
        """
        
        dictionary = dict()
        prediction = workon = None
        results_device = self.device if do_on_device else torch.device('cpu')
        if next(self.network.parameters()).device != results_device:
            self.network = self.network.to(results_device)

        def producer(d, slh, q):
            for s in slh:
                #tqdm.write(f"put patch {s} on queue")    # dentro producer
                q.put((torch.clone(d[s][None], memory_format=torch.contiguous_format).to(self.device), s))
            q.put('end')

        try:
            empty_cache(self.device)

            # move data and network to device
            if self.verbose:
                print(f'move image and model to device {results_device}')

            self.network = self.network.to(results_device)
            data = data.to(results_device)
            queue = Queue(maxsize=2)
            t = Thread(target=producer, args=(data, slicers, queue))
            t.start()

            if not self.allow_tqdm and self.verbose:
                print(f'running prediction: {len(slicers)} steps')

            with tqdm(desc=None, total=len(slicers), disable=not self.allow_tqdm) as pbar:
                while True:
                    item = queue.get()
                    if item == 'end':
                        queue.task_done()
                        break
                    workon, sl = item
                    pred_gpu = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)

                    pred_cpu = pred_gpu.cpu()
                    # save prediction in the dictionary
                    dictionary[self._slice_key(sl)] = pred_cpu
                    # immediately free gpu memory
                    del pred_gpu
                    
                    queue.task_done()
                    pbar.update()
            queue.join()

        except Exception as e:
            del workon#, prediction
            empty_cache(self.device)
            empty_cache(results_device)
            raise e
        return dictionary



    def check_overlapping(self, slicer, perturbation_mask: torch.BoolTensor) -> bool:
        """
        Restituisce True se la patch definita da `slicer`
        contiene almeno un voxel perturbato.
    
        Parameters
        ----------
        slicer : tuple
            Quello prodotto da `_internal_get_sliding_window_slicers`,
            cioè (slice(None), slice(x0,x1), slice(y0,y1), slice(z0,z1)).
        perturbation_mask : torch.BoolTensor
            Maschera (C, X, Y, Z) con True nei voxel da perturbare
            (di solito C==1 o replicata sui canali).
    
        Returns
        -------
        bool
            True ↔ almeno un voxel True nella patch.out
        """
        # NB: il primo elemento del tuple è sempre slice(None) (canali).
        #     Lo manteniamo: non ha overhead e semplifica.
        return perturbation_mask[slicer].any().item()

    from datetime import datetime
    
    def log_cache_stats(log_file: str,
                        total_patches: int,
                        cache_hits: int,
                        cache_misses: int,
                        extra_info: dict | None = None):
        """
        Append caching statistics to a JSON log file.
    
        Args:
            log_file (str): Path to the JSON file.
            total_patches (int): Total number of patches processed.
            cache_hits (int): Number of cache hits.
            cache_misses (int): Number of cache misses.
            extra_info (dict, optional): Additional run-specific information to log.
        """
        hit_ratio = (cache_hits / total_patches * 100) if total_patches > 0 else 0.0
        log_entry = {
            "timestamp": datetime.now().isoformat(timespec="seconds"),
            "total_patches": total_patches,
            "cache_hits": cache_hits,
            "cache_misses": cache_misses,
            "hit_ratio_percent": round(hit_ratio, 2)
        }
    
        if extra_info:
            log_entry.update(extra_info)
    
        # Load existing logs
        if os.path.exists(log_file):
            try:
                with open(log_file, "r", encoding="utf-8") as f:
                    logs = json.load(f)
                if not isinstance(logs, list):
                    logs = []
            except json.JSONDecodeError:
                logs = []
        else:
            logs = []
    
        logs.append(log_entry)
    
        with open(log_file, "w", encoding="utf-8") as f:
            json.dump(logs, f, indent=2)
    
        print(f"[CACHE] Logged stats to '{log_file}'")

# Next step: define regular, fixed size superpixels and try to compute the attributions of each of them
So we also need to define a metric to compare, since segmentation explanations, differently from classification, is intrinsically ambiguous. For example, let's select a priori a single region of the segmentation output, and use the average of these pixels to compute the impact of perturbations.

### 4. Face-centered cubic (FCC) lattice induced supervoxel assignment
-> more *isotropic* than simple cubes

### 4.1 affine transformation to translate isotropy from voxel space into the original geometrical space (measured in mm)

### 4.2 Try to apply the original algorithm FCC it to an affine transformed volume that has the same proportion as the .nii in the physical space. Transform->apply the algorithm to derive the map-> back transform the map onto the voxel space



In [316]:
import numpy as np
import nibabel as nib
from scipy.spatial import cKDTree

def generate_FCC_supervoxel_map(img, S=200.0):
    """
    Generate a supervoxel map using FCC tessellation in physical space (original version).
    
    Args:
        img: Nibabel NIfTI image object
        S (float): Desired supervoxel size in millimeters (default: 200.0)
    
    Returns:
        supervoxel_map: 3D NumPy array with integer labels for supervoxels
    """
    # Load volume and affine from image
    volume = img.get_fdata()
    affine = img.affine
    W, H, D = volume.shape

    # Compute the physical bounding box of the volume
    corners_voxel = np.array([
        [0, 0, 0],
        [W-1, 0, 0],
        [0, H-1, 0],
        [0, 0, D-1],
        [W-1, H-1, 0],
        [W-1, 0, D-1],
        [0, H-1, D-1],
        [W-1, H-1, D-1]
    ])
    corners_hom = np.hstack((corners_voxel, np.ones((8, 1))))
    corners_physical = (affine @ corners_hom.T).T[:, :3]
    min_xyz = corners_physical.min(axis=0)
    max_xyz = corners_physical.max(axis=0)

    # Generate FCC lattice centers in physical space
    a = S * np.sqrt(2)
    factor = 2 / a
    p_min = int(np.floor(factor * min_xyz[0])) - 1
    p_max = int(np.ceil(factor * max_xyz[0])) + 1
    q_min = int(np.floor(factor * min_xyz[1])) - 1
    q_max = int(np.ceil(factor * max_xyz[1])) + 1
    r_min = int(np.floor(factor * min_xyz[2])) - 1
    r_max = int(np.ceil(factor * max_xyz[2])) + 1

    # Create grid of possible indices
    p_vals = np.arange(p_min, p_max + 1)
    q_vals = np.arange(q_min, q_max + 1)
    r_vals = np.arange(r_min, r_max + 1)
    P, Q, R = np.meshgrid(p_vals, q_vals, r_vals, indexing='ij')
    P = P.flatten()
    Q = Q.flatten()
    R = R.flatten()

    # Filter for FCC lattice points (sum of indices is even)
    mask = (P + Q + R) % 2 == 0
    P = P[mask]
    Q = Q[mask]
    R = R[mask]

    # Compute physical coordinates of centers
    centers = np.column_stack((P * a / 2, Q * a / 2, R * a / 2))

    # Keep only centers within the bounding box
    inside = ((centers[:, 0] >= min_xyz[0]) & (centers[:, 0] <= max_xyz[0]) &
              (centers[:, 1] >= min_xyz[1]) & (centers[:, 1] <= max_xyz[1]) &
              (centers[:, 2] >= min_xyz[2]) & (centers[:, 2] <= max_xyz[2]))
    centers = centers[inside]

    # Check if any centers were generated
    print(f"Number of supervoxel centers: {len(centers)}")
    if len(centers) == 0:
        raise ValueError("No supervoxel centers generated. Try reducing S.")

    # Generate voxel indices and transform to physical coordinates
    voxel_indices = np.indices((W, H, D)).reshape(3, -1).T  # shape (W*H*D, 3)
    voxel_indices_hom = np.hstack((voxel_indices, np.ones((voxel_indices.shape[0], 1))))  # shape (W*H*D, 4)
    physical_coords = (affine @ voxel_indices_hom.T).T[:, :3]  # shape (W*H*D, 3)

    # Assign each voxel to the nearest supervoxel center
    tree = cKDTree(centers)
    _, labels = tree.query(physical_coords)

    # Create the supervoxel map
    supervoxel_map = labels.reshape((W, H, D)).astype(np.int32)

    return supervoxel_map

## Combine FCC regularity with organ context: Organ-aware FCC supervoxels

In [317]:
import numpy as np
import nibabel as nib
from scipy.spatial import cKDTree
from collections import defaultdict

def generate_FCC_organs_supervoxel_map(volume_img, organ_img, S, compute_statistics=False):
    """
    Generate organ-aware supervoxel map using FCC tessellation.
    
    Args:
        volume_img: Nibabel NIfTI image object (intensity volume)
        organ_img: Nibabel NIfTI image object (organ labels, 0=background)
        S: Desired supervoxel size in millimeters
    
    Returns:
        tuple: (supervoxel_map, organ_table, statistics)
            - supervoxel_map: 3D array with supervoxel IDs (0=background)
            - organ_table: dict {supervoxel_id: organ_label}
            - statistics: dict with per-supervoxel statistics
    """
    # Load data
    volume = volume_img.get_fdata()
    organs = organ_img.get_fdata().astype(np.int32)
    affine = volume_img.affine
    W, H, D = volume.shape
    
    # Verify dimensions match
    if volume.shape != organs.shape:
        raise ValueError("Volume and organ images must have the same dimensions")
    
    # Compute physical bounding box
    corners_voxel = np.array([
        [0, 0, 0], [W-1, 0, 0], [0, H-1, 0], [0, 0, D-1],
        [W-1, H-1, 0], [W-1, 0, D-1], [0, H-1, D-1], [W-1, H-1, D-1]
    ])
    corners_hom = np.hstack((corners_voxel, np.ones((8, 1))))
    corners_physical = (affine @ corners_hom.T).T[:, :3]
    min_xyz = corners_physical.min(axis=0)
    max_xyz = corners_physical.max(axis=0)

    # Generate FCC lattice centers in physical space
    a = S * np.sqrt(2)
    factor = 2 / a
    p_min = int(np.floor(factor * min_xyz[0])) - 1
    p_max = int(np.ceil(factor * max_xyz[0])) + 1
    q_min = int(np.floor(factor * min_xyz[1])) - 1
    q_max = int(np.ceil(factor * max_xyz[1])) + 1
    r_min = int(np.floor(factor * min_xyz[2])) - 1
    r_max = int(np.ceil(factor * max_xyz[2])) + 1

    # Create FCC lattice points efficiently
    centers = []
    for p in range(p_min, p_max + 1):
        for q in range(q_min, q_max + 1):
            for r in range(r_min, r_max + 1):
                if (p + q + r) % 2 == 0:  # FCC constraint
                    center = np.array([p * a / 2, q * a / 2, r * a / 2])
                    # Check if within bounding box
                    if (min_xyz[0] <= center[0] <= max_xyz[0] and
                        min_xyz[1] <= center[1] <= max_xyz[1] and
                        min_xyz[2] <= center[2] <= max_xyz[2]):
                        centers.append(center)
    
    centers = np.array(centers)
    print(f"Generated {len(centers)} FCC centers")
    
    if len(centers) == 0:
        raise ValueError("No FCC centers generated. Try reducing S.")
    
    # Process volume slice by slice to save memory
    supervoxel_map = np.zeros((W, H, D), dtype=np.int32)
    fcc_to_organs = defaultdict(set)  # Maps FCC center index to set of organ labels
    
    tree = cKDTree(centers)
    
    print("Assigning voxels to FCC centers...")
    for w in range(W):
        # Generate coordinates for this slice
        y_coords, z_coords = np.meshgrid(np.arange(H), np.arange(D), indexing='ij')
        slice_voxels = np.column_stack([
            np.full(H * D, w),
            y_coords.ravel(),
            z_coords.ravel()
        ])
        
        # Transform to physical space
        slice_physical = (affine @ np.column_stack([slice_voxels, np.ones(H * D)]).T).T[:, :3]
        
        # Assign to nearest FCC center
        _, fcc_labels = tree.query(slice_physical)
        fcc_labels = fcc_labels.reshape(H, D)
        
        # Store FCC assignments and collect organ labels per FCC center
        for h in range(H):
            for d in range(D):
                fcc_idx = fcc_labels[h, d]
                organ_label = organs[w, h, d]
                
                if organ_label != 0:  # Skip background
                    fcc_to_organs[fcc_idx].add(organ_label)
                
                # Temporarily store FCC index (will be converted to supervoxel ID later)
                supervoxel_map[w, h, d] = fcc_idx
    
    # Create supervoxel IDs: each (FCC_center, organ) pair gets unique ID
    print("Creating organ-aware supervoxels...")
    fcc_organ_to_supervoxel = {}
    supervoxel_to_organ = {}
    supervoxel_id = 1  # Start from 1 (0 reserved for background)
    
    for fcc_idx in sorted(fcc_to_organs.keys()):
        for organ_label in sorted(fcc_to_organs[fcc_idx]):
            fcc_organ_to_supervoxel[(fcc_idx, organ_label)] = supervoxel_id
            supervoxel_to_organ[supervoxel_id] = organ_label
            supervoxel_id += 1
    
    print(f"Created {len(supervoxel_to_organ)} organ-aware supervoxels")
    
    # Convert FCC assignments to final supervoxel IDs
    final_supervoxel_map = np.zeros((W, H, D), dtype=np.int32)
    
    for w in range(W):
        for h in range(H):
            for d in range(D):
                fcc_idx = supervoxel_map[w, h, d]
                organ_label = organs[w, h, d]
                
                if organ_label != 0:  # Non-background
                    key = (fcc_idx, organ_label)
                    if key in fcc_organ_to_supervoxel:
                        final_supervoxel_map[w, h, d] = fcc_organ_to_supervoxel[key]
                # Background voxels remain 0

    if compute_statistics:
        # Compute statistics
        print("Computing supervoxel statistics...")
        statistics = {}
        voxel_volume_mm3 = np.abs(np.linalg.det(affine[:3, :3]))
        
        for sv_id in supervoxel_to_organ.keys():
            mask = final_supervoxel_map == sv_id
            if np.any(mask):
                voxels = volume[mask]
                statistics[sv_id] = {
                    'organ_label': supervoxel_to_organ[sv_id],
                    'voxel_count': int(np.sum(mask)),
                    'volume_mm3': float(np.sum(mask) * voxel_volume_mm3),
                    'mean_intensity': float(np.mean(voxels)),
                    'std_intensity': float(np.std(voxels)),
                    'min_intensity': float(np.min(voxels)),
                    'max_intensity': float(np.max(voxels)),
                    'centroid_voxel': [float(x) for x in np.array(np.where(mask)).mean(axis=1)]
                }
        
        # Add background statistics
        background_mask = final_supervoxel_map == 0
        if np.any(background_mask):
            bg_voxels = volume[background_mask]
            statistics[0] = {
                'organ_label': 0,
                'voxel_count': int(np.sum(background_mask)),
                'volume_mm3': float(np.sum(background_mask) * voxel_volume_mm3),
                'mean_intensity': float(np.mean(bg_voxels)),
                'std_intensity': float(np.std(bg_voxels)),
                'min_intensity': float(np.min(bg_voxels)),
                'max_intensity': float(np.max(bg_voxels)),
                'centroid_voxel': [float(x) for x in np.array(np.where(background_mask)).mean(axis=1)]
            }
        
        # Print summary
        organ_counts = defaultdict(int)
        for sv_id, organ_label in supervoxel_to_organ.items():
            organ_counts[organ_label] += 1
        
        print("\nSupervoxel summary:")
        print(f"Background supervoxels: 1 (ID: 0)")
        for organ_label in sorted(organ_counts.keys()):
            print(f"Organ {organ_label}: {organ_counts[organ_label]} supervoxels")
    else:
        supervoxel_to_organ, statistics = None, None
    
    return final_supervoxel_map, supervoxel_to_organ, statistics


def save_organ_supervoxel_results(supervoxel_map, organ_table, statistics, 
                                volume_img, output_prefix="organ_supervoxels"):
    """
    Save supervoxel results to files.
    
    Args:
        supervoxel_map: 3D array with supervoxel IDs
        organ_table: dict mapping supervoxel_id -> organ_label  
        statistics: dict with supervoxel statistics
        volume_img: original volume image (for affine/header)
        output_prefix: prefix for output files
    """
    # Save supervoxel map as NIfTI
    supervoxel_img = nib.Nifti1Image(supervoxel_map.astype(np.int32), 
                                   volume_img.affine, volume_img.header)
    supervoxel_img.header.set_data_dtype(np.int32)
    nib.save(supervoxel_img, f"{output_prefix}_map.nii.gz")
    print(f"Saved supervoxel map: {output_prefix}_map.nii.gz")
    
    # Save organ table as CSV
    import csv
    with open(f"{output_prefix}_table.csv", 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['supervoxel_id', 'organ_label'])
        writer.writerow([0, 0])  # Background
        for sv_id in sorted(organ_table.keys()):
            writer.writerow([sv_id, organ_table[sv_id]])
    print(f"Saved organ table: {output_prefix}_table.csv")
    
    # Save statistics as CSV
    with open(f"{output_prefix}_stats.csv", 'w', newline='') as f:
        writer = csv.writer(f)
        header = ['supervoxel_id', 'organ_label', 'voxel_count', 'volume_mm3',
                 'mean_intensity', 'std_intensity', 'min_intensity', 'max_intensity',
                 'centroid_x', 'centroid_y', 'centroid_z']
        writer.writerow(header)
        
        for sv_id in sorted(statistics.keys()):
            stats = statistics[sv_id]
            row = [sv_id, stats['organ_label'], stats['voxel_count'], stats['volume_mm3'],
                  stats['mean_intensity'], stats['std_intensity'], 
                  stats['min_intensity'], stats['max_intensity']] + stats['centroid_voxel']
            writer.writerow(row)
    print(f"Saved statistics: {output_prefix}_stats.csv")

## Try SLIC for visual context aware supervoxels

### define a preprocessing routine to enhance SLIC results

In [318]:
from skimage import exposure
from skimage.restoration import denoise_nl_means, estimate_sigma
from scipy.ndimage import gaussian_filter
from skimage.exposure import equalize_adapthist

def preprocessing_for_SLIC(data: np.ndarray):
    # 2️⃣ Clip extreme intensities (e.g. 0.5%–99.5% quantiles) for contrast enhancement
    vmin, vmax = np.quantile(data, (0.005, 0.995))
    data = np.clip(data, vmin, vmax)
    data = exposure.rescale_intensity(data, in_range=(vmin, vmax), out_range=(0, 1))  #  [oai_citation:0‡scikit-image.org](https://scikit-image.org/docs/0.25.x/api/skimage.segmentation.html?utm_source=chatgpt.com) [oai_citation:1‡researchgate.net](https://www.researchgate.net/publication/330691413_A_novel_technique_for_analysing_histogram_equalized_medical_images_using_superpixels?utm_source=chatgpt.com) [oai_citation:2‡arxiv.org](https://arxiv.org/abs/2204.05278?utm_source=chatgpt.com) [oai_citation:3‡scikit-image.org](https://scikit-image.org/skimage-tutorials/lectures/three_dimensional_image_processing.html?utm_source=chatgpt.com)

    sigma_vox = np.array([1.0 / s for s in spacing])  # blur by 1 mm across axes
    data = gaussian_filter(data, sigma=sigma_vox)

    data = exposure.equalize_hist(data)

    # Apply slice-wise CLAHE for 3D volume
    data = np.stack([equalize_adapthist(slice_, clip_limit=0.03)
                       for slice_ in data], axis=0)

    return data

In [319]:
from skimage.segmentation import slic

In [320]:
# Parametri SLIC: n_segments definisce quanti supervoxels circa si voglion
apply_SLIC = lambda data, spacing, n_supervoxels: slic(
                preprocessing_for_SLIC(data), 
                n_segments=n_supervoxels, 
                compactness=0.2,
                spacing=spacing,
                start_label=0,
                max_num_iter=10, 
                channel_axis=None)

### Initialize predictor

In [321]:
# 2) Initialise predictors ------------------------
morf_predictor = CustomNNUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=False, # == test time augmentation
    perform_everything_on_device=True,
    device=torch.device('cuda', 0),
    verbose=False,
    verbose_preprocessing=False,
    allow_tqdm=False #it interfere with SHAP loading bar
)

lerf_predictor = CustomNNUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=False, # == test time augmentation
    perform_everything_on_device=True,
    device=torch.device('cuda', 0),
    verbose=False,
    verbose_preprocessing=False,
    allow_tqdm=False #it interfere with SHAP loading bar
)
# initializes the network architecture, loads the checkpoint
morf_predictor.initialize_from_trained_model_folder(
    model_dir,
    use_folds=(0,),
    checkpoint_name='checkpoint_final.pth',
)

lerf_predictor.initialize_from_trained_model_folder(
    model_dir,
    use_folds=(0,),
    checkpoint_name='checkpoint_final.pth',
)

# generic predictor for some general operations (baseline computation etc)
predictor = morf_predictor



### some utils for Nifti geometry

In [322]:
def get_spacing(ct_img_data: nib.nifti1.Nifti1Image):
    """affine = ct_img_data.affine
    spacing = np.sqrt(np.sum(affine[:3, :3] ** 2, axis=0))"""
    # this is equivalent to
    spacing = ct_img_data.header.get_zooms() # get_zooms returns (x, y, z) spacing
    return spacing

def get_origin(ct_img_data: nib.nifti1.Nifti1Image):
    affine = ct_img_data.affine
    origin = affine[:3, 3]
    return origin

# Define a ROI to explain segmentation in. 
Maybe this will provide a more useful attribution map, highlighting nearby organs

In [323]:
"""# if true, we will use a json file containing the bounding box for the Region of Interest (ROI), otherwise we will use a masked segmentation
ROI_TYPE = "BoundingBox"  # "BoundingBox" or "MaskedSegmentation"""

'# if true, we will use a json file containing the bounding box for the Region of Interest (ROI), otherwise we will use a masked segmentation\nROI_TYPE = "BoundingBox"  # "BoundingBox" or "MaskedSegmentation'

In [324]:
#BB_ROI = json.load(open(BB_ROI_paths[volume_codes[1]], "r")) if ROI_TYPE == "BoundingBox" else None

### ROI mask is a binary mask highlighting the lymphnodes of interest. We need a bounding box to crop the volume accordingly

In [325]:
import nibabel as nib
import numpy as np
from typing import Tuple

def get_mask_bbox_slices(mask_nii_path):
    """
    Load a binary ROI mask NIfTI and compute the minimal 3D bounding
    box slices containing all positive voxels.

    Parameters
    ----------
    mask_nii_path : str or Path
        Path to the input binary ROI mask NIfTI (.nii or .nii.gz).

    Returns
    -------
    bbox_slices : tuple of slice
        A 3-tuple of Python slice objects (x_slice, y_slice, z_slice)
        defining the minimal bounding box.
    """
    # 1) Load mask
    nii = nib.load(str(mask_nii_path))
    data = nii.get_fdata()
    if data.ndim != 3:
        raise ValueError("Input NIfTI must be a 3D volume")
    
    # 2) Find indices of positive voxels
    pos_voxels = np.argwhere(data > 0)
    if pos_voxels.size == 0:
        raise ValueError("No positive voxels found in mask")
    
    # 3) Compute min/max per axis
    x_min, y_min, z_min = pos_voxels.min(axis=0)
    x_max, y_max, z_max = pos_voxels.max(axis=0)
    
    # 4) Build slice objects (end is exclusive, hence +1)
    bbox_slices = (
        slice(int(x_min), int(x_max) + 1),
        slice(int(y_min), int(y_max) + 1),
        slice(int(z_min), int(z_max) + 1),
    )
    
    return bbox_slices

def get_slices_from_BB_ROI(BB_ROI: dict) -> Tuple[slice, slice, slice]:
    """
    Extracts the bounding box slices from the BB_ROI dictionary.
    Parameters
    ----------
    BB_ROI : dict
        Dictionary containing the bounding box coordinates with keys:
        example:
        {'FileFormat': array('MITK ROI', dtype='<U8'), 'Version': array(2), 'Geometry': array({'Size': [512.0, 512.0, 221.0], 'Transform': [1.3671875, -0.0, -0.0, 0, -0.0, 1.3671875, -0.0, 0, 0.0, 0.0, 5.0, 0, -350.0, -278.6000061035156, -432.239990234375, 1]},
      dtype=object), 'ROIs': array([{'ID': 0, 'Max': [370.5, 291.5, 114.5], 'Min': [276.5, 193.49999999999994, 92.5], 'Properties': {'ColorProperty': {'color': [1.0, 0.0, 0.0]}, 'StringProperty': {'name': 'AUTOMI_00004_0000 Bounding Box'}}}],
      dtype=object)}
    """
    # take the ceil of both min and max (ROI in MITK apprently is shifted by 0.5 voxels)
    x_slice = slice(int(np.ceil(BB_ROI['ROIs'][0]['Min'][0])), int(np.ceil(BB_ROI['ROIs'][0]['Max'][0])))
    y_slice = slice(int(np.ceil(BB_ROI['ROIs'][0]['Min'][1])), int(np.ceil(BB_ROI['ROIs'][0]['Max'][1])))
    z_slice = slice(int(np.ceil(BB_ROI['ROIs'][0]['Min'][2])), int(np.ceil(BB_ROI['ROIs'][0]['Max'][2])))
    return x_slice, y_slice, z_slice
   

In [326]:
import json
from pathlib import Path
from typing import Tuple, Optional
import nibabel as nib


def load_roi_slices(
    ROI_type: str,
    ROI_BB_path: Optional[str] = None,
    ROI_segmentation_mask_path: Optional[str] = None
) -> Tuple[slice, slice, slice]:
    """
    Load ROI definition and return bounding-box slices in x, y, z order.

    Parameters
    ----------
    ROI_type : {"BoundingBox", "MaskedSegmentation"}
        Type of ROI to load.
    ROI_BB_path : str, optional
        Path to JSON file defining the bounding box ROI.
    ROI_segmentation_mask_path : str, optional
        Path to NIfTI file defining the ROI segmentation mask.

    Returns
    -------
    x_slice, y_slice, z_slice : slice
        Slices along each axis for cropping operations.
    """
    if ROI_type == "BoundingBox":
        if ROI_BB_path is None:
            raise ValueError("ROI_BB_path must be provided when ROI_type is 'BoundingBox'.")
        BB_ROI = json.load(open(ROI_BB_path, "r"))
        x_slice, y_slice, z_slice = get_slices_from_BB_ROI(BB_ROI)

    elif ROI_type == "MaskedSegmentation":
        if ROI_segmentation_mask_path is None:
            raise ValueError("ROI_segmentation_mask_path must be provided when ROI_type is 'MaskedSegmentation'.")
        # Optionally inspect mask details if needed:
        mask_img = nib.load(ROI_segmentation_mask_path)
        print(f"Loaded mask shape: {mask_img.shape}, affine: {mask_img.affine}")
        x_slice, y_slice, z_slice = get_mask_bbox_slices(ROI_segmentation_mask_path)

    else:
        raise ValueError(f"Unsupported ROI_type: {ROI_type}")

    return x_slice, y_slice, z_slice


### the identified region is our ROI bounding box in case we use the masked segmentation as manually derived ROI
### otherwise we just use the ROI bounding box

In [327]:
import numpy as np

def slices_to_binary_mask(volume_shape, bbox_slices, dtype=np.uint8):
    """
    Create a binary mask of given shape where voxels inside the provided
    3D bounding‐box slices are set to 1, and all others to 0.

    Parameters
    ----------
    volume_shape : tuple of int
        The full 3D volume dimensions, e.g. (X, Y, Z).
    bbox_slices : tuple of slice
        A 3‐tuple of slice objects (x_slice, y_slice, z_slice) defining
        the region to mask.
    dtype : data‐type, optional
        The desired data‐type of the output mask (default: np.uint8).

    Returns
    -------
    mask : np.ndarray
        A binary mask array of shape `volume_shape`, with ones in the
        region defined by `bbox_slices` and zeros elsewhere.
    """
    if len(volume_shape) != len(bbox_slices):
        raise ValueError(f"volume_shape has {len(volume_shape)} dimensions, "
                         f"but bbox_slices has {len(bbox_slices)} slices")

    # Initialize mask to zeros
    mask = np.zeros(volume_shape, dtype=dtype)
    # Set the bounding‐box region to 1
    mask[bbox_slices] = 1

    return mask

In [328]:
def create_roi_mask(
    volume_path: str,
    bbox_slices: Tuple[slice, slice, slice],
    output_path: Optional[str] = "ROI_binary_mask.nii.gz"
) -> str:
    """
    Create a binary ROI mask based on bounding-box slices and save as NIfTI.

    Parameters
    ----------
    volume_path : str
        Path to the input volume NIfTI file.
    bbox_slices : tuple of slice
        Slices (x_slice, y_slice, z_slice) defining the ROI bounding box.
    output_path : str, optional
        Path where the binary mask NIfTI will be saved.

    Returns
    -------
    output_path : str
        Path to the saved binary mask file.
    """
    img = nib.load(volume_path)
    volume_shape = img.get_fdata().shape
    affine = img.affine

    # Generate binary mask array
    mask_array = slices_to_binary_mask(
        volume_shape=volume_shape,
        bbox_slices=bbox_slices
    )

    # Create and save NIfTI mask
    save_nifty_binary(mask_array, affine, output_path)
    print(f"Saved binary ROI mask to {output_path}")

    return output_path

### to **crop** correctly the volume around the *ROI*, we need to derive the **receptive field** of the sliding window inference, that depends on the *patch size*.

In [329]:
"""patch_size = np.array(predictor.configuration_manager.patch_size)
print("Patch size: ", patch_size)

# Receptive field is twice the patch size-1
RF = 2*(patch_size-1)
print("Receptive field: ", RF)"""

'patch_size = np.array(predictor.configuration_manager.patch_size)\nprint("Patch size: ", patch_size)\n\n# Receptive field is twice the patch size-1\nRF = 2*(patch_size-1)\nprint("Receptive field: ", RF)'

### Consider the receptive field to compute the final slices for cropping
Remember that model metadata are related to transposed volume (nnunetv2 takes (D, H, W) shape)

In [330]:
def compute_rf_slices(
    bbox_slices: Tuple[slice, slice, slice],
    patch_size: np.ndarray,
    volume_shape: Tuple[int, int, int]
) -> Tuple[slice, slice, slice]:
    """
    Compute bounding-box slices expanded by the model's receptive field.

    RF is defined as 2*(patch_size - 1), and applied symmetrically.
    """
    x_slice, y_slice, z_slice = bbox_slices
    W, H, D = volume_shape

    # receptive field along each axis (model input axes reversed)
    RF = 2 * (patch_size - 1)
    RF_z, RF_y, RF_x = RF  # expect patch_size as [D,H,W]

    x_start = max(x_slice.start - RF_x // 2, 0)
    x_stop = min(x_slice.stop + RF_x // 2, W)
    y_start = max(y_slice.start - RF_y // 2, 0)
    y_stop = min(y_slice.stop + RF_y // 2, H)
    z_start = max(z_slice.start - RF_z // 2, 0)
    z_stop = min(z_slice.stop + RF_z // 2, D)

    return (slice(x_start, x_stop), slice(y_start, y_stop), slice(z_start, z_stop))

In [331]:
import nibabel as nib
import numpy as np

def crop_volume_and_affine(nii_path, bbox_slices, save_cropped_nii_path=None):
    """
    Crop a 3D NIfTI volume using the given bounding-box slices and
    recompute the affine so the cropped volume retains correct world coordinates.

    Parameters
    ----------
    nii_path : str or Path
        Path to the input NIfTI volume (.nii or .nii.gz).
    bbox_slices : tuple of slice
        A 3-tuple (x_slice, y_slice, z_slice) as returned by get_mask_bbox_slices().
    save_cropped_nii_path : str or Path, optional
        If provided, the cropped volume will be saved here as a new NIfTI.

    Returns
    -------
    cropped_data : np.ndarray
        The volume data cropped to the bounding box.
    new_affine : np.ndarray
        The updated 4×4 affine transform for the cropped volume.
    """
    # 1) Load the original image
    img = nib.load(str(nii_path))
    data = img.get_fdata()
    affine = img.affine

    # 2) Crop the data array
    cropped_data = data[bbox_slices]

    # 3) Extract the voxel‐offsets for x, y, z from the slice starts
    x_slice, y_slice, z_slice = bbox_slices
    z0, y0, x0 = z_slice.start, y_slice.start, x_slice.start

    # 4) Compute the new affine translation: shift the origin by the voxel offsets
    # Note voxel coordinates are (i, j, k) = (x, y, z)
    offset_vox = np.array([x0, y0, z0])
    new_affine = affine.copy()
    new_affine[:3, 3] += affine[:3, :3].dot(offset_vox)

    # 5) Optionally save the cropped volume
    if save_cropped_nii_path is not None:
        cropped_img = nib.Nifti1Image(cropped_data, new_affine)
        nib.save(cropped_img, str(save_cropped_nii_path))

    return cropped_data, new_affine


### We need a way to check original mask overlapping in the new cropped volume

### We also need a mask to correctly ignoring out-of ROI context in our aggregation metrics

In [332]:
from typing import Tuple, Optional, Dict

def crop_volume_with_rf(
    volume_path: str,
    bbox_slices: Tuple[slice, slice, slice],
    patch_size: np.ndarray,
    output_dir: Optional[str] = "."
) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
    """
    Crop the volume and binary ROI mask to the ROI padded by receptive field.

    Returns a dict with keys:
      - 'cropped_volume'
      - 'affine_cropped_volume'
      - 'cropped_roi_mask'
      - 'affine_cropped_mask'
    """
    volume_img = nib.load(volume_path)
    volume_shape = volume_img.get_fdata().shape

    # compute RF-expanded slices
    padded_slices = compute_rf_slices(bbox_slices, patch_size, volume_shape)

    # ensure output directory exists
    out_dir = Path(output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # crop volume
    cropped_vol_path = out_dir / "cropped_volume.nii.gz"
    cropped_volume, affine_cropped_volume = crop_volume_and_affine(
        nii_path=volume_path,
        bbox_slices=padded_slices,
        save_cropped_nii_path=cropped_vol_path
    )

    # crop ROI mask
    roi_mask_path = out_dir / "ROI_binary_mask.nii.gz"
    cropped_mask, affine_cropped_mask = crop_volume_and_affine(
        nii_path=str(roi_mask_path),
        bbox_slices=padded_slices,
        save_cropped_nii_path=out_dir / "cropped_mask_with_RF.nii.gz"
    )

    return {
        "cropped_volume": (cropped_volume, affine_cropped_volume),
        "cropped_roi_mask": (cropped_mask, affine_cropped_mask)
    }

## We execute SHAP on this cropped image, and we only consider our ROI

### set device

In [333]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Preprocessing (skipped for now)

In [334]:
#NNUNET_PREPROCESSING = False

In [335]:
def preprocess_volume(
    cropped_volume_path: str,
    predictor,
    dataset_json_path: str,
    use_nnunet: bool = True
) -> Tuple[np.ndarray, nib.Nifti1Image]:
    """
    Preprocess the cropped volume for nnU-Net inference.

    If `use_nnunet` is True, applies nnU-Net's default preprocessing.
    Otherwise, loads the volume directly and reformats axes.

    Returns
    -------
    volume_np : np.ndarray
        Array of shape (1, C, D, H, W) ready for torch conversion.
    saved_volume_nii : nib.Nifti1Image
        NIfTI image object of the preprocessed volume (for debugging or saving).
    """
    if use_nnunet:
        # nnU-Net preprocessing utility, assumed imported
        volume_np = nnunetv2_default_preprocessing(
            cropped_volume_path,
            predictor,
            dataset_json_path
        )  # expects (C, D, H, W)
        # Save NIfTI for debugging
        volume_transposed = volume_np.squeeze().transpose(2, 1, 0) # back to (W,H,D)
        affine = nib.load(cropped_volume_path).affine
        path = Path(cropped_volume_path).replace(suffix="_preproc.nii.gz")
        save_nifty(volume_transposed, affine, path)

    else:
        cropped = nib.load(cropped_volume_path)
        data = cropped.get_fdata()
        data = np.transpose(data, (2, 1, 0))  # (D,H,W)
        data = np.expand_dims(data, axis=0)   # (1, D,H,W)
        volume_np = data
        volume_nii = None

    return volume_np, volume_nii

### Supervoxels subdivision

In [336]:
def generate_supervoxel_map(
    cropped_volume_path: str,
    supervoxel_type: str = "FCC",
    fcc_cube_side: float = 100.0,
    slic_n_supervoxels: int = 380,
    organ_map_path: Optional[str] = None,
    just_load: Optional[bool] = False
) -> np.ndarray:
    """
    Generate a supervoxel map for the cropped volume.
    Supports:
      - "full-organs": 1 organ = 1 supervoxel, simple as that
      - "FCC": standard face-centered cubic
      - "SLIC": simple linear iterative clustering
      - "FCC-organs": FCC constrained by organ presence
    Returns an array of shape (D, H, W) with consecutive integer labels.
    """
    if just_load:
        sv_map = nib.load("supervoxel_map.nii.gz").get_fdata()
    else:
        img = nib.load(cropped_volume_path)
        data = img.get_fdata().astype(np.float32)

        if supervoxel_type == "full-organs":
            if organ_map_path is None:
                raise ValueError("organ_map_path must be provided for 'full-organs' type")
            # load organ segmentation map
            organ_img = nib.load(organ_map_path)
            # the supervoxel map is just the organ map
            sv_map = organ_img.get_fdata().astype(np.float32)
        elif supervoxel_type == "FCC":
            sv_map = generate_FCC_supervoxel_map(img, S=fcc_cube_side)
        elif supervoxel_type == "SLIC":
            spacing = get_spacing(data)
            sv_map = apply_SLIC(data, spacing, slic_n_supervoxels)
        elif supervoxel_type == "FCC-organs":
            if organ_map_path is None:
                raise ValueError("organ_map_path must be provided for 'FCC-organs' type")
            # load organ segmentation map
            organ_img = nib.load(organ_map_path)
            sv_map, organ_table, statistics = generate_FCC_organs_supervoxel_map(
            volume_img=img, organ_img=organ_img, S=fcc_cube_side
        )
        else:
            raise ValueError(f"Unsupported supervoxel_type: {supervoxel_type}")

        # remap labels to consecutive ints
        _, inverse = np.unique(sv_map, return_inverse=True)
        sv_map = inverse.reshape(sv_map.shape)
        
        # debug
        image = nib.Nifti1Image(sv_map.astype(np.float32), img.affine)
        nib.save(image, "supervoxel_map.nii.gz")
        print("supervoxel map saved!")

    # reorder axes (W,H,D) -> (D,H,W)
    sv_map = np.transpose(sv_map, (2, 1, 0))
    return sv_map

### derive baseline cached dictionary

In [337]:
def get_cached_output_dictionary(volume_file: Path,
                                 predictor: CustomNNUNetPredictor,
                                 preprocess_before_run: bool = True,
                                verbose: bool = False) -> dict:
        """
        Return a dictionary indexed by the slices for the sliding window, of the output of the inference for each patch
        of the original volume
        """
        rw = predictor.plans_manager.image_reader_writer_class()

        # If nnU-Net returns a class instead of an instance, instantiate it
        if callable(rw) and not hasattr(rw, "read_images"):
            rw = rw()

        orig_image, orig_props = rw.read_images(
            [str(volume_file)]
        )             # (C, Z, Y, X)

        if preprocess_before_run:
        
            preprocessor = predictor.configuration_manager.preprocessor_class()
            # the following cause the kernel death at first notebook run
            data_pp, _, _ = preprocessor.run_case_npy(
                    orig_image,
                    seg=None,
                    properties=orig_props,
                    plans_manager=predictor.plans_manager,
                    configuration_manager=predictor.configuration_manager,
                    dataset_json=predictor.dataset_json
                )
        
            # to torch, channel-first is already true
            inp_tensor = torch.from_numpy(data_pp)
        else:
            inp_tensor = torch.from_numpy(orig_image)

        slicers = predictor._internal_get_sliding_window_slicers(inp_tensor.shape[1:])

        dictionary = predictor.get_output_dictionary_sliding_window(inp_tensor, slicers)

        return dictionary

In [338]:
#USE_STORED_DICTIONARY = False

In [339]:
def compute_baseline_prediction(
    volume_tensor: "torch.Tensor",
    predictor,
    cropped_volume_path: str,
    use_nnunet: bool = True
) -> Tuple["torch.Tensor", Dict]:
    """
    Compute baseline segmentation mask and cache full prediction outputs.

    Returns the segmentation mask tensor and a cache dictionary.
    """
    # Predict logits
    logits = predictor.predict_sliding_window_return_logits(volume_tensor[0])
    # Convert to binary mask (assuming class 1 is positive)
    seg_mask = (torch.argmax(logits, dim=0) == 1)

    # Cache prediction outputs for SHAP
    cache_dict = get_cached_output_dictionary(
        volume_file=cropped_volume_path,
        predictor=predictor,
        preprocess_before_run=use_nnunet,
        verbose=True
    )
    return seg_mask, cache_dict

### Include masking in the forward function

### **Chrabaszcz aggregation**  

Let

* $z_1^{(i)}(x)$ – class-1 logit at voxel $x$ after perturbation *i*
* $P_i(x)=\mathbf 1\!\left[\arg\max_c z_c^{(i)}(x)=1\right]$ – binary mask of voxels currently predicted as lymph-node
* no ROI, total volume considered

$$
S_{\text{Chr}}^{(i)} \;=\;
\frac{1}{\alpha}\sum_{x} P_i(x)\;z_1^{(i)}(x)
$$

where $\alpha$ is the `scaling_factor`.

* **Counts evidence only from voxels the model *currently* labels as class 1.**
* *False positives (FP):* contribute **positively** (they are in $P_i$).
* *False negatives (FN):* contribute **zero** (their logit is absent).

Source: Chrabaszcz et al., *Aggregated Attributions for Explanatory Analysis of 3-D Segmentation Models*, 2024.


In [340]:
def chrabaszcz_aggregation(logits: torch.Tensor,
                           scaling_factor: float = 1.0,
                          ) -> torch.Tensor:
    """
    aggregate the output logits in a sum, following the proposed method in "Chrabaszcz et al. - 2024 - Aggregated Attributions for
    Explanatory Analysis of 3D Segmentation Models"
    """
    seg_mask = (torch.argmax(logits, dim=0) == 1)
    aggregate = torch.sum(logits[1].double() * seg_mask)

    return aggregate / scaling_factor  # normalize to avoid overflows in SHAP


### **True positive aggregation** (Chrabaszcz aggregation + baseline-mask filtering)

Introduce the unperturbed prediction $P_0$, and, optionally, the ROI mask $R(x)$.
Keep only voxels that are **still** class 1 *and* were class 1 before:

$$
S_{\text{Chr\,keep}}^{(i)} \;=\;
\frac{1}{\alpha}\sum_{x} \bigl[P_i(x)\land P_0(x)\land R(x)\bigr]\;z_1^{(i)}(x)
$$

* **True positives preserved** (TP core) add positive evidence.
* **FP created by the perturbation** are **ignored** (masked out).
* **FN** lower the score indirectly because their logits disappear from the sum.

Conceptually this is the **positive part** of a signed logit-difference metric.

In [341]:
def true_positive_aggregation(logits: torch.Tensor,
                          unperturbed_binary_mask: torch.Tensor,
                          ROI_mask: torch.Tensor,
                           scaling_factor: float = 1.0,
                          ) -> torch.Tensor:
    """
    aggregate the output logits in a sum, following the proposed method in "Chrabaszcz et al. - 2024 - Aggregated Attributions for
    Explanatory Analysis of 3D Segmentation Models", with the  addition of filtering by the unperturbed segmentation.
    We can use this to ignore "false positive" voxels -> only account for true positive contribution;
    so this corresponds conceptually to the positive part of a logits difference metric
    """
    seg_mask = (torch.argmax(logits, dim=0) == 1)          # (D,H,W)
    seg_mask = seg_mask.bool() & ROI_mask.bool() & unperturbed_binary_mask.bool()  # prefer boolean indexing for reletively sparse tensors
    aggregate = torch.sum(logits[1].double()[seg_mask])

    return aggregate / scaling_factor  # normalize to avoid overflows in SHAP

---

### **False-positive aggregation**

Directly sum class-1 evidence from **new** positives inside ROI:

$$
S_{\text{FP}}^{(i)} \;=\;
-\frac{1}{\alpha}\sum_{x}
\bigl[P_i(x)\land\neg P_0(x)\land R(x)\bigr]\;z_1^{(i)}(x)
$$

* Measures **only** the spurious lymph-node evidence a perturbation introduces.
* Higher value ⇒ stronger tendency to hallucinate extra nodes.

In [342]:
def false_positive_aggregation(logits: torch.Tensor,
                              unperturbed_binary_mask: torch.Tensor,
                              ROI_mask: torch.Tensor,
                              scaling_factor: float = 1.0,
                              ) -> torch.Tensor:
    """
    Negative part of signed logit-difference objective, returned with positive sign;
    Only accounts for false positive voxels in segmentation (spurious lymph nodes)
    """
    # current segmentation (prevailing class)
    seg_mask = (torch.argmax(logits, dim=0) == 1)     # (D,H,W) ∈ {0,1}

    fp_mask  = seg_mask * ROI_mask * torch.logical_not(unperturbed_binary_mask.bool()).float()   # prefer float multiplication for dense tensors
    # change sign to get negative scores for false positives
    aggregate = - torch.sum(logits[1].double() * fp_mask) 

    return aggregate / scaling_factor

---

### **Dice aggregation (prediction-vs-baseline, ROI-restricted)**

Let $R(x)$ be the ROI mask.

$$
P_i' = P_i \odot R, \qquad
P_0' = P_0 \odot R
$$

$$
S_{\text{Dice}}^{(i)} \;=\;
\frac{1}{\alpha}\;
\frac{2\,\langle P_i',\,P_0'\rangle}{\lVert P_i'\rVert_1 + \lVert P_0'\rVert_1 + \varepsilon}
$$

* Drops when either **FP** ($P_i'=1,\,P_0'=0$) or **FN** ($P_i'=0,\,P_0'=1$) appear → penalises both error types symmetrically.

Based on the “self-consistency Dice” used in MiSuRe (Hasany et al., 2024).





In [343]:
def dice_aggregation(logits: torch.Tensor,
                    unperturbed_binary_mask: torch.Tensor,
                    ROI_mask: torch.Tensor,
                    scaling_factor: float = 1.0,
                    eps: float = 1e-9,  # small value to avoid division by zero
                    ) -> torch.Tensor:
    """
    Use Dice score, the same aggregation measure from "Hasany et al. - 2024 - MiSuRe is all you need to explain your image segmentation"
    Dice score provides a single aggregation metric that accounts for both false negatives and false positives penalization.
    Specificly, we instead score each perturbation supervoxels by that Dice => supervoxels that contribute the most in reproducing
    the baseline segmentation, will get an higher score
    """
    # 1. Boolean masks restricted to ROI
    pred = (logits.argmax(dim=0) == 1).float() * ROI_mask.float()
    base = unperturbed_binary_mask.float()      * ROI_mask.float()

    # 2. Intersection and denominator
    inter = (pred * base).sum()
    denom = pred.sum() + base.sum() + eps       # |P| + |B|

    # 3. Dice coefficient
    dice = (2.0 * inter) / denom

    return dice / scaling_factor

---

### **Signed logit-difference (masked)**

Define a signed weight

$$
w(x)=
\begin{cases}
+1 & \text{if } P_0(x)=1\\
-1 & \text{otherwise}
\end{cases},
\qquad w(x)\leftarrow w(x)\,R(x)
$$

$$
S_{\text{LD}}^{(i)} \;=\;
\frac{1}{\alpha}\sum_{x} P_i(x)\;w(x)\;z_1^{(i)}(x)
$$

* **Positive attribution:** voxels that *keep* the baseline TP (support segmentation).
* **Negative attribution:** voxels that become class 1 **only** after perturbation (generate FP inside ROI).
* FN reduce the positive term (logits disappear) but do **not** add negative mass.

In [344]:


def logit_difference_aggregation(
        logits: torch.Tensor,
        unperturbed_binary_mask: torch.Tensor,
        ROI_mask: torch.Tensor,
        scaling_factor: float = 1.0,
) -> torch.Tensor:
    """
    Signed logit-difference objective *masked by the prevailing class*.
    """
    # current segmentation (prevailing class)
    seg_mask = (torch.argmax(logits, dim=0) == 1)     # (D,H,W) ∈ {0,1}

    # +1 inside baseline positives, −1 elsewhere...
    signed_weight = torch.where(unperturbed_binary_mask.bool(),
                                torch.tensor(1.0, device=logits.device),
                                torch.tensor(-1.0, device=logits.device))

    # ... but we only care of false positives inside the ROI (we don't even have the segmentation mask outside the ROI)
    signed_weight = signed_weight * ROI_mask

    # aggregate signed class-1 evidence, restricted to voxels
    # that are *currently* predicted as class-1 (seg_mask)
    aggregate = torch.sum(logits[1] * seg_mask * signed_weight)
    return aggregate / scaling_factor

### Define a function to prepare all the steps for SHAP

In [345]:
from pathlib import Path
from typing import Tuple, Optional, Dict

import numpy as np
import nibabel as nib
import torch

def prepare_data_for_shap(
    volume_path: str,
    patch_size: np.ndarray,
    predictor,
    dataset_json_path: str,
    ROI_BB_path: Optional[str] = None,
    ROI_segmentation_mask_path: Optional[str] = None,
    ROI_type: str = "BoundingBox",
    nnunet_preprocessing: bool = True,
    supervoxel_type: str = "FCC",
    fcc_supervoxel_size: Optional[float] = None,
    slic_n_supervoxels: Optional[int] = None,
    organ_map_path: Optional[str] = None,
    load_stored_sv_map: Optional[bool] = False, # Forbidden for multiple (different) volumes, only for debugging single volume
    device: str = 'cuda:0'
) -> Tuple[torch.Tensor, np.ndarray, torch.Tensor, torch.Tensor, torch.Tensor, Dict]:
    """
    Orchestrate data preparation by invoking modular utilities.
    Returns:
      - volume_tensor (1,C,D,H,W)
      - affine of cropped volume
      - supervoxel_map_tensor (D,H,W)
      - segmentation_mask_tensor (D,H,W)
      - ROI_mask_tensor (D,H,W)
      - cache_dict
    """
    # 1. Load ROI slices
    x_slice, y_slice, z_slice = load_roi_slices(
        ROI_type,
        ROI_BB_path,
        ROI_segmentation_mask_path
    )

    # 2. Create binary ROI mask
    roi_mask_path = create_roi_mask(
        volume_path,
        (x_slice, y_slice, z_slice)
    )

    # 3. Compute receptive-field–padded slices
    vol_img = nib.load(volume_path)
    padded_slices = compute_rf_slices(
        (x_slice, y_slice, z_slice),
        patch_size,
        vol_img.get_fdata().shape
    )

    # 4. Crop volume and ROI mask using RF slices
    cropped = crop_volume_with_rf(
        volume_path,
        (x_slice, y_slice, z_slice),
        patch_size
    )
    cropped_volume, affine_cropped = cropped['cropped_volume']
    cropped_mask, _ = cropped['cropped_roi_mask']

    cropped_vol_path = Path('cropped_volume.nii.gz')
    cropped_mask_path = Path('cropped_mask_with_RF.nii.gz')

    # 5. If provided, also crop the organs map with the same RF slices
    cropped_organs_path = None
    if organ_map_path is not None:
        cropped_organs_path = 'cropped_organ_map_with_RF.nii.gz'
        crop_volume_and_affine(
            organ_map_path,
            padded_slices,
            Path(cropped_organs_path)
        )

    # 6. Preprocess volume
    volume_np, _ = preprocess_volume(
        str(cropped_vol_path),
        predictor,
        dataset_json_path,
        use_nnunet=nnunet_preprocessing
    )

    # 7. Convert to torch tensor
    volume_tensor = torch.from_numpy(
        volume_np.astype(np.float32)
    ).unsqueeze(0).to(device)

    # 8. Generate supervoxel map, passing cropped_organs_path for FCC-organs
    sv_array = generate_supervoxel_map(
        str(cropped_vol_path),
        supervoxel_type=supervoxel_type,
        fcc_cube_side = fcc_supervoxel_size,
        slic_n_supervoxels = slic_n_supervoxels,
        organ_map_path=cropped_organs_path,
        just_load=load_stored_sv_map
    )
    sv_tensor = torch.from_numpy(sv_array).long().to(device)

    # 9. Compute baseline segmentation and cache
    seg_mask, cache_dict = compute_baseline_prediction(
        volume_tensor,
        predictor,
        str(cropped_vol_path),
        use_nnunet=nnunet_preprocessing
    )

    # 10. Prepare ROI mask tensor
    roi_mask_tensor = torch.from_numpy(
        np.transpose(cropped_mask, (2, 1, 0))
    ).to(device)

    return (
        volume_tensor,
        affine_cropped,
        sv_tensor,
        seg_mask.to(device),
        roi_mask_tensor,
        cache_dict
    )

## Iterate over volumes and run SHAP

In [346]:
WANDB = True

In [347]:
import os
import wandb

# evita spam in console
os.environ.setdefault("WANDB_SILENT", "true")

def wandb_login(project=None, entity=None):
    # 1) prova da variabile d’ambiente / .env
    api_key = os.getenv("WANDB_API_KEY")

    # 2) fallback a kaggle_secrets se non trovato
    if not api_key:
        try:
            from kaggle_secrets import UserSecretsClient
            api_key = UserSecretsClient().get_secret("wandb_key")
        except Exception:
            raise RuntimeError("WANDB_API_KEY non trovata né in .env né in Kaggle secrets")

    # login
    wandb.login(key=api_key, relogin=True)

    # opzionale: avvia run
    if project is not None and entity is not None: 
        return wandb.init(project=project, entity=entity, settings=wandb.Settings(silent=True))


In [348]:
if WANDB:
    os.environ["WANDB_SILENT"] = "true"
    settings=wandb.Settings(silent=True)  # no console spam
    
    wandb_login()

    morf_predictor.set_wandb_logging(wandb_logging=WANDB, wandb_label="MoRF_", wandb_commit=False)
    lerf_predictor.set_wandb_logging(wandb_logging=WANDB, wandb_label="LeRF_", wandb_commit=False)

In [349]:
dataset_json_path = Path(model_dir) / "dataset.json"

In [350]:
patch_size = np.array(morf_predictor.configuration_manager.patch_size)

In [351]:
from typing import Callable

@torch.inference_mode()
def forward_segmentation_output_to_explain(
        input_image:         torch.Tensor,
        perturbation_mask:   torch.BoolTensor | None,
        segmentation_mask:      torch.Tensor,   # remember that must be cropped to the same size of the other tensors
        ROI_bounding_box_mask:      torch.Tensor,
        baseline_prediction_dict: dict,
        predictor: CustomNNUNetPredictor,
        aggregation_fn: Callable = true_positive_aggregation,
) -> torch.Tensor:           # returns a scalar per sample
    """
    Example aggregate: sum of lymph-node logits (class 1) in the mask produced
    by the network – adapt to your real metric as needed.
    """
    logits = predictor.predict_sliding_window_return_logits_with_caching(
        input_image, perturbation_mask, baseline_prediction_dict,
    )                              # (C, D, H, W)
    # we now mask both by the segmentation prevalent class, and by ROI
    D,H,W = logits.shape[1:]
    aggregate = aggregation_fn(
        logits = logits,
        unperturbed_binary_mask = segmentation_mask,
        ROI_mask = ROI_bounding_box_mask,
        scaling_factor = ((D*H*W) if aggregation_fn != dice_aggregation else 1.0)
    )

    return aggregate

## Evaluate attribution maps using ABPC (Area Between Perturbation Curves)

In [352]:
def compute_voxel_volume_mm3(nifti_volume) -> float:
    """
    Compute the volume of a single voxel in mm^3 from a NIfTI image.

    Parameters:
    - nifti_volume: nib.Nifti1Image, the input NIfTI image

    Returns:
    - voxel_volume_mm3: float, volume of a single voxel in mm^3
    """
    spacing = nifti_volume.header.get_zooms()  # get_zooms returns (x, y, z) spacing in mm
    unit = nifti_volume.header.get_xyzt_units()[0]  # spatial unit
    if unit == 'mm':
        pass  # already in mm
    elif unit == 'cm':
        spacing = tuple(s * 10.0 for s in spacing)
    elif unit == 'm':
        spacing = tuple(s * 1000.0 for s in spacing)
    elif unit == 'um':
        spacing = tuple(s * 1e-3 for s in spacing)
    else:
        raise ValueError(f"Unsupported spatial unit: {unit}")
    voxel_volume_mm3 = spacing[0] * spacing[1] * spacing[2]
    return voxel_volume_mm3


def get_total_volume(feature_map: np.ndarray, voxel_volume_mm3: float) -> float:
    """
    Compute the total volume of a tensor in voxels and in mm^3.
    Parameters:
    - feature_map: np.ndarray with shape (Z, Y, X), integer-labeled regions
    - voxel_volume_mm3: float, volume of a single voxel in mm^3

    Returns:
    - volume_voxels: int, total volume in voxels
    - total_volume_mm: float, total volume in mm^3
    """
    volume_voxels = np.prod(feature_map.shape)
    total_volume_mm = volume_voxels * voxel_volume_mm3
    return volume_voxels, total_volume_mm

def get_supervoxel_volumes(feature_map: np.ndarray, voxel_volume_mm3: float, include_background: bool = True) -> np.ndarray:
    """
    Compute the volume (voxel count) for each supervoxel in the feature map.

    Parameters:
    - feature_map: np.ndarray with shape (Z, Y, X), integer-labeled regions
    - voxel_volume_mm3: float, volume of a single voxel in mm^3
    - include_background: bool, if False, excludes region ID 0 from the result

    Returns:
    - supervoxel_volumes: np.ndarrays with shape (2, num_features,), volume (voxel count) per region
    the first element of the first axis contains the number of voxels, the second the volume in mm^3
    0th element corresponds to region ID 0, 1st to ID 1, etc.
    """
    features = np.unique(feature_map)
    if not include_background:
        features = features[features != 0]

    # Compute the volume (voxel count) for each feature region
    supervoxel_volumes = np.array([
        np.sum(feature_map == feat_id) for feat_id in features
    ])
    supervoxel_volumes = (supervoxel_volumes, supervoxel_volumes * voxel_volume_mm3)

    return supervoxel_volumes

In [353]:
# first we need to retrieve the interpretable feature attribution vector from feature map + attribution map
def get_attribution_vector(feature_map: np.ndarray, attribution_map: np.ndarray, include_background: bool = True):
    """
    Recover the original interpretable feature attribution vector from a dense attribution map,
    assuming the attribution is constant within each labeled region of the feature map.

    Parameters:
    - feature_map: np.ndarray with shape (Z, Y, X), integer-labeled regions
    - attribution_map: np.ndarray with same shape, values are repeated per region
    - include_background: bool, if False, excludes region ID 0 from the result

    Returns:
    - attribution_vector: np.ndarray of shape (num_features,), one attribution per region
    """
    features = np.unique(feature_map)
    if not include_background:
        features = features[features != 0]

    # Select the first voxel index of each feature region
    attribution_vector = np.array([
        attribution_map[feature_map == f][0]  # safe because values are constant per region
        for f in features
    ])

    return attribution_vector


# define the method that given an attribution map in the interpretable space, return
# two generators:
# the MoRF (Most Relevant First) generator, returning the interpretable feature vectors sorted by relevance order
# the LeRF (Least Relevant First) generator, returning the interpretable feature vectors sorted by inverse relevance order

def generate_perturbations_ABPC(attribution_vector):
    sorted_features = sorted(range(len(attribution_vector)), key=lambda i: attribution_vector[i], reverse=True)
    # Create the MoRF generator
    # each item is a binary vector containing the first n features on, and the remaining off
    def morf_generator():
        vector = np.ones_like(attribution_vector)
        yield torch.from_numpy(vector).float()
        for i in sorted_features:
            vector[i] = 0
            yield torch.from_numpy(vector).float()

    # Create the LeRF generator
    # each item is a binary vector containing the last n features on, and the remaining off
    def lerf_generator():
        vector = np.ones_like(attribution_vector)
        yield torch.from_numpy(vector).float()
        for i in reversed(sorted_features):
            vector[i] = 0
            yield torch.from_numpy(vector).float()

    return morf_generator(), lerf_generator()


def generate_perturbations_ABPC_with_volumes(attribution_vector, volume_vector):
    """
    Generate perturbation sequences with volume tracking.
    
    Returns:
    - morf_generator: generator yielding (binary_vector, cumulative_volume_removed)
    - lerf_generator: generator yielding (binary_vector, cumulative_volume_removed)
    """
    sorted_indices = sorted(range(len(attribution_vector)), key=lambda i: attribution_vector[i], reverse=True)
    
    def morf_generator():
        vector = np.ones_like(attribution_vector)
        cumulative_volume = np.zeros((2,))  # (voxel count, mm^3)
        yield torch.from_numpy(vector).float(), cumulative_volume
        
        for i in sorted_indices:
            vector[i] = 0
            cumulative_volume[0] += volume_vector[0][i]
            cumulative_volume[1] += volume_vector[1][i]
            yield torch.from_numpy(vector).float(), cumulative_volume

    def lerf_generator():
        vector = np.ones_like(attribution_vector)
        cumulative_volume = np.zeros((2,))  # (voxel count, mm^3)
        yield torch.from_numpy(vector).float(), cumulative_volume
        
        for i in reversed(sorted_indices):
            vector[i] = 0
            cumulative_volume[0] += volume_vector[0][i]
            cumulative_volume[1] += volume_vector[1][i]
            yield torch.from_numpy(vector).float(), cumulative_volume

    return morf_generator(), lerf_generator()

In [354]:
# define the perturbator function that given the current binary input, the original input volume and the feature map
# returns the perturbed volume and the perturbation map

# (from Custom Captum Lime/KernelSHAP)
def default_from_interp_rep_transform(curr_sample, original_input, feature_map, baseline, attribute_background=True):

    def _build_keep_mask(labels: torch.Tensor, sample_vec: torch.Tensor, attribute_background: bool) -> torch.BoolTensor:
        """
        Returns a boolean mask 'keep':
            True  -> take from original_inputs
            False -> take from baselines
        """
        M = sample_vec.shape[0]
        sample_vec = sample_vec.bool()

        if attribute_background:
            # interpretable idx = input label directly
            valid = (labels >= 0) & (labels < M)
            idx   = torch.clamp(labels, 0, M-1)
            keep  = torch.ones_like(labels, dtype=torch.bool)
            keep[valid] = sample_vec[idx[valid]]
        else:
            # interpretable idx i -> input label i+1
            # background (0) is never perturbed
            valid = (labels >= 1) & (labels <= M)
            idx   = (labels - 1).clamp(min=0, max=M-1)
            keep  = torch.ones_like(labels, dtype=torch.bool)
            keep[valid] = sample_vec[idx[valid]]
            # all labels == 0 (background) remain True
        return keep

    
    keep_mask = _build_keep_mask(feature_map, curr_sample.to(device), attribute_background)
    
    return keep_mask.to(original_input.dtype) * original_input + (~keep_mask).to(original_input.dtype) * baseline, (~keep_mask).to(original_input.dtype).unsqueeze(0)  # perturbation mask

In [355]:
"""test_attribution_vector = np.random.normal(size=(5,))
print("attribution vector: ", test_attribution_vector)
test_feature_map = np.array([[0,2,3],[2,1,0], [4,4,4]])
print("feature map: ", test_feature_map)
# place each attribution in the corresponding place in the map
test_attribution_map = np.zeros_like(test_feature_map, dtype=float)
for i in range(test_feature_map.shape[0]):
    for j in range(test_feature_map.shape[1]):
        test_attribution_map[i, j] = test_attribution_vector[test_feature_map[i, j]]
print("attribution map: ", test_attribution_map)

resulting_attribution_vector = get_attribution_vector(test_feature_map, test_attribution_map, False)
print("attribution vector: ", resulting_attribution_vector)"""


'test_attribution_vector = np.random.normal(size=(5,))\nprint("attribution vector: ", test_attribution_vector)\ntest_feature_map = np.array([[0,2,3],[2,1,0], [4,4,4]])\nprint("feature map: ", test_feature_map)\n# place each attribution in the corresponding place in the map\ntest_attribution_map = np.zeros_like(test_feature_map, dtype=float)\nfor i in range(test_feature_map.shape[0]):\n    for j in range(test_feature_map.shape[1]):\n        test_attribution_map[i, j] = test_attribution_vector[test_feature_map[i, j]]\nprint("attribution map: ", test_attribution_map)\n\nresulting_attribution_vector = get_attribution_vector(test_feature_map, test_attribution_map, False)\nprint("attribution vector: ", resulting_attribution_vector)'

In [356]:
"""test_original_input = np.random.normal(size=(3,3))
test_original_input = torch.from_numpy(test_original_input).float().unsqueeze(0).to(device)
test_original_input"""

'test_original_input = np.random.normal(size=(3,3))\ntest_original_input = torch.from_numpy(test_original_input).float().unsqueeze(0).to(device)\ntest_original_input'

In [357]:
"""print("sorted attribution vector: ", np.sort(resulting_attribution_vector)[::-1])
print("sorted indices of attribution vector: ", np.argsort(resulting_attribution_vector)[::-1])
test_morf_generator, test_lerf_generator = generate_perturbations_ABPC(resulting_attribution_vector)
print("MoRF samples - perturbed input:")
for sample in test_morf_generator:
    print(sample)
    perturbed_input, perturbation_map = default_from_interp_rep_transform(
        sample, test_original_input, torch.from_numpy(test_feature_map), baseline=0.0, attribute_background=False)
    print("perturbed input: ", perturbed_input)
    print("perturbation map: ", perturbation_map)

print("LeRF samples - perturbed input:")
for sample in test_lerf_generator:
    print(sample)
    perturbed_input, perturbation_map = default_from_interp_rep_transform(
        sample, test_original_input, torch.from_numpy(test_feature_map), baseline=0.0, attribute_background=False)
    print("perturbed input: ", perturbed_input)
    print("perturbation map: ", perturbation_map)"""

'print("sorted attribution vector: ", np.sort(resulting_attribution_vector)[::-1])\nprint("sorted indices of attribution vector: ", np.argsort(resulting_attribution_vector)[::-1])\ntest_morf_generator, test_lerf_generator = generate_perturbations_ABPC(resulting_attribution_vector)\nprint("MoRF samples - perturbed input:")\nfor sample in test_morf_generator:\n    print(sample)\n    perturbed_input, perturbation_map = default_from_interp_rep_transform(\n        sample, test_original_input, torch.from_numpy(test_feature_map), baseline=0.0, attribute_background=False)\n    print("perturbed input: ", perturbed_input)\n    print("perturbation map: ", perturbation_map)\n\nprint("LeRF samples - perturbed input:")\nfor sample in test_lerf_generator:\n    print(sample)\n    perturbed_input, perturbation_map = default_from_interp_rep_transform(\n        sample, test_original_input, torch.from_numpy(test_feature_map), baseline=0.0, attribute_background=False)\n    print("perturbed input: ", pertur

In [358]:
def compute_area_between_curves(morf_curve, lerf_curve):
    """
    Compute the area between the LeRF and MoRF curves using the trapezoidal rule. In the literature, we divide
    the area by the number of steps to normalize it.
    """
    # Ensure the curves are of the same length
    assert len(morf_curve) == len(lerf_curve), "Curves must be of the same length"

    # Compute the area using the trapezoidal rule
    area = 0.0
    for i in range(1, len(morf_curve)):
        area += 0.5 * ((lerf_curve[i] - morf_curve[i]) + (lerf_curve[i-1] - morf_curve[i-1]))
    return area / len(morf_curve)


def compute_aopc(morf_curve):
    # Compute the Area Over the Perturbation Curve (AOPC)
    reference = morf_curve[0]

    area = 0.0
    for i in range(1, len(morf_curve)):
        area += 0.5 * ((reference - morf_curve[i]) + (reference - morf_curve[i-1]))
    return area / len(morf_curve)

def normalized_abpc(morf_curve, lerf_curve):
    abpc = compute_area_between_curves(morf_curve, lerf_curve)
    range = max(lerf_curve) - min(morf_curve)
    return abpc / range

def normalized_aopc(morf_curve):
    aopc = compute_aopc(morf_curve)
    range = max(morf_curve) - min(morf_curve)
    return aopc / range


In [359]:
def plot_curves(morf_curve, lerf_curve):
    import matplotlib.pyplot as plt

    plt.figure(figsize=(10, 6))
    plt.plot(morf_curve, label='MORF Curve', color='blue')
    plt.plot(lerf_curve, label='LERF Curve', color='red')
    plt.fill_between(range(len(morf_curve)), morf_curve, lerf_curve, where=(morf_curve > lerf_curve), color='lightblue', alpha=0.5)
    plt.title('MORF vs LERF Curves')
    plt.xlabel('Perturbation Steps')
    plt.ylabel('Attribution Score')
    plt.legend()
    plt.grid()
    plt.show()

In [360]:
from typing import Callable

# let's define the main block to compute one evaluation step, given the
# original volume, the interpretable input vector, and the forward function
def evaluate_step(original_volume: torch.Tensor,
                  interpretable_input: torch.Tensor,
                  feature_map: torch.Tensor,
                  forward_function: Callable,
                  predictor: CustomNNUNetPredictor,
                  attribute_background: bool):
    # obtain perturbed input
    perturbed_input, perturbation_map = default_from_interp_rep_transform( 
                                                 interpretable_input, 
                                                 original_volume,
                                                 feature_map,
                                                 baseline=0.0,
                                                 attribute_background=attribute_background
                                                )

    # compute the forward pass with the perturbed input
    output = forward_function(perturbed_input[0], perturbation_map, predictor)

    return output


def compute_ABPC_curves(original_volume: torch.Tensor,
                       feature_map: torch.Tensor,
                       attribution_map: np.ndarray,
                       forward_function: Callable,
                       morf_predictor: CustomNNUNetPredictor,
                       lerf_predictor: CustomNNUNetPredictor,
                       attribute_background: bool):
    
    """
    Compute the Area Between Perturbation Curves (ABPC) using MoRF and LeRF strategies.
    Parameters:
    - original_volume: torch.Tensor of shape (1, C, D, H, W)
    - feature_map: torch.Tensor of shape (D, H, W) with integer-labeled regions
    - attribution_map: np.ndarray of shape (D, H, W) with dense attribution values
    - forward_function: Callable that takes (input_volume, perturbation_mask, predictor) and returns a scalar output
    - morf_predictor: CustomNNUNetPredictor for MoRF evaluation
    - lerf_predictor: CustomNNUNetPredictor for LeRF evaluation
    - attribute_background: bool, whether to include background region in attribution vector
    Returns:
    - morf_curve: np.ndarray of MoRF evaluation outputs
    - lerf_curve: np.ndarray of LeRF evaluation outputs
    - ABPC_area: float, area between the MoRF and LeRF curves
    """
    # 1) get the attribution vector
    attribution_vector = get_attribution_vector(feature_map.cpu().numpy(), attribution_map.cpu().numpy(), attribute_background)
    print("Number of steps for the ABPC computation: ", len(attribution_vector))
    # 2) get the two generators of binary inputs
    morf_generator, lerf_generator = generate_perturbations_ABPC(attribution_vector)
    morf_curve, lerf_curve = [], []

    for morf_sample, lerf_sample in zip(morf_generator, lerf_generator):
        #print("MoRF sample: ", morf_sample)
        #print("LeRF sample: ", lerf_sample)
        # 3) compute the evaluation step for both samples
        morf_output = evaluate_step(original_volume, morf_sample, feature_map, forward_function, morf_predictor, attribute_background).cpu()
        lerf_output = evaluate_step(original_volume, lerf_sample, feature_map, forward_function, lerf_predictor, attribute_background).cpu()

        morf_curve.append(morf_output)
        lerf_curve.append(lerf_output)
        if WANDB:
            wandb.log({"MoRF": morf_output.item(), "LeRF": lerf_output.item()})

    morf_curve = np.array(morf_curve)
    lerf_curve = np.array(lerf_curve)
    # 4) compute the area between the two curves, and other metrics
    ABPC_area = compute_area_between_curves(morf_curve, lerf_curve)
    AOPC_area = compute_aopc(morf_curve)
    norm_ABPC_area = normalized_abpc(morf_curve, lerf_curve)
    norm_AOPC_area = normalized_aopc(morf_curve)

    print("ABPC area: ", ABPC_area)
    print("AOPC area: ", AOPC_area)
    print("Normalized ABPC area: ", norm_ABPC_area)
    print("Normalized AOPC area: ", norm_AOPC_area)
    if WANDB:
        wandb.log({"ABPC_area": ABPC_area.item(), 
                   "AOPC": AOPC_area.item(), 
                   "norm_ABPC": norm_ABPC_area.item(), 
                   "norm_AOPC": norm_AOPC_area.item()})

    # 5) plot and store the curves
    #plot_curves(morf_curve, lerf_curve)
    #store_curves(morf_curve, lerf_curve)

    # 6) return the computed values
    return morf_curve, lerf_curve, ABPC_area


In [361]:
# ...existing code...

def compute_ABPC_curves_with_volumes(original_volume: torch.Tensor,
                       feature_map: torch.Tensor,
                       attribution_map: np.ndarray,
                       forward_function: Callable,
                       morf_predictor: CustomNNUNetPredictor,
                       lerf_predictor: CustomNNUNetPredictor,
                       attribute_background: bool,
                       voxel_volume_mm3: float = 1.0):
    
    """
    Compute the Area Between Perturbation Curves (ABPC) using MoRF and LeRF strategies with volume tracking.
    
    Additional Parameters:
    - voxel_volume_mm3: float, volume of a single voxel in mm³ for proper volume calculation
    """
    # 1) get supervoxel volumes
    supervoxel_volumes = get_supervoxel_volumes(
        feature_map.cpu().numpy(), 
        voxel_volume_mm3=voxel_volume_mm3,
        include_background=attribute_background
    )
    #print(f"Supervoxel volumes (voxels, mm³): {supervoxel_volumes}")

    total_voxels, total_volume = get_total_volume(feature_map.cpu().numpy(), voxel_volume_mm3)
    print(f"Total volume: {total_volume:.2f} mm³ ({total_voxels} voxels of {voxel_volume_mm3:.2f} mm³ each)")
    
    # 2) get the attribution vector with volumes
    attribution_vector = get_attribution_vector(
        feature_map.cpu().numpy(), 
        attribution_map.cpu().numpy(), 
        include_background=attribute_background
    )
    
    print("Number of steps for the ABPC computation: ", len(attribution_vector))
    print(f"Total volume of supervoxels: {np.sum(supervoxel_volumes[1]):.2f} mm³")

    # 3) get the two generators with volume tracking
    morf_generator, lerf_generator = generate_perturbations_ABPC_with_volumes(
        attribution_vector, supervoxel_volumes
    )
    
    morf_curve, lerf_curve = [], []
    step = 0

    # Log initial total volume
    if WANDB:
        wandb.log({
            "total_volume_voxels": total_voxels,
            "total_volume_mm3": total_volume,
            "num_supervoxels": len(attribution_vector)
        }, commit=False)

    for (morf_sample, morf_cumul_vol), (lerf_sample, lerf_cumul_vol) in zip(morf_generator, lerf_generator):
        print(f"Step {step}: MoRF cumul. volume removed: {morf_cumul_vol[1]:.2f} mm³ ({morf_cumul_vol[0]} voxels), "
              f"LeRF cumul. volume removed: {lerf_cumul_vol[1]:.2f} mm³ ({lerf_cumul_vol[0]} voxels)")
        # 4) compute the evaluation step for both samples
        morf_output = evaluate_step(original_volume, morf_sample, feature_map, forward_function, morf_predictor, attribute_background).cpu()
        lerf_output = evaluate_step(original_volume, lerf_sample, feature_map, forward_function, lerf_predictor, attribute_background).cpu()

        morf_curve.append(morf_output)
        lerf_curve.append(lerf_output)
        
        # Calculate volume percentages
        morf_vol_pct = (morf_cumul_vol[0] / total_voxels) * 100.0
        lerf_vol_pct = (lerf_cumul_vol[0] / total_voxels) * 100.0

        if WANDB:
            wandb.log({
                "supervoxels_perturbed": step,
                "MoRF": morf_output.item(), 
                "LeRF": lerf_output.item(),
                "MoRF_volume_removed_voxels": morf_cumul_vol[0],
                "LeRF_volume_removed_voxels": lerf_cumul_vol[0],
                "MoRF_volume_removed_mm3": morf_cumul_vol[1],
                "LeRF_volume_removed_mm3": lerf_cumul_vol[1],
                "MoRF_volume_removed_pct": morf_vol_pct,
                "LeRF_volume_removed_pct": lerf_vol_pct
            }, commit=True)
        
        step += 1

    morf_curve = np.array(morf_curve)
    lerf_curve = np.array(lerf_curve)
    
    # 5) compute the area between the two curves, and other metrics
    ABPC_area = compute_area_between_curves(morf_curve, lerf_curve)
    AOPC_area = compute_aopc(morf_curve)
    norm_ABPC_area = normalized_abpc(morf_curve, lerf_curve)
    norm_AOPC_area = normalized_aopc(morf_curve)

    """print("ABPC area: ", ABPC_area)
    print("AOPC area: ", AOPC_area)
    print("Normalized ABPC area: ", norm_ABPC_area)
    print("Normalized AOPC area: ", norm_AOPC_area)"""
    
    if WANDB:
        wandb.log({
            "ABPC_area": ABPC_area.item(), 
            "AOPC": AOPC_area.item(), 
            "norm_ABPC": norm_ABPC_area.item(), 
            "norm_AOPC": norm_AOPC_area.item()
        }, commit=True)

    # 6) return the computed values
    return morf_curve, lerf_curve, ABPC_area

In [362]:
# some parameters for the experiment
# if true, we will use a json file containing the bounding box for the Region of Interest (ROI), otherwise we will use a masked segmentation
ROI_TYPE = "BoundingBox"  # "BoundingBox" or "MaskedSegmentation"

NNUNET_PREPROCESSING = False
SUPERVOXEL_TYPE = "full-organs" 
SLIC_N_SUPERVOXELS = 380

N_SAMPLES = 1000 if SUPERVOXEL_TYPE == "full-organs" else 2000
FCC_SUPERVOXEL_SIZE = 100 if SUPERVOXEL_TYPE == "FCC" else 50  # in mm

In [363]:
if ROI_TYPE == "BoundingBox":
    BB_ROI_paths = {}
    for volume_code in volume_codes:
        if IN_KAGGLE:
            BB_ROI_paths[volume_code] = f"{mount_dir}/BB-ROI/AUTOMI_{volume_code}.json"
        else:
            BB_ROI_paths[volume_code] = nnUNet_results + "/nnUNetTrainer__nnUNetPlans__3d_fullres/fold_0/BB-ROI/" + f"AUTOMI_{volume_code}.json"
elif ROI_TYPE == "MaskedSegmentation":
    # get the manually derived ROI mask from the dataset, where we manually added it
    if IN_KAGGLE:
        ROI_segmentation_mask_path = "/kaggle/input/segmentation-masked-ROI.nii"
    else:
        ROI_segmentation_mask_path = nnUNet_raw + "/segmentation-masked-ROI.nii"

In [364]:
volume_codes = ["00004", "00005", "00024", "00027", "00029", "00034", "00039", "00044"]
#volume_codes = ["00004"]
#volume_codes = ["00039"]

In [365]:
aggregation_functions = [true_positive_aggregation, false_positive_aggregation, dice_aggregation, logit_difference_aggregation]
#aggregation_functions = [dice_aggregation, logit_difference_aggregation]
#aggregation_functions = [dice_aggregation]

In [366]:
from pathlib import Path

attribution_map_paths = {}
for volume_code in volume_codes:
    for agg_type in aggregation_functions:
        if agg_type.__name__ not in attribution_map_paths:
            attribution_map_paths[agg_type.__name__] = {}
        # for false positive, take the inverted signs for better interpretation of the metric (ORIGINAL DEFINITION OF FP AGGREGATION HAD POSITIVE SIGN FOR LOGITS -> NEGATIVE SIGN IN INFLUENTIAL SUPERVOXELS)
        suffix = "_signed" if agg_type.__name__ == "false_positive_aggregation" else ""
        attribution_map_paths[agg_type.__name__][volume_code] = Path(join(exp_results_path, SUPERVOXEL_TYPE, agg_type.__name__, volume_code, "attribution_map" + suffix + ".nii.gz"))

In [367]:
for volume_code in volume_codes:
    for agg_type in aggregation_functions:
        assert attribution_map_paths[agg_type.__name__][volume_code].exists(), f"Attribution map not found: {attribution_map_paths[agg_type.__name__][volume_code]}"

In [368]:
skip_exp = [] # already done

In [None]:
for i, volume_code in enumerate(volume_codes):
    print(f"Working on volume code {i}: {volume_code}")
    print(f"  Shape: {nib.load(ct_img_paths[volume_code]).shape}")

    os.mkdir(volume_code) if not os.path.exists(volume_code) else None

    volume, affine, supervoxel_map, segmentation_mask, ROI_mask, cache_dict = prepare_data_for_shap(
        volume_path=ct_img_paths[volume_code],
        patch_size=patch_size,
        predictor=predictor,
        dataset_json_path=dataset_json_path,
        ROI_BB_path=BB_ROI_paths[volume_code] if ROI_TYPE == "BoundingBox" else None,
        ROI_segmentation_mask_path=ROI_segmentation_mask_path if ROI_TYPE == "MaskedSegmentation" else None,
        ROI_type=ROI_TYPE,
        nnunet_preprocessing=NNUNET_PREPROCESSING,
        supervoxel_type=SUPERVOXEL_TYPE,
        fcc_supervoxel_size=FCC_SUPERVOXEL_SIZE,
        slic_n_supervoxels=SLIC_N_SUPERVOXELS,
        organ_map_path=organ_map_paths[volume_code],
        device=device,
    )

    # Compute voxel volume for this specific volume
    voxel_volume_mm3 = compute_voxel_volume_mm3(nib.load(ct_img_paths[volume_code]))
    print(f"  Voxel volume: {voxel_volume_mm3:.4f} mm³")
    
    for agg_fn in aggregation_functions:
        print(f"  Using aggregation function: {agg_fn.__name__}")
        if (volume_code, agg_fn.__name__) in skip_exp:
            print("  Skipping this experiment as already done")
            continue

        if WANDB:
            # Detect device name
            if device.type == "cuda":
                GPU_NAME = torch.cuda.get_device_name(device)
            else:
                GPU_NAME = "CPU"
                
            wandb_run = wandb.init(
                project="automi",  # your W&B project name
                #name="run_01",                 # optional, name for this run
                config={
                        "device_name": GPU_NAME,
                        "volume_code": volume_code,
                        "n_samples": N_SAMPLES,
                        "supervoxel_type": SUPERVOXEL_TYPE,
                        "fcc_supervoxel_size": FCC_SUPERVOXEL_SIZE,
                        "slic_n_supervoxels": SLIC_N_SUPERVOXELS,
                        "nnunet_preprocessing": NNUNET_PREPROCESSING,
                        "aggregation_function": agg_fn.__name__,
                        "group": "ABPC-volumes",
                        "voxel_volume_mm3": voxel_volume_mm3,
                    }
            )

        # 0) get the pre-computed attribution map
        attribution_map = torch.from_numpy(nib.load(attribution_map_paths[agg_fn.__name__][volume_code]).get_fdata().transpose(2,1,0)).float().to(device)

        
        # a) wrap the aggregation function
        forward_func = lambda vol, mask, predictor: forward_segmentation_output_to_explain(
            input_image=vol,
            perturbation_mask=mask,
            segmentation_mask=segmentation_mask,
            ROI_bounding_box_mask=ROI_mask,
            baseline_prediction_dict=cache_dict,
            predictor=predictor,
            aggregation_fn=agg_fn
        )

        # get the two curves and the area between them
        aopc_curve, abpc_curve, aopc_area = compute_ABPC_curves_with_volumes(
            original_volume=volume,
            feature_map=supervoxel_map,
            attribution_map=attribution_map,
            forward_function=forward_func,
            morf_predictor=morf_predictor,
            lerf_predictor=lerf_predictor,
            attribute_background=(SUPERVOXEL_TYPE not in ["FCC-organs", "full-organs"]),
            voxel_volume_mm3=voxel_volume_mm3
        )

        if WANDB:
            wandb_run.finish()


Working on volume code 0: 00004
  Shape: (512, 512, 221)
Saved binary ROI mask to ROI_binary_mask.nii.gz
supervoxel map saved!
  Voxel volume: 9.3460 mm³
  Using aggregation function: true_positive_aggregation
Total volume: 251223095.70 mm³ (26880256 voxels of 9.35 mm³ each)
Number of steps for the ABPC computation:  9
Total volume of supervoxels: 6718742.56 mm³
Step 0: MoRF cumul. volume removed: 0.00 mm³ (0.0 voxels), LeRF cumul. volume removed: 0.00 mm³ (0.0 voxels)
Step 1: MoRF cumul. volume removed: 891945.65 mm³ (95436.0 voxels), LeRF cumul. volume removed: 510058.40 mm³ (54575.0 voxels)
Step 2: MoRF cumul. volume removed: 1071519.85 mm³ (114650.0 voxels), LeRF cumul. volume removed: 2460103.03 mm³ (263225.0 voxels)
Step 3: MoRF cumul. volume removed: 1145727.16 mm³ (122590.0 voxels), LeRF cumul. volume removed: 4971711.92 mm³ (531961.0 voxels)
Step 4: MoRF cumul. volume removed: 1474678.61 mm³ (157787.0 voxels), LeRF cumul. volume removed: 5166090.20 mm³ (552759.0 voxels)
Step 5