### Use this for testing MONAI transforms

### Libraries and auxiliar functions

In [11]:
import sys
import os

current_directory = os.getcwd()
path = os.path.dirname(current_directory)
sys.path.append(path)
from Utils import *

%matplotlib widget
from ipywidgets import interact, interactive, widgets
from matplotlib.patches import Rectangle, Circle, Arrow

import glob
import os

import tempfile
import nibabel as nib
import matplotlib.pyplot as plt
from typing import Optional, Any, Mapping, Hashable

import monai
from monai.config import print_config
from monai.utils import first
from monai.config import KeysCollection
from monai.data import Dataset, ArrayDataset, create_test_image_3d, DataLoader
from monai.transforms import (
    AdjustContrastd,
    Transform,
    Compose,
    LoadImage,
    Orientation,
    ConcatItemsd,
    LoadImaged,
    Orientationd,
    EnsureChannelFirstd,
    EnsureChannelFirst,
    ToTensord,
    Spacingd,
    ScaleIntensityd,
    CropForegroundd,
    RandCropByLabelClassesd,
    RandCropByPosNegLabeld,
    RandFlipd,
    RandZoomd,
)

In [12]:
class Param():
    def __init__(self, data_dir, pixel_dim, window_size, orientation, in_channels, out_channels, input_type, num_samples):
        self.data_dir = data_dir
        self.pixel_dim = pixel_dim
        self.window_size = window_size
        self.axcodes = orientation
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.input_type = input_type
        self.num_samples = num_samples


def generateLabeledFileList(param, prefix):
    print('Reading labeled images from: ' + param.data_dir)
    images_m = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_M.nii.gz")))
    images_p = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_P.nii.gz")))
    images_r = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_R.nii.gz")))
    images_i = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_I.nii.gz")))
    labels = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_labels", "*_both_label.nii.gz")))
    
    # Use two types of images combined
    if param.in_channels==2:
        # Use real and imaginary images
        if param.input_type=='R' or param.input_type=='I':
            print('Use real/imaginary images')
            data_dicts = [
                {"image_1": image_r_name, "image_2": image_i_name, "label":label_name}
                for image_r_name, image_i_name, label_name in zip(images_r, images_i, labels)
            ]
        # Use magnitude and phase images
        else:
            print('Use magnitude/phase images')
            data_dicts = [
                {"image_1": image_m_name, "image_2": image_p_name, "label":label_name}
                for image_m_name, image_p_name, label_name in zip(images_m, images_p, labels)
            ]
    # Use only one type of image        
    else:
        # Use real images
        if param.input_type=='R':
            print('Use real images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_r, labels)
            ]
        # Use imaginary images
        elif param.input_type=='I':
            print('Use imaginary images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_i, labels)
            ]
        # Use phase images
        elif param.input_type=='P':
            print('Use phase images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_p, labels)
            ]
        # Use magnitude images
        else:
            print('Use magnitude images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_m, labels)
            ]
    return data_dicts    

In [13]:
# Build param (info from config.ini)
data_dir = os.getcwd()
pixel_dim = (3.6, 1.171875, 1.171875)
window_size = (3, 160, 160)
orientation = 'PIL'
in_channels = 2
out_channels = 2
input_type = 'M'
num_samples = 2
param = Param(data_dir, pixel_dim, window_size, orientation, in_channels, out_channels, input_type, num_samples)

# Create dictionary
print('Create dictionary')
train_files = generateLabeledFileList(param, 'test')



Create dictionary
Reading labeled images from: /Users/pl771/Devel/MRINeedleSegmentation/TestingNotebook
Use magnitude/phase images


In [31]:
# Define preprocessing transforms
# Load images
if param.in_channels==2:
    # Two channels input
    transform_array = [
        LoadImaged(keys=["image_1", "image_2", "label"]),
        EnsureChannelFirstd(keys=["image_1", "image_2", "label"]), # Mariana: AddChanneld(keys=["image", "label"]) deprecated, use EnsureChannelFirst instead
        ConcatItemsd(keys=["image_1", "image_2"], name="image")
    ]
else:
    # One channel input
    transform_array = [            
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"], channel_dim='no_channel'), # Mariana: AddChanneld(keys=["image", "label"]) deprecated, use EnsureChannelFirst instead
    ]
    
# Intensity adjustment
if (param.input_type == 'R') or (param.input_type == 'I'):
    transform_array.append(AdjustContrastd(keys=["image"], gamma=2.5))
transform_array.append(ScaleIntensityd(keys=["image", "label"], minv=0, maxv=1, channel_wise=True))

# Spatial adjustments
transform_array.append(Orientationd(keys=["image", "label"], axcodes=param.axcodes))
transform_array.append(Spacingd(keys=["image", "label"], pixdim=param.pixel_dim, mode=("bilinear", "nearest")))

# Plot original
loadTest = Compose(transform_array)
output_original = loadTest(train_files)

for i in range(len(output_original)):
    if len(output_original)==1:
        output_dict = output_original[0]
    else:
        output_dict = output_original[0][i]
    # output_dict = output[0]
    image = output_dict['image']
    label = output_dict['label']
    # image = output_dict[0]['image']
    # label = output_dict[0]['label']
    image_array = np.array(image)
    label_array = np.array(label)
    print('Image shape: '+ str(image_array.shape))
    print('Label shape: '+ str(label_array.shape))
    show_array(image_array[0,:,:,:], title='Ch1 Original')
    if image_array.shape[0]==2:
        show_array(image_array[1,:,:,:], title='Ch2 Original')
    show_array(label_array[0,:,:,:], title='Label Original')

# Data augmentation
transform_array.append(RandZoomd(
    keys=['image', 'label'],
    prob=0.5,
    min_zoom=1.1,
    max_zoom=1.5,
    mode=['area', 'nearest'],
))
transform_array.append(RandFlipd(
    keys=['image', 'label'],
    prob=0.5,
    spatial_axis=2,
))
# Balance background/foreground
# transform_array.append(RandCropByPosNegLabeld(
#     keys=["image", "label"],
#     label_key="label",
#     spatial_size=param.window_size,
#     pos=1, #0.8
#     neg=1, #0.2
#     num_samples=2,
#     image_key="image",
#     image_threshold=0, #0.05
# ))
transform_array.append(RandCropByLabelClassesd(
    keys=["image", "label"], 
    label_key="label", 
    spatial_size=param.window_size, 
    ratios=[1,3], 
    num_classes=2,
    num_samples=param.num_samples, 
    image_key="image", 
    image_threshold=0,
))


transfTest = Compose(transform_array)
output = transfTest(train_files)

for i in range(param.num_samples):
    # if len(output)==1:
    #     output_dict = output[0]
    # else:
    output_dict = output[0][i]
    image = output_dict['image']
    label = output_dict['label']
    image_array = np.array(image)
    label_array = np.array(label)
    print('Image shape: '+ str(image_array.shape))
    print('Label shape: '+ str(label_array.shape))
    show_array(image_array[0,:,:,:], title='Ch1 '+ str(i))
    if image_array.shape[0]==2:
        show_array(image_array[1,:,:,:], title='Ch2 '+ str(i))
    show_array(label_array[0,:,:,:], title='Label '+ str(i))




Image shape: (2, 3, 192, 192)
Label shape: (1, 3, 192, 192)


interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

RuntimeError: applying transform <monai.transforms.spatial.dictionary.RandZoomd object at 0x2a21022b0>

### Run code to be tested

In [15]:
# import torch
# from monai.networks.nets import UNet

# # Define the UNet architecture
# model_unet = UNet(
#     spatial_dims=3,
#     in_channels=1,
#     out_channels=2,
#     channels=[16, 32, 64, 128],
#     strides=[(1, 2, 2), (1, 2, 2), (1, 1, 1)],
#     num_res_units=2,
# )

# # Create an example input tensor
# input_tensor = torch.randn(1, 1, 3, 192, 192)

# # Pass the input tensor through the UNet model
# output_tensor = model_unet(input_tensor)

# # Print the size of the output tensor
# print(output_tensor.size())  # Output: torch.Size([1, 2, 3, 192, 192])
