### Use this for testing MONAI transforms

### Libraries and auxiliar functions

In [7]:
import sys
import os

current_directory = os.getcwd()
path = os.path.dirname(current_directory)
sys.path.append(path)
from Utils import *

%matplotlib widget
from ipywidgets import interact, interactive, widgets
from matplotlib.patches import Rectangle, Circle, Arrow

import glob
import os

import tempfile
import nibabel as nib
import matplotlib.pyplot as plt
from typing import Optional, Any, Mapping, Hashable

import monai
from monai.config import print_config
from monai.utils import first
from monai.config import KeysCollection
from monai.data import Dataset, ArrayDataset, create_test_image_3d, DataLoader
from monai.transforms import (
    AdjustContrastd,
    Transform,
    Compose,
    LoadImage,
    Orientation,
    ConcatItemsd,
    LoadImaged,
    Orientationd,
    EnsureChannelFirstd,
    EnsureChannelFirst,
    ToTensord,
    Spacingd,
    ScaleIntensityd,
    CropForegroundd,
    RandCropByLabelClassesd,
    RandCropByPosNegLabeld,
    RandFlipd,
    RandZoomd,
)

In [8]:
class Param():
    def __init__(self, data_dir, pixel_dim, window_size, orientation, in_channels, out_channels, input_type, label_type):
        self.data_dir = data_dir
        self.pixel_dim = pixel_dim
        self.window_size = window_size
        self.axcodes = orientation
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.input_type = input_type
        self.label_type = label_type

def generateLabeledFileList(param, prefix):
    print('Reading labeled images from: ' + param.data_dir)
    images_m = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_M.nii.gz")))
    images_p = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_P.nii.gz")))
    images_r = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_R.nii.gz")))
    images_i = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_I.nii.gz")))
    labels = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_labels", "*_"+param.label_type+"_label.nii.gz")))
    
    # Use two types of images combined
    if param.in_channels==2:
        # Use real and imaginary images
        if param.input_type=='R' or param.input_type=='I':
            print('Use real/imaginary images')
            data_dicts = [
                {"image_1": image_r_name, "image_2": image_i_name, "label":label_name}
                for image_r_name, image_i_name, label_name in zip(images_r, images_i, labels)
            ]
        # Use magnitude and phase images
        else:
            print('Use magnitude/phase images')
            data_dicts = [
                {"image_1": image_m_name, "image_2": image_p_name, "label":label_name}
                for image_m_name, image_p_name, label_name in zip(images_m, images_p, labels)
            ]
    # Use only one type of image        
    else:
        # Use real images
        if param.input_type=='R':
            print('Use real images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_r, labels)
            ]
        # Use imaginary images
        elif param.input_type=='I':
            print('Use imaginary images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_i, labels)
            ]
        # Use phase images
        elif param.input_type=='P':
            print('Use phase images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_p, labels)
            ]
        # Use magnitude images
        else:
            print('Use magnitude images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_m, labels)
            ]
    return data_dicts    

In [19]:
# Build param (info from config.ini)
data_dir = os.getcwd()
pixel_dim = (3.6, 1.171875, 1.171875)
window_size = (3, 48, 48)
orientation = 'PIL'
in_channels = 2
out_channels = 3
input_type = 'MP'
label_type = 'multi'
param = Param(data_dir, pixel_dim, window_size, orientation, in_channels, out_channels, input_type, label_type)

# Create dictionary
print('Create dictionary')
train_files = generateLabeledFileList(param, 'test')
print(train_files[0])

Create dictionary
Reading labeled images from: /Users/pl771/Devel/MRINeedleSegmentation/TestingNotebook
Use magnitude/phase images
{'image_1': '/Users/pl771/Devel/MRINeedleSegmentation/TestingNotebook/test_images/SyntheticImage_018_M.nii.gz', 'image_2': '/Users/pl771/Devel/MRINeedleSegmentation/TestingNotebook/test_images/SyntheticImage_018_P.nii.gz', 'label': '/Users/pl771/Devel/MRINeedleSegmentation/TestingNotebook/test_labels/SyntheticImage_018_multi_label.nii.gz'}


In [140]:
# CREATE NEW CUSTOM ITK OBJECT LOADER

import itk
from monai.data import ITKReader, MetaTensor
import torch
from monai.config import DtypeLike
from monai.utils import convert_to_dst_type
from monai.utils import ImageMetaKey as Key

class LoadITKObject(Transform):
    def __init__(self,
            image_only: bool = False,
            dtype: DtypeLike or None = np.float32,
            ensure_channel_first: bool = False,
            simple_keys: bool = False,
            prune_meta_pattern: str or None = None,
            prune_meta_sep: str = ".",
            expanduser: bool = True,         
        ) -> None:
        self.itkReader = ITKReader()
        self.image_only = image_only
        self.ensure_channel_first = ensure_channel_first
        self.dtype = dtype
        self.simple_keys = simple_keys
        self.pattern = prune_meta_pattern
        self.sep = prune_meta_sep
        self.expanduser = expanduser

    def __call__(self, img:itk.itkImagePython.itkImageF3):
        if isinstance(img, itk.itkImagePython.itkImageF3) == 0:
            raise RuntimeError(f"{self.__class__.__name__} The input image is not an ITK object.\n")
        img_array, meta_data = self.itkReader.get_data(img)
        img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]
        if not isinstance(meta_data, dict):
            raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.")
        # Here I changed from original LoadImage to use tensor instead of numpy array (img_array) 
        # so the result is similar to loading the nifti file with LoadImage
        img = MetaTensor.ensure_torch_and_prune_meta(
            torch.from_numpy(img_array), meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep
        )
        if self.ensure_channel_first:
            img = EnsureChannelFirst()(img)
        if self.image_only:
            return img
        return img, img.meta if isinstance(img, MetaTensor) else meta_data

In [123]:
# LOAD WITH STANDARD IMAGE LOADER TO COMPARE
data_list = [train_files[0]['image_1'], train_files[0]['image_2']]

load_file = Compose([LoadImage(ensure_channel_first=False)])
output_original = load_file(data_list)
metatensor_1 = output_original[0][0]
metatensor_2 = output_original[1][0]
print(metatensor_1.data)


metatensor([[[0.3549, 0.3380, 0.3352],
         [0.3519, 0.3691, 0.3125],
         [0.3634, 0.3661, 0.3520],
         ...,
         [0.2452, 0.2340, 0.1959],
         [0.2745, 0.2437, 0.2493],
         [0.3192, 0.2761, 0.2805]],

        [[0.3495, 0.3437, 0.3300],
         [0.3317, 0.3535, 0.3084],
         [0.3516, 0.3400, 0.3551],
         ...,
         [0.2688, 0.2568, 0.1992],
         [0.2424, 0.2224, 0.1931],
         [0.2454, 0.2236, 0.2136]],

        [[0.3308, 0.3461, 0.3214],
         [0.3149, 0.3304, 0.3110],
         [0.3439, 0.3199, 0.3542],
         ...,
         [0.3304, 0.3056, 0.2455],
         [0.2756, 0.2674, 0.2069],
         [0.2273, 0.2261, 0.2214]],

        ...,

        [[0.0157, 0.0208, 0.0273],
         [0.0123, 0.0225, 0.0241],
         [0.0125, 0.0160, 0.0182],
         ...,
         [0.0106, 0.0113, 0.0094],
         [0.0064, 0.0089, 0.0059],
         [0.0040, 0.0076, 0.0041]],

        [[0.0150, 0.0149, 0.0265],
         [0.0136, 0.0189, 0.0179],
        

monai.transforms.io.array LoadImage.__init__:image_only: Current default value of argument `image_only=False` has been deprecated since version 1.1. It will be changed to `image_only=True` in version 1.3.


In [141]:
# TEST NEW CUSTOM LOADER

itk_image_1 = itk.imread(train_files[0]['image_1']).astype(itk.F)
itk_image_2 = itk.imread(train_files[0]['image_2']).astype(itk.F)
data_list = [itk_image_1, itk_image_2]

load_obj = Compose([LoadITKObject(ensure_channel_first=False)])
output_original = load_obj(data_list)
metatensor_1 = output_original[0][0]
metatensor_2 = output_original[1][0]
print(metatensor_1.data)


metatensor([[[0.3549, 0.3380, 0.3352],
         [0.3519, 0.3691, 0.3125],
         [0.3634, 0.3661, 0.3520],
         ...,
         [0.2452, 0.2340, 0.1959],
         [0.2745, 0.2437, 0.2493],
         [0.3192, 0.2761, 0.2805]],

        [[0.3495, 0.3437, 0.3300],
         [0.3317, 0.3535, 0.3084],
         [0.3516, 0.3400, 0.3551],
         ...,
         [0.2688, 0.2568, 0.1992],
         [0.2424, 0.2224, 0.1931],
         [0.2454, 0.2236, 0.2136]],

        [[0.3308, 0.3461, 0.3214],
         [0.3149, 0.3304, 0.3110],
         [0.3439, 0.3199, 0.3542],
         ...,
         [0.3304, 0.3056, 0.2455],
         [0.2756, 0.2674, 0.2069],
         [0.2273, 0.2261, 0.2214]],

        ...,

        [[0.0157, 0.0208, 0.0273],
         [0.0123, 0.0225, 0.0241],
         [0.0125, 0.0160, 0.0182],
         ...,
         [0.0106, 0.0113, 0.0094],
         [0.0064, 0.0089, 0.0059],
         [0.0040, 0.0076, 0.0041]],

        [[0.0150, 0.0149, 0.0265],
         [0.0136, 0.0189, 0.0179],
        

In [142]:
inference_transform = Compose([LoadITKObject(), EnsureChannelFirstd(keys=["image_1", "image_2"])])
output_original = inference_transform(data_list)


RuntimeError: applying transform <monai.transforms.utility.dictionary.EnsureChannelFirstd object at 0x2a741cd90>

In [11]:
# Define preprocessing transforms
# Load images
if param.in_channels==2:
    # Two channels input
    transform_array = [
        LoadImaged(keys=["image_1", "image_2", "label"]),
        EnsureChannelFirstd(keys=["image_1", "image_2", "label"]), # Mariana: AddChanneld(keys=["image", "label"]) deprecated, use EnsureChannelFirst instead
        ConcatItemsd(keys=["image_1", "image_2"], name="image")
    ]
else:
    # One channel input
    transform_array = [            
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"], channel_dim='no_channel'), # Mariana: AddChanneld(keys=["image", "label"]) deprecated, use EnsureChannelFirst instead
    ]
    
# Intensity adjustment
if (param.input_type == 'R') or (param.input_type == 'I'):
    transform_array.append(AdjustContrastd(keys=["image"], gamma=2.5))
    transform_array.append(ScaleIntensityd(keys=["image", "label"], minv=0, maxv=1, channel_wise=True))

# Spatial adjustments
transform_array.append(Orientationd(keys=["image", "label"], axcodes=param.axcodes))
transform_array.append(Spacingd(keys=["image", "label"], pixdim=param.pixel_dim, mode=("bilinear", "nearest")))

# Plot original
loadTest = Compose(transform_array)
output_original = loadTest(train_files)

for i in range(len(output_original)):
    if len(output_original)==1:
        output_dict = output_original[0]
    else:
        output_dict = output_original[0][i]
    # output_dict = output[0]
    image = output_dict['image']
    label = output_dict['label']
    # image = output_dict[0]['image']
    # label = output_dict[0]['label']
    image_array = np.array(image)
    label_array = np.array(label)
    print('Image shape: '+ str(image_array.shape))
    print('Label shape: '+ str(label_array.shape))
    show_array(image_array[0,:,:,:], title='Ch1 Original')
    if image_array.shape[0]==2:
        show_array(image_array[1,:,:,:], title='Ch2 Original')
    show_array(label_array[0,:,:,:], title='Label Original')

# Data augmentation
transform_array.append(RandZoomd(
    keys=['image', 'label'],
    prob=0.1,
    min_zoom=1.0,
    max_zoom=1.3,
    mode=['area', 'nearest'],
))
transform_array.append(RandFlipd(
    keys=['image', 'label'],
    prob=0.5,
    spatial_axis=2,
))
# Balance background/foreground
transform_array.append(RandCropByPosNegLabeld(
    keys=["image", "label"],
    label_key="label",
    spatial_size=param.window_size,
    pos=5, 
    neg=1,
    num_samples=5,
    image_key="image",
    image_threshold=0, 
))

transfTest = Compose(transform_array)
output = transfTest(train_files)
N = len(output[0])
for i in range(N):
    output_dict = output[0][i]
    image = output_dict['image']
    label = output_dict['label']
    image_array = np.array(image)
    label_array = np.array(label)
    print('Image shape: '+ str(image_array.shape))
    print('Label shape: '+ str(label_array.shape))
    show_array(image_array[0,:,:,:], title='Ch1 '+ str(i+1))
    if image_array.shape[0]==2:
        show_array(image_array[1,:,:,:], title='Ch2 '+ str(i+1))
    show_array(label_array[0,:,:,:], title='Label '+ str(i+1))




Image shape: (2, 3, 192, 192)
Label shape: (1, 3, 192, 192)


interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

Image shape: (2, 3, 48, 48)
Label shape: (1, 3, 48, 48)


interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

Image shape: (2, 3, 48, 48)
Label shape: (1, 3, 48, 48)


interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

Image shape: (2, 3, 48, 48)
Label shape: (1, 3, 48, 48)


interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

Image shape: (2, 3, 48, 48)
Label shape: (1, 3, 48, 48)


interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

Image shape: (2, 3, 48, 48)
Label shape: (1, 3, 48, 48)


interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

### Run code to be tested

In [10]:
# import torch
# from monai.networks.nets import UNet

# # Define the UNet architecture
# model_unet = UNet(
#     spatial_dims=3,
#     in_channels=1,
#     out_channels=2,
#     channels=[16, 32, 64, 128],
#     strides=[(1, 2, 2), (1, 2, 2), (1, 1, 1)],
#     num_res_units=2,
# )

# # Create an example input tensor
# input_tensor = torch.randn(1, 1, 3, 192, 192)

# # Pass the input tensor through the UNet model
# output_tensor = model_unet(input_tensor)

# # Print the size of the output tensor
# print(output_tensor.size())  # Output: torch.Size([1, 2, 3, 192, 192])
