### Use this for testing MONAI transforms

### Libraries and auxiliar functions

In [1]:
import sys
import os
import random

current_directory = os.getcwd()
path = os.path.dirname(current_directory)
sys.path.append(path)
from Utils import *

%matplotlib widget
from ipywidgets import interact, interactive, widgets
from matplotlib.patches import Rectangle, Circle, Arrow

import glob
import os

import tempfile
import nibabel as nib
import matplotlib.pyplot as plt
from typing import Optional, Any, Mapping, Hashable

import monai
from monai.config import print_config
from monai.utils import first, ensure_tuple
from monai.config import KeysCollection
from monai.data import Dataset, ArrayDataset, create_test_image_3d, DataLoader
from monai.transforms import (
    AdjustContrastd,
    Transform,
    Compose,
    LoadImage,
    Orientation,
    ConcatItemsd,
    LoadImaged,
    Orientationd,
    EnsureChannelFirstd,
    EnsureChannelFirst,
    ToTensord,
    Spacingd,
    ScaleIntensityd,
    CropForegroundd,
    RandGaussianNoised,
    RandRicianNoised,
    RandCropByLabelClassesd,
    RandCropByPosNegLabeld,
    RandFlipd,
    RandZoomd,
    RandKSpaceSpikeNoised,
    KSpaceSpikeNoised
)

from sitkIO import PushSitkImage

In [2]:
class Param():
    def __init__(self, data_dir, pixel_dim, window_size, orientation, in_channels, out_channels, input_type, label_type, rand_noise, spike_noise, rand_flip, rand_zoom):
        self.data_dir = data_dir
        self.pixel_dim = pixel_dim
        self.window_size = window_size
        self.axcodes = orientation
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.input_type = input_type
        self.label_type = label_type
        self.training_rand_noise = rand_noise
        self.training_spike_noise = spike_noise
        self.training_rand_flip = rand_flip
        self.training_rand_zoom = rand_zoom

def generateLabeledFileList(param, prefix):
    print('Reading labeled images from: ' + param.data_dir)
    images_m = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_M.nii.gz")))
    images_p = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_P.nii.gz")))
    images_r = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_R.nii.gz")))
    images_i = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_images", "*_I.nii.gz")))
    labels = sorted(glob.glob(os.path.join(param.data_dir, prefix + "_labels", "*_"+param.label_type+"_label.nii.gz")))
    # Use two types of images combined
    if param.in_channels==2:
        # Use real and imaginary images
        if param.input_type=='R' or param.input_type=='I':
            print('Use real/imaginary images')
            data_dicts = [
                {"image_1": image_r_name, "image_2": image_i_name, "label":label_name}
                for image_r_name, image_i_name, label_name in zip(images_r, images_i, labels)
            ]
        # Use magnitude and phase images
        else:
            print('Use magnitude/phase images')
            data_dicts = [
                {"image_1": image_m_name, "image_2": image_p_name, "label":label_name}
                for image_m_name, image_p_name, label_name in zip(images_m, images_p, labels)
            ]
    # Use only one type of image        
    else:
        # Use real images
        if param.input_type=='R':
            print('Use real images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_r, labels)
            ]
        # Use imaginary images
        elif param.input_type=='I':
            print('Use imaginary images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_i, labels)
            ]
        # Use phase images
        elif param.input_type=='P':
            print('Use phase images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_p, labels)
            ]
        # Use magnitude images
        else:
            print('Use magnitude images')
            data_dicts = [
                {"image": image_name, "label": label_name}
                for image_name, label_name in zip(images_m, labels)
            ]
    return data_dicts    

In [3]:
# Build param (info from config.ini)
data_dir = os.getcwd()+''
pixel_dim = (3.6, 1.171875, 1.171875)
window_size = (3, 64, 64)
orientation = 'PIL'
in_channels = 2
out_channels = 3
input_type = 'MP'
label_type = 'multi'
rand_noise = 0.8
spike_noise = 0.0
rand_flip = 0.5
rand_zoom = 0.5
param = Param(data_dir, pixel_dim, window_size, orientation, in_channels, out_channels, input_type, label_type, rand_noise, spike_noise, rand_flip, rand_zoom)

# Create dictionary
print('Create dictionary')
train_files = generateLabeledFileList(param, 'test')
print(train_files)
print(train_files[0]['image_1'])
print(train_files[0]['image_2'])


Create dictionary
Reading labeled images from: /Users/pl771/Devel/MRINeedleSegmentation-LIVER/TestingNotebook
Use magnitude/phase images
[{'image_1': '/Users/pl771/Devel/MRINeedleSegmentation-LIVER/TestingNotebook/test_images/SyntheticImage_001_M.nii.gz', 'image_2': '/Users/pl771/Devel/MRINeedleSegmentation-LIVER/TestingNotebook/test_images/SyntheticImage_001_P.nii.gz', 'label': '/Users/pl771/Devel/MRINeedleSegmentation-LIVER/TestingNotebook/test_labels/SyntheticImage_001_multi_label.nii.gz'}, {'image_1': '/Users/pl771/Devel/MRINeedleSegmentation-LIVER/TestingNotebook/test_images/SyntheticImage_002_M.nii.gz', 'image_2': '/Users/pl771/Devel/MRINeedleSegmentation-LIVER/TestingNotebook/test_images/SyntheticImage_002_P.nii.gz', 'label': '/Users/pl771/Devel/MRINeedleSegmentation-LIVER/TestingNotebook/test_labels/SyntheticImage_002_multi_label.nii.gz'}]
/Users/pl771/Devel/MRINeedleSegmentation-LIVER/TestingNotebook/test_images/SyntheticImage_001_M.nii.gz
/Users/pl771/Devel/MRINeedleSegmentat

# Load originals

In [4]:
# Plot original
print('ORIGINAL')

if param.in_channels==2:
    # Two channels input
    load_array = [
        LoadImaged(keys=["image_1", "image_2", "label"], image_only=False),
        EnsureChannelFirstd(keys=["image_1", "image_2", "label"]), 
        ConcatItemsd(keys=["image_1", "image_2"], name="image"),
        Orientationd(keys=["image", "label"], axcodes=param.axcodes)
    ]
else:
    # One channel input
    load_array = [            
        LoadImaged(keys=["image", "label"], image_only=False),
        EnsureChannelFirstd(keys=["image", "label"], channel_dim='no_channel'),
        Orientationd(keys=["image", "label"], axcodes=param.axcodes)
    ]

sitkTransform = PushSitkImage(resample=False, output_dtype=np.float32, print_log=False)
loadTest = Compose(load_array)
original = loadTest(train_files)
N = len(original)
for i in range(N):
    original_dict = original[i]
    image_m = original_dict['image'][0] #ch1
    label = original_dict['label'][0]   #ch1
    sitk_image_m = sitkTransform(image_m)
    sitk_label = sitkTransform(label)
    print('Input shape: ', sitk_image_m.GetSize())
    print('Label shape: ', sitk_label.GetSize())
    
    if param.in_channels==2:
        image_p = original_dict['image'][1] #ch2
        sitk_image_p = sitkTransform(image_p)
        show_mag_phase_images(sitk_image_m, sitk_image_p, title = 'Input images')
    else:
        show_image(sitk_image_m, title='Input image')
    show_image(sitk_label, title = 'Label')


ORIGINAL
Input shape:  (3, 192, 192)
Label shape:  (3, 192, 192)


interactive(children=(IntSlider(value=95, description='z', max=191), Output()), _dom_classes=('widget-interact…

interactive(children=(IntSlider(value=95, description='z', max=191), Output()), _dom_classes=('widget-interact…

Input shape:  (3, 192, 192)
Label shape:  (3, 192, 192)


interactive(children=(IntSlider(value=95, description='z', max=191), Output()), _dom_classes=('widget-interact…

interactive(children=(IntSlider(value=95, description='z', max=191), Output()), _dom_classes=('widget-interact…

In [5]:
# Define preprocessing transforms    
# Load images
if param.in_channels==2:
    # Two channels input
    transform_array = [
        LoadImaged(keys=["image_1", "image_2", "label"], image_only=False),             # Load Magnitude, Phase and labelmap
        EnsureChannelFirstd(keys=["image_1", "image_2", "label"]),                      # Ensure channel first
        ScaleIntensityd(keys=["image_1", "image_2"], minv=0, maxv=1, channel_wise=True) # Scale intensity to 0-1
    ]
    # Noise addition
    print(param.training_rand_noise)
    if param.training_rand_noise != 0:
        if random.random() < param.training_rand_noise: # Probability of adding noise
            print('ADD NOISE')
            transform_array.append(RandRicianNoised(keys=["image_1"], prob=1, mean=0, std=0.1))     # Add Rician noise to Magnitude -  mean=0, std=0.1
            transform_array.append(RandGaussianNoised(keys=["image_2"], prob=1, mean=0, std=0.08))  # Add small Gaussian noise to Phase - mean=0, std=0.08
    transform_array.append(ConcatItemsd(keys=["image_1", "image_2"], name="image"))     # Concatenate Magnitude and Phase to 2-channels       
else:
    # One channel input
    transform_array = [            
        LoadImaged(keys=["image", "label"], image_only=False),                          # Load Magnitude and labelmap
        EnsureChannelFirstd(keys=["image", "label"], channel_dim='no_channel'),         # Ensure channel first
        ScaleIntensityd(keys=["image"], minv=0, maxv=1, channel_wise=True)              # Scale intensity to 0-1
    ]
    # Noise addition
    if param.training_rand_noise != 0:
        transform_array.append(RandRicianNoised(keys=["image"], prob=param.training_rand_noise, mean=0, std=0.1))           # Add Rician noise to Magnitude 

# Intensity adjustment
if (param.input_type == 'R') or (param.input_type == 'I'):
    transform_array.append(AdjustContrastd(keys=["image"], gamma=2.5))                  # Increase contrast for real/imaginary
ScaleIntensityd(keys=["image"], minv=0, maxv=1, channel_wise=True) # Re-scale intensity after noise addition

# Spatial adjustments
transform_array.append(Orientationd(keys=["image", "label"], axcodes=param.axcodes))                            # Adjust image orientation
transform_array.append(Spacingd(keys=["image", "label"], pixdim=param.pixel_dim, mode=("bilinear", "nearest"))) # Adjust image spacing

# Spike noise addition
if param.training_spike_noise != 0:
    transform_array.append(RandKSpaceSpikeNoised(keys=['image'], prob=param.training_spike_noise, channel_wise=False, intensity_range=(0.95*8.6,1.10*8.6)))
    ScaleIntensityd(keys=["image"], minv=0, maxv=1, channel_wise=True) # Re-scale intensity after noise addition
        
# Intensity adjustment and noise addition
print('INTENSITY ADJUST')
sitkTransform = PushSitkImage(resample=False, output_dtype=np.float32, print_log=False)
intensityTest = Compose(transform_array)
output = intensityTest(train_files)
N = len(output)
N=1
for i in range(N):
    output_dict = output[i]
    image_m = output_dict['image'][0] #ch1
    label = output_dict['label'][0]   #ch1
    sitk_image_m = sitkTransform(image_m)
    sitk_label = sitkTransform(label)
    print('Input shape: ', sitk_image_m.GetSize())
    print('Label shape: ', sitk_label.GetSize())

    if param.in_channels==2:
        image_p = output_dict['image'][1] #ch2
        sitk_image_p = sitkTransform(image_p)
        show_mag_phase_images(sitk_image_m, sitk_image_p, title = 'Input images')
    else:
        show_image(sitk_image_m, title='Input image')
    show_image(sitk_label, title = 'Label')

0.8
ADD NOISE
INTENSITY ADJUST
Input shape:  (3, 192, 192)
Label shape:  (3, 192, 192)


interactive(children=(IntSlider(value=95, description='z', max=191), Output()), _dom_classes=('widget-interact…

interactive(children=(IntSlider(value=95, description='z', max=191), Output()), _dom_classes=('widget-interact…

In [6]:
transform_array.append(RandKSpaceSpikeNoised(
    keys=['image'],
    prob=1.0,
    channel_wise=False,
    intensity_range=(6.5,6.5),
))

sitkTransform = PushSitkImage(resample=False, output_dtype=np.float32, print_log=False)
transfTest = Compose(transform_array)
output = transfTest(train_files)
N = len(output)
N=1
for i in range(N):
    output_dict = output[i]
    image_m = output_dict['image'][0] #ch1
    label = output_dict['label'][0]   #ch1
    sitk_image_m = sitkTransform(image_m)
    sitk_label = sitkTransform(label)
    print('Input shape: ', sitk_image_m.GetSize())
    print('Label shape: ', sitk_label.GetSize())

    if param.in_channels==2:
        image_p = output_dict['image'][1] #ch2
        sitk_image_p = sitkTransform(image_p)
        show_mag_phase_images(sitk_image_m, sitk_image_p, title = 'Input images')
    else:
        show_image(sitk_image_m, title='Input image')
    show_image(sitk_label, title = 'Label')

Input shape:  (3, 192, 192)
Label shape:  (3, 192, 192)


interactive(children=(IntSlider(value=95, description='z', max=191), Output()), _dom_classes=('widget-interact…

interactive(children=(IntSlider(value=95, description='z', max=191), Output()), _dom_classes=('widget-interact…

# Data Augmentation

In [7]:
# # Data augmentation
transform_array.append(RandZoomd(
    keys=['image', 'label'],
    prob=param.training_rand_zoom,
    min_zoom=1.0,
    max_zoom=1.3,
    mode=['area', 'nearest'],
))

transform_array.append(RandFlipd(
    keys=['image', 'label'],
    prob=param.training_rand_flip,
    spatial_axis=2,
))

# Balance background/foreground
transform_array.append(RandCropByPosNegLabeld(
    keys=["image", "label"],
    label_key="label",
    spatial_size=param.window_size,
    pos=5, 
    neg=1,
    num_samples=3,
    image_key="image",
    image_threshold=0, 
))

print('== DATA AUGMENTATION ==')
sitkTransform = PushSitkImage(resample=False, output_dtype=np.float32, print_log=False)
transfTest = Compose(transform_array)
output = transfTest(train_files)
N = len(output[0])
N=1
for i in range(N):
    output_dict = output[0][i]
    image_m = output_dict['image'][0] #ch1
    label = output_dict['label'][0]   #ch1
    sitk_image_m = sitkTransform(image_m)
    sitk_label = sitkTransform(label)
    print('Input shape: ', sitk_image_m.GetSize())
    print('Label shape: ', sitk_label.GetSize())

    if param.in_channels==2:
        image_p = output_dict['image'][1] #ch2
        sitk_image_p = sitkTransform(image_p)
        show_mag_phase_images(sitk_image_m, sitk_image_p, title = 'Input images')
    else:
        show_image(sitk_image_m, title='Input image')
    show_image(sitk_label, title = 'Label')



== DATA AUGMENTATION ==
Input shape:  (3, 64, 64)
Label shape:  (3, 64, 64)


Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.


interactive(children=(IntSlider(value=31, description='z', max=63), Output()), _dom_classes=('widget-interact'…

interactive(children=(IntSlider(value=31, description='z', max=63), Output()), _dom_classes=('widget-interact'…

### Check dimensions

In [20]:
import torch
from monai.networks.nets import UNet
if param.axcodes == 'PIL':
    print('PIL')
    strides = [(1, 2, 2), (1, 2, 2), (1, 1, 1)]   # PIL
else:   
    print('LIP') 
    strides = [(2, 2, 1), (2, 2, 1), (1, 1, 1)]   # LIP

# Define the UNet architecture
model_unet = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=[16, 32, 64, 128],
    strides=strides,
    num_res_units=2,
)

# Create an example input tensor
input_tensor = torch.randn(1, 1, 3, 192, 192)

# Pass the input tensor through the UNet model
output_tensor = model_unet(input_tensor)

# Print the size of the output tensor
print(output_tensor.size())  # Output: torch.Size([1, 2, 3, 192, 192])


PIL
torch.Size([1, 2, 3, 192, 192])
