# Resample a 3D Image

In [1]:
import os
import tqdm

import SimpleITK as sitk
import numpy as np
from matplotlib import pyplot as plt
from ipywidgets import interact, fixed

In [2]:
 def read_dicom_series(fn:str):
    "load a DICOM series as `sitk.Image`"
    SeriesReader = sitk.ImageSeriesReader()
    dicom_names = SeriesReader.GetGDCMSeriesFileNames(str(fn))
    SeriesReader.SetFileNames(dicom_names)
    im = SeriesReader.Execute()
    return sitk.Cast(im, sitk.sitkInt16)

In [3]:
def resize_image(im, new_size=(128, 128, 64)):
    "Resample a sitk.Image to `new_size` maintaining oriantation and direction"
    # get old image meta data
    old_size = im.GetSize()
    old_spacing = im.GetSpacing()
    old_orig = im.GetOrigin()
    old_dir = im.GetDirection()
    old_pixel_id = im.GetPixelIDValue()
    
    # calculate new spacing
    new_spacing = [spc * (old_sz/new_sz) for spc, old_sz, new_sz in zip(old_spacing, old_size, new_size)]
    
    # create reference plane to resample image
    ref_image = sitk.Image(new_size, old_pixel_id)
    ref_image.SetSpacing(new_spacing)
    ref_image.SetOrigin(old_orig)
    ref_image.SetDirection(old_dir)
    
    # resample image to `new_size`
    return sitk.Resample(im, ref_image)

In [4]:
def resample_to_axial(image, spacing=None):
    "resample an image to direction (1,0,0,0,1,0,0,0,1)"
    
    # define the Euler Transformation for Image Rotation
    euler3d = sitk.Euler3DTransform() # define transform for rotation of the image
    image_center = (np.array(image.GetSize())/2.0)  # set rotation center to image center
    image_center_as_sitk_point = image.TransformContinuousIndexToPhysicalPoint(image_center)
    euler3d.SetCenter(image_center_as_sitk_point)
    
    # get index of volume edges
    w,h,d = image.GetSize()
    extreme_points = [(0, 0, 0), (w, 0, 0), (0, h, 0), (0, 0, d), 
                      (w, h, 0), (w, 0, d), (0, h, d), (w, h, d)]
    # transform edges to physical points in the global coordinate system
    extreme_points = [image.TransformIndexToPhysicalPoint(pnt) for pnt in extreme_points]
    inv_euler3d = euler3d.GetInverse()
    extreme_points_transformed = [inv_euler3d.TransformPoint(pnt) for pnt in extreme_points]
    
    # get new min and max coordinates of image edges
    min_x = min(extreme_points_transformed)[0]
    min_y = min(extreme_points_transformed, key=lambda p: p[1])[1]
    min_z = min(extreme_points_transformed, key=lambda p: p[2])[2]
        
    max_x = max(extreme_points_transformed)[0]
    max_y = max(extreme_points_transformed, key=lambda p: p[1])[1]
    max_z = max(extreme_points_transformed, key=lambda p: p[2])[2]
        
    # define new direction
    # take spacing and size from original image
    # calculate new origin from extree points
    output_spacing = spacing if spacing else image.GetSpacing()
    output_direction = tuple(np.identity(3).flatten())
    output_origin = [min_x, min_y, min_z]
    output_size = [int((max_x-min_x)/output_spacing[0]), 
                   int((max_y-min_y)/output_spacing[1]), 
                   int((max_z-min_z)/output_spacing[2])]
        
    resampled_image = sitk.Resample(image, output_size, euler3d, sitk.sitkLinear, output_origin, output_spacing, output_direction)
    return resampled_image 

In [5]:
def show_image(image, slice_id):
    img_slice = sitk.GetArrayViewFromImage(image)[slice_id, :, :]
    plt.imshow(img_slice, cmap='bone')

In [6]:
def read_and_resample(fn, plot=False):
    im = read_dicom_series(fn)
    im_resampled = resample_to_axial(im)
    if plot:
        interact(show_image, image=fixed(im_resampled), slice_id = (0, im_resampled.GetDepth()-1));
    else: return im_resampled

In [7]:
parent = 'images_train'
sub_dirs = next(os.walk(parent))[1]

StopIteration: 

In [None]:
for sub_dir in tqdm.tqdm(sub_dirs): 
    for sequence in ['FLAIR', 'T1w', 'T2w', 'T1wCE']:
        im = read_and_resample(f'{parent}/{sub_dir}/{sequence}')
        new_dir = f'train/{sub_dir}/'
        os.makedirs(new_dir, exist_ok=True)
        sitk.WriteImage(im, f'{new_dir}/{sequence}.nii.gz')


In [None]:
im = sitk.ReadImage('test/00335/FLAIR.nii.gz')
sitk.ImageViewer_SetGlobalDefaultApplication('/home/bressekk/Fiji.app/ImageJ-linux64')
sitk.Show(im)