### Use this for testing MONAI transforms

### Libraries and auxiliar functions

In [None]:
import sys
import os

current_directory = os.getcwd()
path = os.path.dirname(current_directory)
sys.path.append(path)
from Utils import *

%matplotlib widget
from ipywidgets import interact, interactive, widgets
from matplotlib.patches import Rectangle, Circle, Arrow

import glob
import os

import tempfile
import nibabel as nib
import matplotlib.pyplot as plt
from typing import Optional, Any, Mapping, Hashable

import monai
from monai.config import print_config
from monai.utils import first, ensure_tuple
from monai.config import KeysCollection
from monai.data import Dataset, ArrayDataset, create_test_image_3d, DataLoader
from monai.transforms import (
    AdjustContrastd,
    Transform,
    Compose,
    LoadImage,
    Orientation,
    ConcatItemsd,
    LoadImaged,
    Orientationd,
    EnsureChannelFirstd,
    EnsureChannelFirst,
    ToTensord,
    Spacingd,
    ScaleIntensityd,
    CropForegroundd,
    RandCropByLabelClassesd,
    RandCropByPosNegLabeld,
    RandFlipd,
    RandZoomd,
)

In [None]:
class Param():
    def __init__(self, data_dir, pixel_dim, window_size, orientation, in_channels, out_channels, input_type, label_type):
        self.data_dir = data_dir
        self.pixel_dim = pixel_dim
        self.window_size = window_size
        self.axcodes = orientation
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.input_type = input_type
        self.label_type = label_type

def generateLabeledFileList(param, prefix):
    print('Reading labeled images from: ' + param.data_dir)
    images_m = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_M.nii.gz")))
    images_p = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_P.nii.gz")))
    images_r = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_R.nii.gz")))
    images_i = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_I.nii.gz")))
    labels = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_labels", "*_"+param.label_type+"_label.nii.gz")))
    # Use two types of images combined
    if param.in_channels==2:
        # Use real and imaginary images
        if param.input_type=='R' or param.input_type=='I':
            print('Use real/imaginary images')
            data_dicts = [
                {"image_1": image_r_name, "image_2": image_i_name, "label":label_name}
                for image_r_name, image_i_name, label_name in zip(images_r, images_i, labels)
            ]
        # Use magnitude and phase images
        else:
            print('Use magnitude/phase images')
            data_dicts = [
                {"image_1": image_m_name, "image_2": image_p_name, "label":label_name}
                for image_m_name, image_p_name, label_name in zip(images_m, images_p, labels)
            ]
    # Use only one type of image        
    else:
        # Use real images
        if param.input_type=='R':
            print('Use real images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_r, labels)
            ]
        # Use imaginary images
        elif param.input_type=='I':
            print('Use imaginary images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_i, labels)
            ]
        # Use phase images
        elif param.input_type=='P':
            print('Use phase images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_p, labels)
            ]
        # Use magnitude images
        else:
            print('Use magnitude images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_m, labels)
            ]
    return data_dicts    

In [None]:
# Build param (info from config.ini)
data_dir = os.getcwd()
pixel_dim = (3.6, 1.171875, 1.171875)
window_size = (3, 48, 48)
orientation = 'PIL'
in_channels = 2
out_channels = 3
input_type = 'MP'
label_type = 'multi'
param = Param(data_dir, pixel_dim, window_size, orientation, in_channels, out_channels, input_type, label_type)

# Create dictionary
print('Create dictionary')
train_files = generateLabeledFileList(param, 'test')
print(train_files[0]['image_1'])
print(train_files[0]['image_2'])

In [None]:
# CREATE NEW CUSTOM ITK OBJECT LOADER

from collections.abc import Sequence
import SimpleITK as sitk
from monai.data import MetaTensor, ImageReader, ITKReader
from monai.data.utils import orientation_ras_lps, is_no_channel
import torch
from monai.config import DtypeLike, PathLike, KeysCollection
from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep, MetaKeys, SpaceKeys, TraceKeys
from monai.utils import ImageMetaKey as Key
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from abc import ABC, abstractmethod
from monai.transforms import Compose, Transform, MapTransform
from monai.utils.enums import PostFix
DEFAULT_POST_FIX = PostFix.meta()

class sitkReader(ImageReader):
    def __init__(
            self,
            series_name: str = "",
            reverse_indexing: bool = False,
            series_meta: bool = False,
            affine_lps_to_ras: bool = True,
            **kwargs,
        ):
        super().__init__()
        self.kwargs = kwargs
        self.series_name = series_name
        self.reverse_indexing = reverse_indexing
        self.series_meta = series_meta
        self.affine_lps_to_ras = affine_lps_to_ras
    
    def read(self, img):
        return img
    
    def verify_suffix(self, img) -> bool:
        return True
    
    def get_data(self, img) -> tuple[np.ndarray, dict]:
        img_array: list[np.ndarray] = []
        compatible_meta: dict = {}
        data = self._get_array_data(img)
        img_array.append(data)
        header = self._get_meta_dict(img)
        header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(img, self.affine_lps_to_ras)
        header[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS
        header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy()
        header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(img)
        # default to "no_channel" or -1
        header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1)
        self._copy_compatible_dict(header, compatible_meta)
        return self._stack_images(img_array, compatible_meta), compatible_meta
        
    def _get_meta_dict(self, img) -> dict:
        img_meta_dict = img.GetMetaDataKeys()
        meta_dict: dict = {}
        for key in img_meta_dict:
            if key.startswith("ITK_"):
                continue
            val = img.GetMetaData(key)
            meta_dict[key] = np.asarray(val) if type(val).__name__.startswith("itk") else val
        meta_dict["spacing"] = np.asarray(img.GetSpacing())
        return dict(meta_dict)

    def _get_affine(self, img, lps_to_ras: bool = True):
        dir_array = img.GetDirection()
        direction = np.array([dir_array[0:3],dir_array[3:6],dir_array[6:9]])
        spacing = np.asarray(img.GetSpacing())
        origin = np.asarray(img.GetOrigin())
        sr = min(max(direction.shape[0], 1), 3)
        affine: np.ndarray = np.eye(sr + 1)
        affine[:sr, :sr] = direction[:sr, :sr] @ np.diag(spacing[:sr])
        affine[:sr, -1] = origin[:sr]
        if lps_to_ras:
            affine = orientation_ras_lps(affine)
        return affine

    def _get_spatial_shape(self, img):
        ## Not handling multichannel images with SimpleITK
        dir_array = img.GetDirection()
        sr = np.array([dir_array[0:3],dir_array[3:6],dir_array[6:9]]).shape[0]
        sr = max(min(sr, 3), 1)
        _size = list(img.GetSize())
        return np.asarray(_size[:sr])

    def _get_array_data(self, img):
        ## Not handling multichannel images with SimpleITK
        np_img = sitk.GetArrayFromImage(img)
        return np_img if self.reverse_indexing else np_img.T
    
    def _stack_images(self, image_list: list, meta_dict: dict):
        if len(image_list) <= 1:
            return image_list[0]
        if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
            channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
            return np.concatenate(image_list, axis=channel_dim)
        # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
        meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
        return np.stack(image_list, axis=0)

    def _copy_compatible_dict(self, from_dict: dict, to_dict: dict):
        if not isinstance(to_dict, dict):
            raise ValueError(f"to_dict must be a Dict, got {type(to_dict)}.")
        if not to_dict:
            for key in from_dict:
                datum = from_dict[key]
                if isinstance(datum, np.ndarray) and np_str_obj_array_pattern.search(datum.dtype.str) is not None:
                    continue
                to_dict[key] = str(TraceKeys.NONE) if datum is None else datum  # NoneType to string for default_collate
        else:
            affine_key, shape_key = MetaKeys.AFFINE, MetaKeys.SPATIAL_SHAPE
            if affine_key in from_dict and not np.allclose(from_dict[affine_key], to_dict[affine_key]):
                raise RuntimeError(
                    "affine matrix of all images should be the same for channel-wise concatenation. "
                    f"Got {from_dict[affine_key]} and {to_dict[affine_key]}."
                )
            if shape_key in from_dict and not np.allclose(from_dict[shape_key], to_dict[shape_key]):
                raise RuntimeError(
                    "spatial_shape of all images should be the same for channel-wise concatenation. "
                    f"Got {from_dict[shape_key]} and {to_dict[shape_key]}."
            )
                
class LoadSitkImage(Transform):
    def __init__(self,
            image_only: bool = False,
            dtype: DtypeLike or None = np.float32,
            ensure_channel_first: bool = False,
            simple_keys: bool = False,
            prune_meta_pattern: str or None = None,
            prune_meta_sep: str = ".",   
        ) -> None:
        self.reader = sitkReader()
        self.image_only = image_only
        self.ensure_channel_first = ensure_channel_first
        self.dtype = dtype
        self.simple_keys = simple_keys
        self.pattern = prune_meta_pattern
        self.sep = prune_meta_sep

    def __call__(self, img):
        if not isinstance(img, sitk.SimpleITK.Image):
            raise RuntimeError(f"{self.__class__.__name__} The input image is not an ITK object.\n")    
        img_array, meta_data = self.reader.get_data(img)
        img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]
        if not isinstance(meta_data, dict):
            raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.")
        # Here I changed from original LoadImage to use tensor instead of numpy array (img_array) 
        # so the result is similar to loading the nifti file with LoadImage
        img = MetaTensor.ensure_torch_and_prune_meta(
            torch.from_numpy(img_array), meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep
        )
        if self.ensure_channel_first:
            img = EnsureChannelFirst()(img)
        if self.image_only:
            return img
        return img, img.meta if isinstance(img, MetaTensor) else meta_data


import itk 
class LoadITKImage(Transform):
    def __init__(self,
            image_only: bool = False,
            dtype: DtypeLike or None = np.float32,
            ensure_channel_first: bool = False,
            simple_keys: bool = False,
            prune_meta_pattern: str or None = None,
            prune_meta_sep: str = ".",   
        ) -> None:
        self.reader = ITKReader()
        self.image_only = image_only
        self.ensure_channel_first = ensure_channel_first
        self.dtype = dtype
        self.simple_keys = simple_keys
        self.pattern = prune_meta_pattern
        self.sep = prune_meta_sep

    def __call__(self, img):
        if not isinstance(img, itk.itkImagePython.itkImageF3):
            raise RuntimeError(f"{self.__class__.__name__} The input image is not an ITK object.\n")
        img_array, meta_data = self.reader.get_data(img)
        img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]
        if not isinstance(meta_data, dict):
            raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.")
        # Here I changed from original LoadImage to use tensor instead of numpy array (img_array) 
        # so the result is similar to loading the nifti file with LoadImage
        img = MetaTensor.ensure_torch_and_prune_meta(
            torch.from_numpy(img_array), meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep
        )
        if self.ensure_channel_first:
            img = EnsureChannelFirst()(img)
        if self.image_only:
            return img
        return img, img.meta if isinstance(img, MetaTensor) else meta_data

In [None]:
class LoadSitkImaged(MapTransform):
    def __init__(self,
            keys: KeysCollection,
            dtype: DtypeLike = np.float32,
            meta_keys: KeysCollection or None=None,
            meta_key_postfix: str=DEFAULT_POST_FIX,
            overwriting: bool=False,
            image_only: bool=False,
            ensure_channel_first: bool=False,
            simple_keys: bool=False,
            prune_meta_pattern: str or None=None,
            prune_meta_sep: str=".",
            allow_missing_keys: bool=False,
        ):
        super().__init__(keys, allow_missing_keys)
        self._loader = LoadSitkImage(
            image_only,
            dtype,
            ensure_channel_first,
            simple_keys,
            prune_meta_pattern,
            prune_meta_sep
        ) 
        if not isinstance(meta_key_postfix, str):
            raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.")
        self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
        if len(self.keys) != len(self.meta_keys):
            raise ValueError(
                f"meta_keys should have the same length as keys, got {len(self.keys)} and {len(self.meta_keys)}."
            )
        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
        self.overwriting = overwriting
        
        
    def __call__(self, img):
        d = dict(img)
        for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
            img = self._loader(d[key])
            if self._loader.image_only:
                d[key] = img
            else:
                if not isinstance(img, (tuple, list)):
                    raise ValueError(
                        f"loader must return a tuple or list (because image_only=False was used), got {type(data)}."
                    )
                d[key] = img[0]
                if not isinstance(img[1], dict):
                    raise ValueError(f"metadata must be a dict, got {type(img[1])}.")
                meta_key = meta_key or f"{key}_{meta_key_postfix}"
                if meta_key in d and not self.overwriting:
                    raise KeyError(f"Metadata with key {meta_key} already exists and overwriting=False.")
                d[meta_key] = img[1]
        return d


In [None]:
import itk

itk_img = itk.imread(train_files[0]['image_1']).astype(itk.F)
sitk_img = sitk.ReadImage(train_files[0]['image_1'], sitk.sitkFloat32)

# print(sitk_img.GetSpacing())
# print(itk.affine_lps_to_ras(itk_img))

itk_img.GetNumberOfComponentsPerPixel


dir_array = sitk_img.GetDirection()
sr = np.array([dir_array[0:3],dir_array[3:6],dir_array[6:9]]).shape[0]
sr = max(min(sr, 3), 1)
_size = list(sitk_img.GetSize())
print(np.asarray(_size[:sr]))



sr = itk.array_from_matrix(itk_img.GetDirection()).shape[0]
sr = max(min(sr, 3), 1)
_size = list(itk.size(itk_img))
print(np.asarray(_size[:sr]))
    
    
# print(itk.array_from_matrix(itk_img.GetDirection()))
# itk.array_from_matrix(img.GetDirection()).shape[0]

print(np.asarray(sitk_img.GetSize()))

In [None]:
# Test ImageLoad

# LOAD WITH STANDARD IMAGE LOADER
print(train_files[0]['image_1'])
print(train_files[0]['image_2'])

# Make list with two images filenames
data_list = [train_files[0]['image_1'], train_files[0]['image_2']]

load_file = Compose([LoadImage(image_only=False)])
output_original = load_file(data_list) # Output is a list of tuples. Each tuple, with a pair of metatensor and dict
# print((output_original[0][1])) # See metatensor for first output
# print((output_original[0][1])) # See dictionary for first output
metatensor_1 = output_original[0][0]
metatensor_2 = output_original[1][0]
print(metatensor_1.data.shape)
print(metatensor_1.data[0,0])
print(metatensor_2.data[0,0])

# LOAD WITH ITK IMAGE LOADER TO COMPARE
itk_image_1 = itk.imread(train_files[0]['image_1']).astype(itk.F)
itk_image_2 = itk.imread(train_files[0]['image_2']).astype(itk.F)

# Make list with two itk image objects
data_list = [itk_image_1, itk_image_2]

load_itk = Compose([LoadITKImage()])
output_original = load_itk(data_list)
metatensor_1 = output_original[0][0]
metatensor_2 = output_original[1][0]
print(metatensor_1.data.shape)
print(metatensor_1.data[0,0])
print(metatensor_2.data[0,0])

# TEST NEW CUSTOM SITK LOADER
sitk_image_1 = sitk.ReadImage(train_files[0]['image_1'], sitk.sitkFloat32)
sitk_image_2 = sitk.ReadImage(train_files[0]['image_2'], sitk.sitkFloat32)

# Make list with two sitk image objectss
data_list = [sitk_image_1, sitk_image_2]

load_sitk = Compose([LoadSitkImage()])
output_original = load_sitk(data_list)
metatensor_1 = output_original[0][0]
metatensor_2 = output_original[1][0]
print(metatensor_1.data.shape)
print(metatensor_1.data[0,0])
print(metatensor_2.data[0,0])



In [None]:
# Test ImageLoadd
data_list = generateLabeledFileList(param, 'test')

# LOAD WITH STANDARD IMAGE LOADER

load_file = Compose([LoadImaged(keys=['image_1', 'image_2'], image_only=False)])
output_original = load_file(data_list) # Output is a list of tuples. Each tuple, with a pair of metatensor and dict

metatensor_1 = output_original[0]['image_1']
metatensor_2 = output_original[1]['image_1']

print('Image_1 - Mag')
print(metatensor_1.data.shape)
print(metatensor_1.data[0,0])
print(metatensor_2.data[0,0])

metatensor_1 = output_original[0]['image_2']
metatensor_2 = output_original[1]['image_2']

print('Image_2 - Phase')
print(metatensor_1.data.shape)
print(metatensor_1.data[0,0])
print(metatensor_2.data[0,0])


In [None]:
# Test sitkLoadD
prefix = 'test'
images_m = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_M.nii.gz")))
images_p = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_P.nii.gz")))

data_list = []
for image_m_name, image_p_name in zip(images_m, images_p):
    sitk_image_1 = sitk.ReadImage(image_m_name, sitk.sitkFloat32)
    sitk_image_2 = sitk.ReadImage(image_p_name, sitk.sitkFloat32)
    data_list.append({"image_1": sitk_image_1, "image_2": sitk_image_2,})


# LOAD WITH STANDARD IMAGE LOADER

load_file = Compose([LoadSitkImaged(keys=['image_1', 'image_2'], image_only=False)])
output_original = load_file(data_list) # Output is a list of tuples. Each tuple, with a pair of metatensor and dict

metatensor_1 = output_original[0]['image_1']
metatensor_2 = output_original[1]['image_1']

print('Image_1 - Mag')
print(metatensor_1.data.shape)
print(metatensor_1.data[0,0])
print(metatensor_2.data[0,0])

metatensor_1 = output_original[0]['image_2']
metatensor_2 = output_original[1]['image_2']

print('Image_2 - Phase')
print(metatensor_1.data.shape)
print(metatensor_1.data[0,0])
print(metatensor_2.data[0,0])

In [None]:
inference_transform = Compose([LoadSitkImaged(keys=["image_1", "image_2"]), EnsureChannelFirstd(keys=["image_1", "image_2"])])
output_original = inference_transform(data_list)


In [None]:
# Define preprocessing transforms
# Load images
if param.in_channels==2:
    # Two channels input
    transform_array = [
        LoadImaged(keys=["image_1", "image_2", "label"]),
        EnsureChannelFirstd(keys=["image_1", "image_2", "label"]), # Mariana: AddChanneld(keys=["image", "label"]) deprecated, use EnsureChannelFirst instead
        ConcatItemsd(keys=["image_1", "image_2"], name="image")
    ]
else:
    # One channel input
    transform_array = [            
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"], channel_dim='no_channel'), # Mariana: AddChanneld(keys=["image", "label"]) deprecated, use EnsureChannelFirst instead
    ]
    
# Intensity adjustment
if (param.input_type == 'R') or (param.input_type == 'I'):
    transform_array.append(AdjustContrastd(keys=["image"], gamma=2.5))
    transform_array.append(ScaleIntensityd(keys=["image", "label"], minv=0, maxv=1, channel_wise=True))

# Spatial adjustments
transform_array.append(Orientationd(keys=["image", "label"], axcodes=param.axcodes))
transform_array.append(Spacingd(keys=["image", "label"], pixdim=param.pixel_dim, mode=("bilinear", "nearest")))

# Plot original
loadTest = Compose(transform_array)
output_original = loadTest(train_files)

for i in range(len(output_original)):
    if len(output_original)==1:
        output_dict = output_original[0]
    else:
        output_dict = output_original[0][i]
    # output_dict = output[0]
    image = output_dict['image']
    label = output_dict['label']
    # image = output_dict[0]['image']
    # label = output_dict[0]['label']
    image_array = np.array(image)
    label_array = np.array(label)
    print('Image shape: '+ str(image_array.shape))
    print('Label shape: '+ str(label_array.shape))
    show_array(image_array[0,:,:,:], title='Ch1 Original')
    if image_array.shape[0]==2:
        show_array(image_array[1,:,:,:], title='Ch2 Original')
    show_array(label_array[0,:,:,:], title='Label Original')

# Data augmentation
transform_array.append(RandZoomd(
    keys=['image', 'label'],
    prob=0.1,
    min_zoom=1.0,
    max_zoom=1.3,
    mode=['area', 'nearest'],
))
transform_array.append(RandFlipd(
    keys=['image', 'label'],
    prob=0.5,
    spatial_axis=2,
))
# Balance background/foreground
transform_array.append(RandCropByPosNegLabeld(
    keys=["image", "label"],
    label_key="label",
    spatial_size=param.window_size,
    pos=5, 
    neg=1,
    num_samples=5,
    image_key="image",
    image_threshold=0, 
))

transfTest = Compose(transform_array)
output = transfTest(train_files)
N = len(output[0])
for i in range(N):
    output_dict = output[0][i]
    image = output_dict['image']
    label = output_dict['label']
    image_array = np.array(image)
    label_array = np.array(label)
    print('Image shape: '+ str(image_array.shape))
    print('Label shape: '+ str(label_array.shape))
    show_array(image_array[0,:,:,:], title='Ch1 '+ str(i+1))
    if image_array.shape[0]==2:
        show_array(image_array[1,:,:,:], title='Ch2 '+ str(i+1))
    show_array(label_array[0,:,:,:], title='Label '+ str(i+1))




### Run code to be tested

In [None]:
# import torch
# from monai.networks.nets import UNet

# # Define the UNet architecture
# model_unet = UNet(
#     spatial_dims=3,
#     in_channels=1,
#     out_channels=2,
#     channels=[16, 32, 64, 128],
#     strides=[(1, 2, 2), (1, 2, 2), (1, 1, 1)],
#     num_res_units=2,
# )

# # Create an example input tensor
# input_tensor = torch.randn(1, 1, 3, 192, 192)

# # Pass the input tensor through the UNet model
# output_tensor = model_unet(input_tensor)

# # Print the size of the output tensor
# print(output_tensor.size())  # Output: torch.Size([1, 2, 3, 192, 192])
