## **Initialization**


**Google Drive Mount**

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

**Libraries Import**

In [None]:
import numpy as np
import os
from PIL import Image
import skimage.util as util
import skimage.io as img
from skimage import exposure
from tqdm import tqdm
from zipfile import ZipFile
import random
import shutil

from IPython.display import clear_output

**Paths Definition**

In [None]:
# Work folder
main_folder = "/content/drive/MyDrive/Colab Notebooks/Progetto EIM/"
script_folder = main_folder

# Path of the images dataset (already divided into train, val and test set and organized specifically for the architecture of MMSegmentation)
# The zip is created in VSCode to avoid the need to reorganize the dataset every time the notebook is restarted and to get a faster execution
zip_file_path_sets = "/content/drive/MyDrive/Progetto EIM/ConstrAndTestSet_ForMMSegmentation.zip"

extract_folder_sets = "/content/ProstateMRI/"
os.makedirs(extract_folder_sets, exist_ok=True)
ProstateMRI_dataset_folder = extract_folder_sets

# Folder for saving MMSegmentation results
results_folder = "/content/results_mmseg/"


**Zip Extraction**

In [None]:
# Extracting zip file containing images from zip located at zip_file_path_sets
with ZipFile(zip_file_path_sets, 'r') as zip_ref:
    # Get the total number of files in the zip file
    total_files = len(zip_ref.infolist())

    # Create a progress bar using tqdm
    with tqdm(total=total_files, unit="file") as pbar:
        for member in zip_ref.infolist():
            # Extract each file individually and update the progress bar
            zip_ref.extract(member, extract_folder_sets)
            pbar.update(1)


**MMSegmentation Setup**

In [None]:
# N.B ::: These are the basic installs needed for mmsegmentation, add the ones you need to run your script being careful about the dependencies

# Check nvcc version
!nvcc -V
# Check GCC version
!gcc --version

# Install PyTorch
!pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
# Install mmcv
!pip install mmcv==2.0.0rc4 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html
# Install mmsegmentation
!pip install mmsegmentation
# Install mmengine
!pip install mmengine

# Other installs
!pip install ftfy

clear_output()

In [None]:
# Check Pytorch installation
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())

# Check MMSegmentation installation
import mmseg
print(mmseg.__version__)

**Libraries Import**

In [None]:
# Required imports definition (put all imports you need in this cell)
import os
import numpy as np
import torch

from mmseg.apis import init_model, inference_model
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm

from mmengine import Config

from mmseg.registry import DATASETS
from mmseg.datasets import BaseSegDataset

from skimage import morphology

**ProstateMRI Dataset Class Creation**

In [None]:
# Example of Dataset configuration, can be edited based on your strategy, make sure it's coherent with your config file

classes = ('background',
        'tumor')

paletteProstate = [
    (0, 0, 0), # background - black
    (255, 255, 255), # tumor - white
]

@DATASETS.register_module()
class ProstateMRI(BaseSegDataset):
    METAINFO = dict(classes = classes, palette = paletteProstate)
    def __init__(self, **kwargs):
        super().__init__(img_suffix='.png',
                        seg_map_suffix='.png',
                        reduce_zero_label = False,
                        **kwargs)

**Load of config file for MMSegmentation**

In [None]:
# Path of the config file
cfg_path = script_folder + 'config_FUNZIONANTE.py'

# Load of the config file
cfg = Config.fromfile(cfg_path)

**Selection of folders containing images of train, val and test sets and manual masks**

In [None]:
cfg.train_dataloader['dataset']['data_prefix'] = dict(img_path='stacked_dir/train/', seg_map_path='ann_dir/train/')
cfg.val_dataloader['dataset']['data_prefix'] = dict(img_path='stacked_dir/val/', seg_map_path='ann_dir/val/')
cfg.test_dataloader['dataset']['data_prefix'] = dict(img_path='stacked_dir/test/', seg_map_path='ann_dir/test/')

**Default Hyperparameters Modification**

In [None]:
cfg.compile=False

cfg.data_root = ProstateMRI_dataset_folder
cfg.save_dir = results_folder
cfg.work_dir = results_folder

cfg.train_dataloader['dataset']['data_root'] = ProstateMRI_dataset_folder
cfg.val_dataloader['dataset']['data_root']   = ProstateMRI_dataset_folder
cfg.test_dataloader['dataset']['data_root']  = ProstateMRI_dataset_folder

cfg.visualizer['save_dir'] = results_folder
cfg.visualizer['vis_backends'] = [
    dict(type='LocalVisBackend'),
    dict(type='TensorboardVisBackend')
]

cfg.visualizer['save_dir'] = results_folder
cfg.visualizer['vis_backends'] = [
    dict(type='LocalVisBackend'),
    dict(type='TensorboardVisBackend')
]

In [None]:
# PreProcessing Transform Class Creation

from skimage import util
import skimage.exposure as exposure
import numpy as np
from scipy import ndimage

from mmcv.transforms import BaseTransform, TRANSFORMS
import mmseg.datasets.transforms as mmsegTransforms

@TRANSFORMS.register_module()
class PreProcessing(BaseTransform):

    # Initialization of the class
    def __init__(self,
                EqualizzazioneIstogrammaFlagADC, EqualizzazioneIstogrammaFlagHBV, EqualizzazioneIstogrammaFlagT2W,
                FiltraggioGaussianoFlagADC, FiltraggioGaussianoFlagHBV, FiltraggioGaussianoFlagT2W, sigmaADC, sigmaHBV, sigmaT2W,
                ModificaContrastoFlagADC, ModificaContrastoFlagHBV, ModificaContrastoFlagT2W, percContrADC, percContrHBV, percContrT2W,
                MinMaxScalingFlagADC, MinMaxScalingFlagHBV, MinMaxScalingFlagT2W
                ):
        super().__init__()
        # Gaussian filtering
        self.FiltraggioGaussianoFlagADC = FiltraggioGaussianoFlagADC
        self.FiltraggioGaussianoFlagHBV = FiltraggioGaussianoFlagHBV
        self.FiltraggioGaussianoFlagT2W = FiltraggioGaussianoFlagT2W
        self.sigmaADC = sigmaADC
        self.sigmaHBV = sigmaHBV
        self.sigmaT2W = sigmaT2W
        # Histogram equalization
        self.EqualizzazioneIstogrammaFlagADC = EqualizzazioneIstogrammaFlagADC
        self.EqualizzazioneIstogrammaFlagHBV = EqualizzazioneIstogrammaFlagHBV
        self.EqualizzazioneIstogrammaFlagT2W = EqualizzazioneIstogrammaFlagT2W
        # Contrast modification
        self.ModificaContrastoFlagADC = ModificaContrastoFlagADC
        self.ModificaContrastoFlagHBV = ModificaContrastoFlagHBV
        self.ModificaContrastoFlagT2W = ModificaContrastoFlagT2W
        self.percContrADC = percContrADC
        self.percContrHBV = percContrHBV
        self.percContrT2W = percContrT2W
        # Min-Max Scaling
        self.MinMaxScalingFlagADC = MinMaxScalingFlagADC
        self.MinMaxScalingFlagHBV = MinMaxScalingFlagHBV
        self.MinMaxScalingFlagT2W = MinMaxScalingFlagT2W

    # Adaptive histogram equalization
    def EqualizzazioneIstogramma(self, img):
        img = exposure.equalize_adapthist(img)
        return img

    # Gaussian Filter
    def FiltraggioGaussiano(self, img, sigma):
        img = ndimage.gaussian_filter(img, sigma)
        return img

    # Contrast Modification
    def ModificaContrasto(self, img, percContr):
        img = util.img_as_float(img)
        luminanza = np.sum(img)/img.size
        diff = img-luminanza
        img = img + diff*percContr/100
        img = np.clip(img,0,1)
        return img

    # Min-Max Scaling
    def MinMaxScaling(self, img):
        if img.max() == img.min():
            return img
        else:
            img = (img - img.min()) / (img.max() - img.min())
            return img


    # uint8 Conversion
    def uint8Conversion(self, img):
        img = util.img_as_ubyte(img)
        return img

    # Application of transformations
    def transform(self, results: dict) -> dict:
        img = results['img']

        # Application of transformations to each image channel
        for i in range(img.shape[2]):
            channel_img = img[:,:,i]

            if i == 0:  # Apply preprocessing to ADC channel
                channel_img = self.FiltraggioGaussiano(channel_img, self.sigmaADC) if self.FiltraggioGaussianoFlagADC else channel_img
                channel_img = self.EqualizzazioneIstogramma(channel_img) if self.EqualizzazioneIstogrammaFlagADC else channel_img
                channel_img = self.ModificaContrasto(channel_img, self.percContrADC) if self.ModificaContrastoFlagADC else channel_img
                channel_img = self.MinMaxScaling(channel_img) if self.MinMaxScalingFlagADC else channel_img
                channel_img = self.uint8Conversion(channel_img)

            elif i == 1:  # Apply different preprocessing HBV channel
                channel_img = self.FiltraggioGaussiano(channel_img, self.sigmaHBV) if self.FiltraggioGaussianoFlagHBV else channel_img
                channel_img = self.EqualizzazioneIstogramma(channel_img) if self.EqualizzazioneIstogrammaFlagHBV else channel_img
                channel_img = self.ModificaContrasto(channel_img, self.percContrHBV) if self.ModificaContrastoFlagHBV else channel_img
                channel_img = self.MinMaxScaling(channel_img) if self.MinMaxScalingFlagHBV else channel_img
                channel_img = self.uint8Conversion(channel_img)

            else:  # Apply different preprocessing T2W channel
                channel_img = self.FiltraggioGaussiano(channel_img, self.sigmaT2W) if self.FiltraggioGaussianoFlagT2W else channel_img
                channel_img = self.EqualizzazioneIstogramma(channel_img) if self.EqualizzazioneIstogrammaFlagT2W else channel_img
                channel_img = self.ModificaContrasto(channel_img, self.percContrT2W) if self.ModificaContrastoFlagT2W else channel_img
                channel_img = self.MinMaxScaling(channel_img) if self.MinMaxScalingFlagT2W else channel_img
                channel_img = self.uint8Conversion(channel_img)

            img[:,:,i] = channel_img

        results['img'] = img
        return results

In [None]:
# Preprocessing for ADC channel
FiltraggioGaussianoFlagADC, sigmaADC = False, 1.5
EqualizzazioneIstogrammaFlagADC = False
ModificaContrastoFlagADC, percContrADC = False, 50
MinMaxScalingFlagADC = True

# Preprocessing for HBV channel
FiltraggioGaussianoFlagHBV, sigmaHBV = False, 1.5
EqualizzazioneIstogrammaFlagHBV = False
ModificaContrastoFlagHBV, percContrHBV = False, 50
MinMaxScalingFlagHBV = True

# Preprocessing for T2W channel
FiltraggioGaussianoFlagT2W, sigmaT2W = False, 1.5
EqualizzazioneIstogrammaFlagT2W = False
ModificaContrastoFlagT2W, percContrT2W = False, 50
MinMaxScalingFlagT2W = True

In [None]:
cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='PreProcessing',
         FiltraggioGaussianoFlagADC=FiltraggioGaussianoFlagADC, sigmaADC=sigmaADC,
         EqualizzazioneIstogrammaFlagADC=EqualizzazioneIstogrammaFlagADC,
         ModificaContrastoFlagADC=ModificaContrastoFlagADC, percContrADC=percContrADC,
         MinMaxScalingFlagADC=MinMaxScalingFlagADC,
         FiltraggioGaussianoFlagHBV=FiltraggioGaussianoFlagHBV, sigmaHBV=sigmaHBV,
         EqualizzazioneIstogrammaFlagHBV=EqualizzazioneIstogrammaFlagHBV,
         ModificaContrastoFlagHBV=ModificaContrastoFlagHBV, percContrHBV=percContrHBV,

         MinMaxScalingFlagHBV=MinMaxScalingFlagHBV,
         FiltraggioGaussianoFlagT2W=FiltraggioGaussianoFlagT2W, sigmaT2W=sigmaT2W,
         EqualizzazioneIstogrammaFlagT2W=EqualizzazioneIstogrammaFlagT2W,
         ModificaContrastoFlagT2W=ModificaContrastoFlagT2W, percContrT2W=percContrT2W,
         MinMaxScalingFlagT2W=MinMaxScalingFlagT2W),
    dict(type='LoadAnnotations'),
    dict(prob=0.5, type='RandomFlip'),
    dict(degree=(-15.0,15.0), prob=0.5, type='RandomRotate'),
    dict(type='PackSegInputs'),
]

cfg.train_dataloader.dataset.pipeline= cfg.train_pipeline

cfg.val_dataloader.dataset.pipeline =[
            dict(type='LoadImageFromFile'),
            dict(type='PreProcessing',
                FiltraggioGaussianoFlagADC=FiltraggioGaussianoFlagADC, sigmaADC=sigmaADC,
                EqualizzazioneIstogrammaFlagADC=EqualizzazioneIstogrammaFlagADC,
                ModificaContrastoFlagADC=ModificaContrastoFlagADC, percContrADC=percContrADC,
                MinMaxScalingFlagADC=MinMaxScalingFlagADC,
                FiltraggioGaussianoFlagHBV=FiltraggioGaussianoFlagHBV, sigmaHBV=sigmaHBV,
                EqualizzazioneIstogrammaFlagHBV=EqualizzazioneIstogrammaFlagHBV,
                ModificaContrastoFlagHBV=ModificaContrastoFlagHBV, percContrHBV=percContrHBV,
                MinMaxScalingFlagHBV=MinMaxScalingFlagHBV,
                FiltraggioGaussianoFlagT2W=FiltraggioGaussianoFlagT2W, sigmaT2W=sigmaT2W,
                EqualizzazioneIstogrammaFlagT2W=EqualizzazioneIstogrammaFlagT2W,
                ModificaContrastoFlagT2W=ModificaContrastoFlagT2W, percContrT2W=percContrT2W,
                MinMaxScalingFlagT2W=MinMaxScalingFlagT2W),
            dict(type='LoadAnnotations'),
            dict(type='PackSegInputs'),
        ]

cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='PackSegInputs'),
]

cfg.test_dataloader.dataset.pipeline = cfg.test_pipeline

In [None]:
# Epoch Runner (instead of Iteration Runner)
# Parameters of config files changed according to MMSegmentation tutorial in order to use the Epoch Runner instead of the Iteration Runner

cfg.param_scheduler = dict(
    by_epoch=True,
    milestones=[6, 8],
    type='MultiStepLR'
)

cfg.default_hooks.logger.log_metric_by_epoch = True
cfg.default_hooks.checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=True)

cfg.train_cfg = dict(by_epoch=True, max_epochs=50, val_interval=2) # by_epoch=True or type='EpochBasedTrainLoop'instead of type='IterBasedTrainLoop'

cfg.log_processor = dict(by_epoch=True)

cfg.train_dataloader.sampler = None

## **Hyperparameter Modification + Training**

**Net Changes**: write here things to change and test

In [None]:
# Batch Size
cfg.train_dataloader['batch_size'] = 8

# Loss Function
cfg.model.decode_head.loss_decode=dict(type='TverskyLoss', alpha=0.3, beta=0.7)

# Backbone Strides
# cfg.model.backbone.strides = (1, 1, 1, 1)


**Training**

In [None]:
# Build Runner
from mmengine.runner import Runner
runner = Runner.from_cfg(cfg)

clear_output()

In [None]:
# Start training
runner.train()

In [None]:
# Start testing
# runner.test()

## **Volumetric Inference**

**Unzip test set zip: it's organized as the available dataset and as the test set blind**

In [None]:
import os
import numpy as np
from tqdm import tqdm
from zipfile import ZipFile
from PIL import Image

from mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv

# Test Set Path Definition
testSetZipPath = os.path.join('/content/drive/MyDrive/Progetto EIM/volumetricTestSet.zip')

testSetPath = os.path.join('/content/testSet/')

# Unzip
with ZipFile(testSetZipPath, 'r') as zip_ref:
    # Get the total number of files in the zip file
    total_files = len(zip_ref.infolist())

    # Create a progress bar using tqdm
    with tqdm(total=total_files, unit="file") as pbar:
        for member in zip_ref.infolist():
            # Extract each file individually and update the progress bar
            zip_ref.extract(member, testSetPath)
            pbar.update(1)

**Definition of PreProcessing Functions**

In [None]:
import os
import numpy as np
from tqdm import tqdm
from PIL import Image

def findBorders(t2w):
    # FIND BLACK BORDERS IN T2W CHANNEL
    # Make t2w boolean
    t2w_bool = t2w > 0
    # Find non-black rows and columns
    rows = np.any(t2w_bool, axis=1)
    cols = np.any(t2w_bool, axis=0)
    # Get first and last non-black row and column indexes
    rmin, rmax = np.where(rows)[0][[0, -1]] if np.any(rows) else (0, t2w.shape[0])
    cmin, cmax = np.where(cols)[0][[0, -1]] if np.any(cols) else (0, t2w.shape[1])
    return rmin, rmax, cmin, cmax

def cropResizeImages(adc, hbv, t2w):
    # CROP IMAGES
    # Get original image size
    original_shape = t2w.shape
    # Find black borders in t2w channel
    rmin_t2w, rmax_t2w, cmin_t2w, cmax_t2w = findBorders(t2w)
    # Get final cropped image size
    cropDimension = min(rmax_t2w - rmin_t2w, cmax_t2w - cmin_t2w)
    if cropDimension % 2 != 0: # ensure crop_size is even
        cropDimension -= 1
    if abs(t2w.shape[0] - cropDimension) < 5: # do not crop if the black region is too small
        cropDimension = t2w.shape[0]
        flagCrop = False
        # resize images to 256x256
        resizeShape = (256, 256)
        adc_resized = np.array(Image.fromarray(adc, mode='L').resize(resizeShape, Image.LANCZOS), dtype=np.uint8)
        hbv_resized = np.array(Image.fromarray(hbv, mode='L').resize(resizeShape, Image.LANCZOS), dtype=np.uint8)
        t2w_resized = np.array(Image.fromarray(t2w, mode='L').resize(resizeShape, Image.LANCZOS), dtype=np.uint8)
        # Create the dictionary to store resize metadata
        cropMetadata = {'flagCrop': flagCrop, 'original_shape': original_shape, 'resizeShape': resizeShape, 'cropDimension': cropDimension}
        return adc_resized, hbv_resized, t2w_resized, cropMetadata
    else:
        flagCrop = True

    # Get center of the non-black region
    center_x = (cmin_t2w + cmax_t2w) // 2
    center_y = (rmin_t2w + rmax_t2w) // 2
    # Get start and end coordinates for cropping
    start_x = max(center_x - cropDimension // 2, 0)
    start_y = max(center_y - cropDimension // 2, 0)
    end_x = min(start_x + cropDimension, t2w.shape[1])
    end_y = min(start_y + cropDimension, t2w.shape[0])

    # Crop images at the center of the non-black region
    adc_cropped = adc[start_y:end_y, start_x:end_x]
    hbv_cropped = hbv[start_y:end_y, start_x:end_x]
    t2w_cropped = t2w[start_y:end_y, start_x:end_x]

    cropped_shape = t2w_cropped.shape

    # RESIZE IMAGES
    # Resize cropped images to 256x256
    resizeShape = (256, 256)
    adc_cropped_resized = np.array(Image.fromarray(adc_cropped, mode='L').resize(resizeShape, Image.LANCZOS), dtype=np.uint8)
    hbv_cropped_resized = np.array(Image.fromarray(hbv_cropped, mode='L').resize(resizeShape, Image.LANCZOS), dtype=np.uint8)
    t2w_cropped_resized = np.array(Image.fromarray(t2w_cropped, mode='L').resize(resizeShape, Image.LANCZOS), dtype=np.uint8)

    # Create a dictionary to store crop metadata
    cropMetadata = {'start_x': start_x, 'start_y': start_y, 'end_x': end_x, 'end_y': end_y, 'cropDimension': cropDimension, 'cropped_shape': cropped_shape, 'center_x': center_x, 'center_y': center_y, 'original_shape': original_shape, 'resizeShape': resizeShape, 'flagCrop': flagCrop}

    return adc_cropped_resized, hbv_cropped_resized, t2w_cropped_resized, cropMetadata

def restoreSizeMask(mask, cropMetadata):
    # RESTORE AUTOMATIC MASK TO ORIGINAL SIZE
    # Check if the image was not cropped
    if not cropMetadata['flagCrop']:
        mask_restored = np.array(Image.fromarray(mask, mode='L').resize(cropMetadata['original_shape'], Image.NEAREST), dtype=np.uint8)
        return mask_restored

    else:
        mask_restored = np.zeros(cropMetadata['original_shape'], dtype=np.uint8)
        mask_cropDim = np.array(Image.fromarray(mask, mode='L').resize(cropMetadata['cropped_shape'], Image.NEAREST), dtype=np.uint8)
        mask_restored[cropMetadata['start_y']:cropMetadata['end_y'], cropMetadata['start_x']:cropMetadata['end_x']] = mask_cropDim

        return mask_restored

def minMaxScaling(img):
    if np.max(img) != np.min(img):
        img = (img - np.min(img)) / (np.max(img) - np.min(img))
        img = np.array(img*255, dtype=np.uint8)

    return img



**Definition of PostProcessing Function (Interpolation)**

In [None]:
# Definition of Function for Interpolation between two masks
from scipy.ndimage import distance_transform_edt

def interp_mask(mask1, mask2):
    
    d1 = distance_transform_edt(mask1) - distance_transform_edt(~mask1)
    d2 = distance_transform_edt(mask2) - distance_transform_edt(~mask2)
    interpolated_mask=(d1+d2) > 0
    interpolated_mask = np.array(interpolated_mask, dtype=np.uint8)

    return interpolated_mask

**Inference**

In [None]:
# Loop over the test images
checkpoint_path = os.path.join(results_folder, 'epoch_50.pth')
model = init_model(cfg, checkpoint_path, 'cuda:0')


patientList = os.listdir(testSetPath)
# Initialize a dict to store cropDimension metadata
cropDimensionList = {}
contatoreMaskDaInterpolare = 0


for patient in tqdm(patientList):

  # Create a dictionary key of the patient name
  cropDimensionList[patient] = {}

  os.makedirs(os.path.join(testSetPath, patient, "automatic"), exist_ok=True)

  sliceList = os.listdir(os.path.join(testSetPath, patient, "stacked"))

  # Initialization of elements needed for post-processing
  tumorVolume = 0
  numTumoralSlices = 0
  hasTumor = False
  listTumoralSlices = []


  for slice in sliceList:
    if slice.endswith('.png'):

      ###############################
      #######      Load      ########
      ###############################

      # adc_path = os.path.join(test_img_folder, patient, "adc", slice)
      # hbv_path = os.path.join(test_img_folder, patient, "hbv", slice)
      # t2w_path = os.path.join(test_img_folder, patient, "t2w", slice)

      # adc = np.array(Image.open(adc_path))
      # hbv = np.array(Image.open(hbv_path))
      # t2w = np.array(Image.open(t2w_path))

      stacked_path = os.path.join(testSetPath, patient, "stacked", slice)

      stacked = np.array(Image.open(stacked_path))

      adc = stacked[:,:,0]
      hbv = stacked[:,:,1]
      t2w = stacked[:,:,2]

      ###############################
      ####### Pre-Processing ########
      ###############################

      # Crop
      adc, hbv, t2w, cropMetadata = cropResizeImages(adc, hbv, t2w)

      # Store cropDimension metadata in the dictionary list
      cropDimensionList[patient][slice] = cropMetadata['cropDimension']

      # Min-Max Scaling
      adc = minMaxScaling(adc)
      hbv = minMaxScaling(hbv)
      t2w = minMaxScaling(t2w)

      # Stack the images
      stacked = np.stack((adc, hbv, t2w), axis=-1)

      ###############################
      #######    Inference   ########
      ###############################

      result = inference_model(model, stacked)

      # Get data from the result
      pred_label = result.pred_sem_seg.data.squeeze()
      pred_label = pred_label.cpu().numpy().astype(np.uint8)

      ###############################
      ### Computation of elements ###
      ###       needed for        ###
      ###     post-processing     ###
      ###############################

      # Computation of elements needed for post-processing
      if np.sum(pred_label) > 0:
          hasTumor = True
          numTumoralSlices += 1
          tumorVolume += np.sum(pred_label.astype(bool))
          listTumoralSlices.append(slice)

      ###############################
      #######   Save results  #######
      ###############################

      # Restore the original size
      # pred_label = np.array(pred_label)
      pred_label = restoreSizeMask(pred_label, cropMetadata)
      pred_label = np.array(pred_label, dtype=np.uint8)

      # Save the result
      pred_label = Image.fromarray(pred_label, mode='L')

      pred_label.save(os.path.join(testSetPath, patient, "automatic", slice))


  ###############################
  ####### Post-Processing #######
  ###############################


#   if hasTumor:
#     exampleSliceStacked = np.array(Image.open(os.path.join(testSetPath, patient, "stacked", sliceList[0])))
#     exampleSlice = exampleSliceStacked[:,:,0]
#     shapeSlices = exampleSlice.shape

#     # # Check tumor volume
#     # if tumorVolume < 13:
#     #   blackMask = Image.fromarray(np.zeros((shapeSlices[0], shapeSlices[1]), dtype=np.uint8)).convert('L')
#     #   for slice in sliceList:
#     #     if slice.endswith('.png'):
#     #       blackMask.save(os.path.join(testSetPath, patient, "automatic", slice))

#     # # Check number of tumoral slices
#     # if numTumoralSlices == 1:
#     #   blackMask = Image.fromarray(np.zeros((shapeSlices[0], shapeSlices[1]), dtype=np.uint8)).convert('L')
#     #   for slice in sliceList:
#     #     if slice.endswith('.png'):
#     #       blackMask.save(os.path.join(testSetPath, patient, "automatic", slice))

#     # Check if there are black slices between two tumoral slices --> if so, interpolate all the black slices with the tumoral slices before and after
#     # Set a maximum number of black slices to look for between two tumoral slices
#     maxBlackSlices = 100
#     listTumoralSlices.sort()
#     # Extract the number of the tumoral slices from the list. They are called slice_n.png
#     listTumoralSlices = [int(slice.split('.')[0].split('_')[1]) for slice in listTumoralSlices]
#     # Check how many black slices there are between two tumoral slices using difference between consecutive elements
#     diffTumoralSlices = np.diff(listTumoralSlices) - 1
#     for i in range(len(diffTumoralSlices)):
#       if diffTumoralSlices[i] > 0:
#         if diffTumoralSlices[i] <= maxBlackSlices:
#           for j in range(1, diffTumoralSlices[i]+1):
#             contatoreMaskDaInterpolare += 1
#             # interpolate the mask of the black slice with interp_shape function
#             top = np.array(Image.open(os.path.join(testSetPath, patient, "automatic", f"slice_{listTumoralSlices[i]}.png")))
#             bottom = np.array(Image.open(os.path.join(testSetPath, patient, "automatic", f"slice_{listTumoralSlices[i+1]}.png")))
#             black = interp_shape(top, bottom, 0.5)
#             black = Image.fromarray(black, mode='L')

#             black = interp_mask(top, bottom)
#             black = Image.fromarray(black*255, mode='L')
#             black.save(os.path.join(testSetPath, patient, "automatic", f"slice_{listTumoralSlices[i]+j}.png"))

# print(f"Numero di maschere da interpolare: {contatoreMaskDaInterpolare}")

**Metrics Computation**

In [None]:
import cv2
from scipy.spatial.distance import directed_hausdorff

patientsList = os.listdir(testSetPath)

# Initialization of evaluation metrics
DiceList = []

DiffVolumeList = []
AbsErrorList = []
RelDiffVolumeList = []

HausdorffDistanceList = []

tumPatientTP = 0
hltPatientTN = 0
tumPatientTotal = 0
hltPatientTotal = 0

# Loop over the patients
for patient in tqdm(patientsList):
  # Get the list of slices
  sliceList = os.listdir(os.path.join(testSetPath, patient, "stacked"))
  # Get the shape of the patient's slices
  exampleSlicePath = os.path.join(testSetPath, patient, "manual", sliceList[0])
  exampleSlice = np.array(Image.open(exampleSlicePath))
  shapeSlice = exampleSlice.shape

  # Initialization of 3D masks
  ManMask3D_bool = np.zeros((shapeSlice[0], shapeSlice[1], len(sliceList)), dtype=bool)
  AutoMask3D_bool = np.zeros((shapeSlice[0], shapeSlice[1], len(sliceList)), dtype=bool)

  # Initialization of lists to store 3D contour points (for HD)
  contourPoints3D_Man = []
  contourPoints3D_Auto = []

  # Loop over the slices to create 3D masks
  for actSlice in sliceList:
    actManMask = os.path.join(testSetPath, patient, "manual", actSlice)
    actAutoMask = os.path.join(testSetPath, patient, "automatic", actSlice)

    ManMask3D_bool[:,:,sliceList.index(actSlice)] = np.array(Image.open(os.path.join(testSetPath, patient, "manual", actSlice))).astype(bool)
    AutoMask3D_bool[:,:,sliceList.index(actSlice)] = np.array(Image.open(os.path.join(testSetPath, patient, "automatic", actSlice))).astype(bool)


  # CALCULATION OF EVALUATION METRICS
  # Calculation of elements necessary for the calculation of metrics
  # For Dice
  TPmask = np.sum(ManMask3D_bool & AutoMask3D_bool)
  FPmask = np.sum(~ManMask3D_bool & AutoMask3D_bool)
  TNmask = np.sum(~ManMask3D_bool & ~AutoMask3D_bool)
  FNmask = np.sum(ManMask3D_bool & ~AutoMask3D_bool)

  # For Volume-based metrics
  totalVolumeMan = np.sum(ManMask3D_bool)
  totalVolumeAuto = np.sum(AutoMask3D_bool)
  diffVolume = totalVolumeMan - totalVolumeAuto

  # For Hausdorff Distance
  ManMask3D_uint8 = ManMask3D_bool.astype(np.uint8)*255
  AutoMask3D_uint8 = AutoMask3D_bool.astype(np.uint8)*255
  for s in range(len(ManMask3D_uint8[0,0,:])):
    actManMask = ManMask3D_uint8[:,:,s]
    actAutoMask = AutoMask3D_uint8[:,:,s]

    contours_Man = cv2.findContours(actManMask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours_Auto = cv2.findContours(actAutoMask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    for contour in contours_Man[0]:
      for point in contour:
        contourPoints3D_Man.append([point[0][0], point[0][1], s])
    for contour in contours_Auto[0]:
      for point in contour:
        contourPoints3D_Auto.append([point[0][0], point[0][1], s])

  # Dice
  if (2*TPmask + FPmask + FNmask)!=0:
    Dice = 2*TPmask/(2*TPmask + FPmask + FNmask)
    DiceList.append(Dice)

  # Volume-based
  DiffVolumePaziente = diffVolume
  DiffVolumeList.append(DiffVolumePaziente)

  AbsErrorPaziente = abs(DiffVolumePaziente)
  AbsErrorList.append(AbsErrorPaziente)

  if totalVolumeMan != 0:
    RelDiffVolume = diffVolume / totalVolumeMan
    RelDiffVolumeList.append(RelDiffVolume)
  elif totalVolumeMan == 0 and totalVolumeAuto != 0:
    pass
  else:
    RelDiffVolumeList.append(0)
  
  # Hausdorff Distance
  if contourPoints3D_Man and contourPoints3D_Auto:
    HausdorffDistance = max(directed_hausdorff(contourPoints3D_Man, contourPoints3D_Auto)[0], directed_hausdorff(contourPoints3D_Auto, contourPoints3D_Man)[0])
    HausdorffDistance = HausdorffDistance / np.mean(list(cropDimensionList[patient].values())) * 256 # normalize HD as explained in the report
    HausdorffDistanceList.append(HausdorffDistance)

  # Sensibility e Specificity
  if np.sum(ManMask3D_bool) > 0 and np.sum(AutoMask3D_bool) > 0:
    tumPatientTP += 1
  elif np.sum(ManMask3D_bool) == 0 and np.sum(AutoMask3D_bool) == 0:
    hltPatientTN += 1
  if np.sum(ManMask3D_bool) > 0:
    tumPatientTotal += 1
  else:
    hltPatientTotal += 1

# MEAN AND STD OF EVALUATION METRICS
MeanDice = np.mean(DiceList)
StdDice = np.std(DiceList)
print(f"\nDice: {MeanDice:.2f} +- {StdDice:.2f} (Mean +- Std)")

MeanDiffVolume = np.mean(DiffVolumeList)
MeanAbsError = np.mean(AbsErrorList)
MeanRelDiffVolume = np.mean(RelDiffVolumeList)

StdDiffVolume = np.std(DiffVolumeList)
StdAbsError = np.std(AbsErrorList)
StdRelDiffVolume = np.std(RelDiffVolumeList)

print(f"DiffVolume: {MeanDiffVolume:.2f} +- {StdDiffVolume:.2f} (Mean +- Std)")
print(f"AbsError: {MeanAbsError:.2f} +- {StdAbsError:.2f} (Mean +- Std)")
print(f"RelDiffVolume: {MeanRelDiffVolume:.2f} +- {StdRelDiffVolume:.2f} (Mean +- Std)")

MeanHausdorffDistance = np.mean(HausdorffDistanceList)
StdHausdorffDistance = np.std(HausdorffDistanceList)
print(f"Hausdorff Distance: {MeanHausdorffDistance:.2f} +- {StdHausdorffDistance:.2f} (Mean +- Std)")

if tumPatientTotal > 0:
  Sensibility = 100 * tumPatientTP / tumPatientTotal
else:
  Sensibility = 100

if hltPatientTotal > 0:
  Specificity = 100 * hltPatientTN / hltPatientTotal
else:
  Specificity = 100

print(f"SensibilityTum: {Sensibility:.2f}%")
print(f"SpecificitySani: {Specificity:.2f}%")


## **Plot**

In [None]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact

# Definition of function to load images
def load_images(patient_id, slice_id):
    stacked_path = os.path.join(testSetPath, patient_id, "stacked", f"slice_{slice_id}.png")
    mask_path = os.path.join(testSetPath, patient_id, "manual", f"slice_{slice_id}.png")
    auto_mask_path = os.path.join(testSetPath, patient_id, "automatic", f"slice_{slice_id}.png")

    img = np.array(Image.open(stacked_path))
    adc = img[:,:,0]
    hbv = img[:,:,1]
    t2w = img[:,:,2]
    mask = np.array(Image.open(mask_path)) * 255
    auto_mask = np.array(Image.open(auto_mask_path)) * 255

    return adc, hbv, t2w, mask, auto_mask

# Definition of function to display images in a subplot with matplotlib
def display_images(patient_id, slice_id):
    adc, hbv, t2w, mask, auto_mask = load_images(patient_id, slice_id)

    fig, axs = plt.subplots(3, 2, figsize=(12, 12))

    axs[0, 0].imshow(adc, cmap='gray')
    axs[0, 0].imshow(mask, cmap='gray', alpha=0.2)
    axs[0, 0].set_title('ADC with manual mask')
    axs[0, 0].axis('off')

    axs[0, 1].imshow(adc, cmap='gray')
    axs[0, 1].imshow(auto_mask, cmap='gray', alpha=0.2)
    axs[0, 1].set_title('ADC with automatic mask')
    axs[0, 1].axis('off')

    axs[1, 0].imshow(hbv, cmap='gray')
    axs[1, 0].imshow(mask, cmap='gray', alpha=0.2)
    axs[1, 0].set_title('HBV with manual mask')
    axs[1, 0].axis('off')

    axs[1, 1].imshow(hbv, cmap='gray')
    axs[1, 1].imshow(auto_mask, cmap='gray', alpha=0.2)
    axs[1, 1].set_title('HBV with automatic mask')
    axs[1, 1].axis('off')

    axs[2, 0].imshow(t2w, cmap='gray')
    axs[2, 0].imshow(mask, cmap='gray', alpha=0.2)
    axs[2, 0].set_title('T2W with manual mask')
    axs[2, 0].axis('off')

    axs[2, 1].imshow(t2w, cmap='gray')
    axs[2, 1].imshow(auto_mask, cmap='gray', alpha=0.2)
    axs[2, 1].set_title('T2W with automatic mask')
    axs[2, 1].axis('off')

    plt.show()

# Definition of function to get the maximum number of slices for a patient
def max_slices_for_patient(patient_id):
    stacked_path = os.path.join(testSetPath, patient_id, "stacked")
    return len([name for name in os.listdir(stacked_path) if os.path.isfile(os.path.join(stacked_path, name))])

# Get list of patients of the test set
patients = sorted(os.listdir(testSetPath))

# Create widgets for patient selection and slice selection
patient_selector = widgets.Select(options=patients, description="Paziente:")
slice_slider = widgets.IntSlider(min=1, max=1, value=1, description="Slice:")

def update_slice_range(*args):
    patient_id = patient_selector.value
    max_slices = max_slices_for_patient(patient_id)
    slice_slider.max = max_slices

patient_selector.observe(update_slice_range, 'value')

update_slice_range()

@interact(patient_id=patient_selector, slice_id=slice_slider)
def explore_data(patient_id, slice_id):
    display_images(patient_id, str(slice_id))