In [1]:
import numpy as np
import SimpleITK as sitk
import os
import math

In [2]:
def get_list_of_files(base_dir):
    list_of_lists = []
    patients = os.listdir(base_dir)
    for p in patients:
        if p.startswith("mr_train_") and p.endswith(".nii.gz"):
            list_of_lists.append(os.path.join(base_dir, p))
    print("Found %d patients" % len(list_of_lists))
    return list_of_lists


def load_img(dir_img):
    # load SimpleITK Image
    img_sitk = sitk.ReadImage(dir_img)

    # get pixel arrays from SimpleITK images
    img_npy = sitk.GetArrayFromImage(img_sitk)

    # get some metadata
    spacing = img_sitk.GetSpacing()    
    # the spacing returned by SimpleITK is in inverse order relative to the numpy array we receive. 
    spacing = np.array(spacing)[::-1]    
    direction = img_sitk.GetDirection()
    origin = img_sitk.GetOrigin()
    
    original_shape = img_npy.shape
    
    metadata = {
    'spacing': spacing,
    'direction': direction,
    'origin': origin,
    'original_shape': original_shape
    }    
    return img_sitk, img_npy, metadata  


def get_center(img):
    """
    This function returns the physical center point of a 3d sitk image
    :param img: The sitk image we are trying to find the center of
    :return: The physical center point of the image
    """
    width, height, depth = img.GetSize()
    
    return img.TransformIndexToPhysicalPoint((int(np.ceil(width/2)),
                                          int(np.ceil(height/2)),
                                          int(np.ceil(depth/2))))


def rotate_img(img_sitk, transform, is_label, theta_x, theta_y, theta_z):    
    # Angles for each axis
    theta_x = np.deg2rad(theta_x)
    theta_y = np.deg2rad(theta_y)
    theta_z = np.deg2rad(theta_z)
    
    new_transform = sitk.Euler3DTransform(get_center(img_sitk), theta_x, theta_y, theta_z, (0, 0, 0))
    image_center = get_center(img_sitk)
    new_transform.SetCenter(image_center)
    new_transform.SetRotation(theta_x, theta_y, theta_z)  
    
    # Resample
    reference_image = img_sitk
    
    if is_label:
        interpolator = sitk.sitkNearestNeighbor
    else:
        interpolator = sitk.sitkBSpline     

    default_value = 0
    resampled = sitk.Resample(img_sitk, reference_image, new_transform,
                         interpolator, default_value)  
    npy_img = sitk.GetArrayFromImage(resampled)
    
    return npy_img, resampled


def save_as_nii(img_npy, img_name, metadata):
    sitk_image = sitk.GetImageFromArray(img_npy)
    sitk_image.SetDirection(metadata['direction'])
    sitk_image.SetOrigin(metadata['origin'])
    # remember to revert spacing back to sitk order again
    sitk_image.SetSpacing(tuple(metadata['spacing'][[1, 2, 0]]))
    sitk.WriteImage(sitk_image, img_name)   

In [3]:
base_dir = '/usr/not-backed-up2/scsad/DL/MedicalDataAugmentationTool/bin/experiments/semantic_segmentation/mmwhs/TODO_mr'


list_of_files = get_list_of_files(base_dir)

for i in range(1,len(list_of_files),2):

    dir_img = list_of_files[i-1]
    dir_label = list_of_files[i]

    name_image =  dir_img.split('/')[11:][0][0:-7]
    name_label =  dir_label.split('/')[11:][0][0:-7]

    print(name_image)
    print(name_label)

    # Loading image and label
    img_sitk, img_npy, metadata_img = load_img(dir_img)
    label_sitk, label_npy, metadata_label = load_img(dir_label)

    # Rotate image. Discussion: https://stackoverflow.com/questions/56171643/simpleitk-rotation-of-mri-image
    theta_x, theta_y, theta_z = 0, 30, 0
    affine_ro = sitk.AffineTransform(3)
    npy_rotated, rotated_image = rotate_img(img_sitk, affine_ro, False, theta_x, theta_y, theta_z) 

    # Rotate label
    npy_rotated_label, rotated_label = rotate_img(label_sitk, affine_ro, True, theta_x, theta_y, theta_z) 

    # Downsampling image. 
    max_slices = math.floor(npy_rotated.shape[0] / 10)
    downsampled_img = np.array([npy_rotated[i*10,:,:] for i in range(0, max_slices)],dtype=np.int16)   
    save_as_nii(downsampled_img, base_dir + '/augmented_' + name_image + '.nii.gz', metadata_img)

    # Downsampling label 
    downsampled_label = np.array([npy_rotated_label[i*10,:,:] for i in range(0, max_slices)],dtype=np.int16)
    save_as_nii(downsampled_label, base_dir + '/augmented_' + name_label + '.nii.gz', metadata_label)
    
    # Check how to rotate also the point!

Found 40 patients
mr_train_1001_image
mr_train_1001_label
mr_train_1002_image
mr_train_1002_label
mr_train_1003_image
mr_train_1003_label
mr_train_1004_image
mr_train_1004_label
mr_train_1005_image
mr_train_1005_label
mr_train_1006_image
mr_train_1006_label
mr_train_1007_image
mr_train_1007_label
mr_train_1008_image
mr_train_1008_label
mr_train_1009_image
mr_train_1009_label
mr_train_1010_image
mr_train_1010_label
mr_train_1011_image
mr_train_1011_label
mr_train_1012_image
mr_train_1012_label
mr_train_1013_image
mr_train_1013_label
mr_train_1014_image
mr_train_1014_label
mr_train_1015_image
mr_train_1015_label
mr_train_1016_image
mr_train_1016_label
mr_train_1017_image
mr_train_1017_label
mr_train_1018_image
mr_train_1018_label
mr_train_1019_image
mr_train_1019_label
mr_train_1020_image
mr_train_1020_label
