In [11]:
# This notebook augments a given dataset.

In [12]:
import config as conf
import importlib
importlib.reload(conf)

<module 'config' from '/home/msc_student/vox2vox/config.py'>

In [13]:
# Installs

#!pip install elasticdeform
#!pip install tensorflow_addons
#!pip install nibabel
#!pip install matplotlib
#!pip install sklearn

In [14]:
# Imports

# System imports
import os
import glob
import concurrent.futures

# Tensorflow
import tensorflow as tf
from tensorflow.keras.preprocessing.image import apply_affine_transform
from tensorflow.keras.utils import to_categorical

# Numerical calculations
import numpy as np
from scipy.ndimage.interpolation import affine_transform
import elasticdeform as ed

# Own imports
import scan_loader
from data_generator import DataGenerator

In [15]:
# GPU setup

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = conf.gpu

allow_multi_gpu = True
tf_version = 2

if tf_version == 2 and allow_multi_gpu:
    gpus = tf.config.experimental.list_physical_devices('GPU')
    print("Num GPUs Available: ", len(gpus))
    if gpus:
        for gpu in gpus:
          tf.config.experimental.set_memory_growth(gpu, True)

Num GPUs Available:  1


In [16]:
# Finds paths to all MRT-scans in the dataset folder.

t1_list    = sorted(glob.glob(conf.dataset_mask_train + '*t1.nii.gz'))
t2_list    = sorted(glob.glob(conf.dataset_mask_train + '*t2.nii.gz'))
t1ce_list  = sorted(glob.glob(conf.dataset_mask_train + '*t1ce.nii.gz'))
flair_list = sorted(glob.glob(conf.dataset_mask_train + '*flair.nii.gz'))
seg_list   = sorted(glob.glob(conf.dataset_mask_train + '*seg.nii.gz'))

data = []
for i in range(len(t1_list)):
    data.append((t1_list[i], t2_list[i], t1ce_list[i], flair_list[i], seg_list[i]))

In [17]:
# Defines data augmentation functions

# Flips a MRT-image on a random plane
def flip3D(X, y):
    choice = np.random.randint(3)
    print(f"Flip: {choice}")
    
    if choice == 0: # flip on x
        X_flip, y_flip = X[::-1, :, :, :], y[::-1, :, :]
    if choice == 1: # flip on y
        X_flip, y_flip = X[:, ::-1, :, :], y[:, ::-1, :]
    if choice == 2: # flip on z
        X_flip, y_flip = X[:, :, ::-1, :], y[:, :, ::-1]
        
    return X_flip, y_flip

# Rotates a MRT-image randomly by up to 31°
def rotation3D(X, y):
    # TODO: Doesn't this only turn into positive direction? 
    # Wouldn't np.random.randint(-31,31,size=3) make more sense?
    alpha, beta, gamma = np.random.randint(0, 31, size=3)/180*np.pi
    print(f"Rotation: {alpha}, {beta}, {gamma}")
    
    # Calcualtes rotation matrices
    Rx = np.array([[1, 0, 0],
                   [0, np.cos(alpha), -np.sin(alpha)],
                   [0, np.sin(alpha), np.cos(alpha)]])
    
    Ry = np.array([[np.cos(beta), 0, np.sin(beta)],
                   [0, 1, 0],
                   [-np.sin(beta), 0, np.cos(beta)]])
    
    Rz = np.array([[np.cos(gamma), -np.sin(gamma), 0],
                   [np.sin(gamma), np.cos(gamma), 0],
                   [0, 0, 1]])
    
    # Combines rotation matrices into one rotation matrix
    # TODO: Take center of image into account.
    R = np.dot(np.dot(Rx, Ry), Rz)
    #t = np.array([[120,120,77]]).T
    #Rt = np.hstack((R,t))
    #Rt = np.linalg.inv(Rt)
    
    # Performs rotation
    X_rot = np.empty_like(X)
    backgrounds = X.min(axis=(0,1,2))
    
    for channel in range(X.shape[-1]):
        X_rot[:,:,:,channel] = affine_transform(X[:,:,:,channel], R, order=0, cval=backgrounds[channel])
        
    y_rot = affine_transform(y, R, order=0, mode="nearest")
    
    return X_rot, y_rot

### NEW AUGMENTATION TECHNIQUES
def brightness(X, y):
    """
    Changing the brighness of a image using power-law gamma transformation.
    Gain and gamma are chosen randomly for each image channel.
    
    Gain chosen between [0.9 - 1.1]
    Gamma chosen between [0.9 - 1.1]
    
    new_im = gain * im^gamma
    """
    print("Brightness")
    
    factor = 0.2
    
    X_new = np.zeros(X.shape)
    for c in range(X.shape[-1]):
        im = X[:,:,:,c]        
        gain, gamma = (1-factor) + np.random.random_sample(2,) * factor * 2.0
        im_new = np.sign(im)*gain*(np.abs(im)**gamma)
        X_new[:,:,:,c] = im_new 
    
    return X_new, y

def elastic(X, y):
    """
    Elastic deformation on a image and its target
    """  
    
    factor = np.random.uniform(0.0, 4.0)
    print(f"Elastic: {factor}")
  
    # Randomly transforms the image elastically.
    # Parts of the image that would be empty are set to +1000 so that they can be filled with the appropriate background later.
    [Xel, yel] = ed.deform_random_grid([X, y], sigma=factor, axis=[(0, 1, 2), (0, 1, 2)], order=[0, 0], cval=1000)
    
    # Empty parts of the segmentation contain no tumor.
    yel[yel == 1000] = 0
    
    for channel in range(Xel.shape[-1]):
        
        Xchannel = Xel[:,:,:,channel]
        
        # Sets all empty areas of the channel to the background intensity (minimum).
        Xchannel[Xchannel == 1000] = Xchannel.min()
        
        Xel[:,:,:,channel] = Xchannel
    
    return Xel, yel

def noise(X, y):
    """
    Adds random noise to the image.
    The noise has 1% magnitude of the average intensity of the scan.
    """
    
    brain = X[X > X.min()]
    noise_intensity = brain.mean() * 0.05
    noise_intensity *= np.random.random()
    
    print(f"Noise: {noise_intensity}")
    
    noise = (np.random.random(X.shape)-0.5) * 2.0 * noise_intensity
    
    Xnoise = X + noise
    return Xnoise, y

def contrast(X, y):
    """
    Changes the contrast of the image. Based on tf.image.adjust_contrast(X, contrast_factor)
    See: https://www.tensorflow.org/api_docs/python/tf/image/adjust_contrast
    Formula: contrasted_img = (img - mean) * contrast_factor + mean
    """
    
    contrast_factor = np.random.uniform(0.8, 1.2)
    print(f"Contrast: {contrast_factor}")
    
    for c in range(X.shape[-1]):
        im = X[:,:,:,c]
        
        brain_im = im[im!=0]
        brain_mean = brain_im.mean()
        
        im = (im - brain_mean) * contrast_factor + brain_mean
        
        X[:,:,:,c] = im
    
    return X, y

def translate(X, y):
    """
    Randomly moves the image by translating it.
    """
    
    # Calculates intensity of translation in each dimension.
    magnitude = 0.1
    x_t = int(magnitude * (X.shape[0]))
    y_t = int(magnitude * (X.shape[1]))
    z_t = int(magnitude * (X.shape[2]))
    
    # Generates random image translation.
    x_t = np.random.randint(-x_t, x_t+1)
    y_t = np.random.randint(-y_t, y_t+1)
    z_t = np.random.randint(-z_t, z_t+1)
    
    print(f"Translate: {x_t}, {y_t}, {z_t}")
    
    # Translates image.
    X = np.roll(X, (x_t, y_t, z_t), axis=(0,1,2))
    y = np.roll(y, (x_t, y_t, z_t), axis=(0,1,2))
    
    # Removes overflow
    backgrounds = X.min(axis=(0,1,2))
    
    for modality in range(X.shape[3]):
        if x_t > 0:
            X[:x_t,:,:,modality] = backgrounds[modality]
            y[:x_t,:,:] = 0
        elif x_t < 0:
            X[x_t:,:,:,modality] = backgrounds[modality]
            y[x_t:,:,:] = 0

        if y_t > 0:
            X[:,:y_t,:,modality] = backgrounds[modality]
            y[:,:y_t,:] = 0
        elif y_t < 0:
            X[:,y_t:,:,modality] = backgrounds[modality]
            y[:,y_t:,:] = 0 

        if z_t > 0:
            X[:,:,:z_t,modality] = backgrounds[modality]
            y[:,:,:z_t] = 0
        elif z_t < 0:
            X[:,:,z_t:,modality] = backgrounds[modality]
            y[:,:,z_t:] = 0
        
    return X,y


def zoom(X, y):
    scaling = np.random.uniform(0.8,1.2)
    print(f"Zoom: {scaling}")
    
    R = np.eye(3) * scaling
    
    X_rot = np.zeros_like(X)
    for channel in range(X.shape[-1]):
        X_rot[:,:,:,channel] = affine_transform(X[:,:,:,channel], R, order=0)
    y_rot = affine_transform(y, R, order=0)
    
    return X_rot, y_rot
    

def shear(X, y):
    
    factor = 0.05
    x_shear = np.random.uniform(-factor, factor)
    y_shear = np.random.uniform(-factor, factor)
    z_shear = np.random.uniform(-factor, factor)
    print(f"Shear: {x_shear},{y_shear},{z_shear}")
    
    shear_mat = np.array([
        [1, y_shear, z_shear],
        [x_shear, 1, z_shear],
        [x_shear, y_shear, 1]
    ])
    
    X_shear = np.zeros_like(X)
    
    backgrounds = X.min(axis=(0,1,2))
    
    for channel in range(X.shape[-1]):
        X_shear[:,:,:,channel] = affine_transform(X[:,:,:,channel], shear_mat, order=0, cval=backgrounds[channel])
        
    y_shear = affine_transform(y, shear_mat, order=0, mode="nearest")
    return X_shear, y_shear

In [18]:
# Defines different data augmentation approaches.

def extreme_augmentation(im, gt):
    "Calculates the most extreme data augmentation for the scan."
    
    im, gt = elastic(im, gt)
    im, gt = flip3D(im, gt)
    im, gt = rotation3D(im, gt)
    im, gt = shear(im, gt)
    im, gt = translate(im, gt)
    im, gt = zoom(im, gt)
    
    im, gt = brightness(im, gt)
    im, gt = contrast(im, gt)
    im, gt = noise(im, gt)
    return im, gt


def balanced_augmentation(im, gt):
    """
    Augments the given datapoint with less extreme transformations than extreme_augmentation(...).
    
    The following brightness-altering augmentations are always applied at the very end:
    - brightness adjust
    - contrast adjust
    - addititive noise
    
    Because of its good performance, every datapoint is augmented with elastic transformation.
    
    Out of the pool of geometric transformation, two are pulled and applied:
    - image flipping
    - rotation
    - translation
    - zoom in/out
    - 3D shearing
    """
    
    # Always applies elastic transformation.
    im, gt = elastic(im, gt)
    
    # Pulls 2 unique random numbers out of range(5).
    choices = np.random.choice(range(5), (conf.num_augmentations,), replace = False)
    
    for choice in choices:
        if choice == 0:
            im, gt = flip3D(im, gt)
        elif choice == 1:
            im, gt = rotation3D(im, gt)
        elif choice == 2:
            im, gt = translate(im, gt)
        elif choice == 3:
            im, gt = zoom(im, gt)
        elif choice == 4:
            im, gt = shear(im, gt)
            
    # Applies all brightness augmentations.
    im, gt = brightness(im, gt)
    im, gt = contrast(im, gt)
    im, gt = noise(im, gt)
    
    return im, gt 

In [19]:
# Optional: Use this to test specific augmentations on a single image.
ENABLED = False

if ENABLED:
    # Export settings:
    folder = conf.aug_export_path + "/examples/" # Specify the export folder here!
    aug_title = "noise" # Name the augmentation you are testing here!
    
    # Loads a single image.
    X, y, patient_id = scan_loader.load_img(data[0], normalize = False)
    
    # Apply the augmentations you want to test here!
    Xaug, yaug = noise(X, y)
    
    # Saves the result.
    scan_name = patient_id + "_" + aug_title
    scan_loader.save_full_scan(Xaug, yaug, folder, scan_name)
    print(f"Done! Exported to: {folder}{scan_name}.")

In [20]:
# Creates data generator
aug_gen = DataGenerator(data,
                        shuffle    = False,
                        input_dim  = conf.dataset_dim,
                        output_dim = conf.augmented_dim,
                        batch_size = conf.batch_size,
                        n_channels = conf.num_channels,
                        n_classes  = conf.num_classes,
                        categorical_classes = False,
                        preprocessed = False
)

# Xbatch.shape: num_batches x 128 x 128 x 128 x num_classes
# Ybatch.shape: num_batches x 128 x 128 x 128
for Xbatch, Ybatch, IDbatch in aug_gen:

    # Iterate over all scans in this batch.
    for b in range(Xbatch.shape[0]):

        # Gets one scan from batch.
        im = Xbatch[b,:,:,:,:]
        gt = Ybatch[b,:,:,:]
        patient_id = IDbatch[b]
        
        
        for i in range(conf.num_augmentations):
        
            # Augments scan.
            im_aug, gt_aug = balanced_augmentation(im, gt)

            # Saves augmented scan
            folder = f"{conf.aug_export_path}/{patient_id}_aug_{i}"
            print(f"Saving {patient_id}_aug_{i}...")
            scan_loader.save_full_scan(im_aug, gt_aug, folder, patient_id + "_aug")
        
        # Saves un-augmented scan.
        folder = f"{conf.aug_export_path}/{patient_id}"
        print(f"Saving {patient_id}...")
        scan_loader.save_full_scan(im, gt, folder, patient_id)

Elastic: 0.40365954709004637
Rotation: 0.017453292519943295, 0.41887902047863906, 0.06981317007977318
Brightness
Contrast: 1.0886894007391874
Noise: -0.05676754313135
Saving BraTS20_Training_001_aug_0...
Saving BraTS20_Training_001...
Elastic: 3.452556821828727
Flip: 0
Brightness
Contrast: 0.8886741599072806
Noise: -0.09177652204090722
Saving BraTS20_Training_002_aug_0...
Saving BraTS20_Training_002...
Elastic: 0.4977502250160204
Rotation: 0.40142572795869574, 0.296705972839036, 0.3141592653589793
Brightness
Contrast: 1.058719125719855
Noise: -0.10898741054686016
Saving BraTS20_Training_003_aug_0...
Saving BraTS20_Training_003...
Elastic: 3.1500690686085266


KeyboardInterrupt: 