In [None]:
import sys
import os

current_directory = os.getcwd()
path = os.path.dirname(current_directory)
sys.path.append(path)
from Utils import *

%matplotlib widget
from ipywidgets import interact, interactive, widgets
from matplotlib.patches import Rectangle, Circle, Arrow

# Set paths and files

In [None]:
import glob
import os

import tempfile
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Any, Mapping, Hashable

import monai
from monai.config import print_config
from monai.utils import first
from monai.config import KeysCollection
from monai.data import Dataset, ArrayDataset, create_test_image_3d, DataLoader
from monai.transforms import (
    Compose,
    Orientationd,
    EnsureChannelFirstd,
    LoadImaged,
    ConcatItemsd,
    ScaleIntensityd,
    Spacingd,
    AsDiscreted,
    SaveImaged,
    SaveImage
)
from monaiUtils import MetatensorToSitk, SitkToMetatensor
from sitkIO import LoadSitkImaged, PushSitkImaged, PushSitkImage

# Load images
print('Reading images from: ' + current_directory)
images_1 = sorted(glob.glob(os.path.join(current_directory, 'test_dataset', '*_M.nii.gz')))
images_2 = sorted(glob.glob(os.path.join(current_directory, 'test_dataset', '*_P.nii.gz')))


# Create dictionary for MONAI
select_image = 3 # train_files[0] = 2D images COR  / train_files[1] = 2D images SAG / train_files[2] = 3D images COR

train_files = [
    {'image_1': image_name_1, 'image_2': image_name_2}
    for image_name_1, image_name_2 in zip(images_1, images_2)
]

print(images_1[select_image])
# Load original image as SimpleITK object
sitk_input_1 = sitk.ReadImage(images_1[select_image], sitk.sitkFloat32)
sitk_input_2 = sitk.ReadImage(images_2[select_image], sitk.sitkFloat32)
show_mag_phase_images(sitk_input_1, sitk_input_2, 'Loaded sitk image from nii')
print('SimpleITK object sizes')
print(sitk_input_1.GetSize())


# Nifti file LoadImage

In [None]:
# Load from Nifti files
loadNifti = Compose([LoadImaged(keys=['image_1', 'image_2'], image_only=True)])
metatensor_nifti = loadNifti(train_files[select_image])

print('Metatensors from nifti')
print(metatensor_nifti['image_1'].shape)
print(metatensor_nifti)

# Sitk image LoadSitkImage

In [None]:
sitk_dict = {'image_1': sitk_input_1, 'image_2': sitk_input_2}

loadSitk = Compose([LoadSitkImaged(keys=['image_1', 'image_2'], image_only=True)])
metatensor_itkreader = loadSitk(sitk_dict)

print('Metatensors from sitkReader')
print(metatensor_itkreader['image_1'].shape)


# Sitk image SitkToMetaTensor

In [None]:
# Get MONAI metatensor from sitk image
metatensor_1, metatensor_1_dict = SitkToMetatensor(sitk_input_1)
metatensor_2, metatensor_2_dict = SitkToMetatensor(sitk_input_2)

print('Metatensors from sitk')
print(metatensor_1.shape)

metatensor_sitk = {'image_1': metatensor_1, 'image_2': metatensor_2}

In [None]:
pushTransf = Compose([PushSitkImaged(keys=['pred'], resample=False)])
sitk_nifti = pushTransf(metatensor_nifti)

print(sitk_nifti)

In [None]:
# Define Transforms
if select_image==1: # SAG image
  orientation = 'LIP'
else:
  orientation = 'PIL'
if select_image==2: # 3D volume
  pixdim = (6, 1.171875, 1.171875)
else:
  pixdim = (3.6, 1.171875, 1.171875)
  
# Ensure Channel First
channelFirst = Compose([EnsureChannelFirstd(keys=['image_1', 'image_2'])])
# Concatenate
concatImages = Compose([ConcatItemsd(keys=['image_1', 'image_2'], name='image')])
# Orientation and other
preTransf = Compose([ ScaleIntensityd(keys=['image'], minv=0, maxv=1, channel_wise=True),
                      # Orientationd(keys=['image'], axcodes=orientation),
                      Spacingd(keys=['image'], pixdim=pixdim, mode=("bilinear"))
                    ])

In [None]:
# Transforms with Nifti
metatensor_nifti = channelFirst(metatensor_nifti)
print('Metatensors from nifti - Channel first')
print(metatensor_nifti['image_1'].shape)

metatensor_nifti = concatImages(metatensor_nifti)
print('Metatensors from nifti - Concatenate')
print(metatensor_nifti['image'].shape)

metatensor_nifti = preTransf(metatensor_nifti)
print('Metatensors from nifti - Pre-transforms (Orientation change to DHW)')
print(metatensor_nifti['image'].shape)

In [None]:
# Transforms with Sikt
metatensor_sitk = channelFirst(metatensor_sitk)
print('Metatensors from sitk - Channel first')
print(metatensor_sitk['image_1'].shape)

metatensor_sitk = concatImages(metatensor_sitk)
print('Metatensors from sitk - Concatenate')
print(metatensor_sitk['image'].shape)

metatensor_sitk = preTransf(metatensor_sitk)
print('Metatensors from sitk - Pre-transforms')
print(metatensor_sitk['image'].shape)

In [None]:
# Transforms with itkReader
metatensor_itkreader = channelFirst(metatensor_itkreader)
print('Metatensors from itkreader - Channel first')
print(metatensor_itkreader['image_1'].shape)

metatensor_itkreader = concatImages(metatensor_itkreader)
print('Metatensors from itkreader - Concatenate')
print(metatensor_itkreader['image'].shape)

metatensor_itkreader = preTransf(metatensor_itkreader)
print('Metatensors from itkreader - Pre-transforms')
print(metatensor_itkreader['image'].shape)


In [None]:
# Check meta and sizes
print(metatensor_nifti['image'].shape)
print(metatensor_sitk['image'].shape)
print(metatensor_itkreader['image'].shape)

# UNet

In [None]:
# Choose source:
metatensor = metatensor_itkreader

import torch
from monai.networks.nets import UNet 
from monai.networks.layers import Norm
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
from monai.handlers.utils import from_engine

# Setup UNet model
if select_image == 2:
    model_file= os.path.join(path, 'model_MP_multi_CEDice_3000_PIL_t3_noise_600epoch_Unet4_TEST.pth') #3D
    window_size = (3,64,64)
else:
    model_file= os.path.join(path, 'model_MP_multi_CEDice_5500_PIL_t3_t6_BG11_noise_600epoch_Unet4_2D.pth') #2D
    window_size = (1,64,64)

model_unet = UNet(
    spatial_dims=3,
    in_channels=2,
    out_channels=3,
    channels=[16, 32, 64, 128], 
    strides=[(1, 2, 2), (1, 2, 2), (1, 1, 1)], 
    num_res_units=2,
    norm=Norm.BATCH,
)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model_unet.to(device)

# Evaluate model
model.load_state_dict(torch.load(model_file, map_location=device))
model.eval()
with torch.no_grad():
    batch_input = metatensor['image'].unsqueeze(0)
    val_inputs = batch_input.to(torch.device('cpu'))
    val_outputs = sliding_window_inference(val_inputs, window_size, 1, model)
    metatensor['pred'] = val_outputs[0]
    
print('UNet output')
print(metatensor['pred'].shape)

In [None]:
# Post Transform
postTransf = Compose([  AsDiscreted(keys='pred', argmax=True, n_classes=3),
                        PushSitkImaged(keys=['pred'], resample=True)
                    ])
metatensor = postTransf(metatensor)
sitk_output =  metatensor['pred']

print('Post-transform output')
print(metatensor['pred'].GetSize())
show_mag_phase_images(sitk_input_1, sitk_output, title= 'Needle segmentation', subtitles=['Input', 'Output'])


In [None]:
# Save output image
file_name = os.path.basename(images_1[select_image])
prefix_name = file_name.split('_M.nii.gz')[0]
sitk.WriteImage(sitk_input_1, os.path.join(path, 'TestingNotebook', 'debug', str(select_image).zfill(3), prefix_name+'_M.nii.gz'))
sitk.WriteImage(sitk_input_2, os.path.join(path, 'TestingNotebook', 'debug', str(select_image).zfill(3), prefix_name+'_P.nii.gz'))
sitk.WriteImage(sitk_output, os.path.join(path, 'TestingNotebook', 'debug', str(select_image).zfill(3), prefix_name+'_seg.nii.gz'))