## 3D U-Net
---
**Import Statements**

In [None]:
import os
import random

import SimpleITK as sitk
import numpy as np
from numpy import load
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
from skimage.util import montage

from keras.models import Model, load_model
from keras.layers import Input, BatchNormalization, Activation, Dense, Dropout, LeakyReLU
from keras.layers.core import Lambda, RepeatVector, Reshape
from keras.layers.convolutional import Conv3D, Conv3DTranspose
from keras.layers.pooling import MaxPooling3D, GlobalMaxPool3D
MaxPool3D = MaxPooling3D
from keras.layers.merge import concatenate, add
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.optimizers import Adam

from keras.utils import to_categorical
from keras.utils import Sequence
from keras import backend as K

from instancenormalization import *

from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
from batchgenerators.transforms.spatial_transforms import MirrorTransform, SpatialTransform
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
from batchgenerators.transforms.color_transforms import ContrastAugmentationTransform, BrightnessMultiplicativeTransform, GammaTransform
from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
from batchgenerators.transforms.abstract_transforms import Compose
from batchgenerators.dataloading import SingleThreadedAugmenter

In [None]:
from datagen import *

**Model Architecture**

In [None]:
def conv3d_block(input_tensor, n_filters, kernel_size=3, instancenorm=True, leakyrelu=True):
    
    x = Conv3D(filters=n_filters, kernel_size=kernel_size, kernel_initializer="he_normal", padding="same")(input_tensor)
    if instancenorm:
        x = InstanceNormalization()(x)
    else:
        x = BatchNormalization()(x)
    if leakyrelu:
        x = Activation(LeakyReLU(alpha=0.01))(x)
    else:
        x = Activation('relu')(x)
    
    x = Conv3D(filters=n_filters, kernel_size=kernel_size, kernel_initializer="he_normal", padding="same")(x)
    if instancenorm:
        x = InstanceNormalization()(x)
    else:
        x = BatchNormalization()(x)
    if leakyrelu:
        x = Activation(LeakyReLU(alpha=0.01))(x)
    else:
        x = Activation('relu')(x)
    
    return x 

In [None]:
def unet(input_vol, n_filters=32, kernel_size=3, instancenorm=True, leakyrelu=True):
    c1 = conv3d_block(input_vol, n_filters=n_filters*(2**0), kernel_size=kernel_size, instancenorm=instancenorm, leakyrelu=leakyrelu)
    p1 = MaxPool3D((2,2,2))(c1)
    
    c2 = conv3d_block(p1, n_filters=n_filters*(2**1), kernel_size=kernel_size, instancenorm=instancenorm, leakyrelu=leakyrelu)
    p2 = MaxPool3D((2,2,2))(c2)
    
    c3 = conv3d_block(p2, n_filters=n_filters*(2**2), kernel_size=kernel_size, instancenorm=instancenorm, leakyrelu=leakyrelu)
    p3 = MaxPool3D((2,2,2))(c3)
    
    c4 = conv3d_block(p3, n_filters=n_filters*(2**3), kernel_size=kernel_size, instancenorm=instancenorm, leakyrelu=leakyrelu)
    p4 = MaxPool3D((2,2,2))(c4)
    
    c5 = conv3d_block(p4, n_filters=n_filters*(2**4), kernel_size=kernel_size, instancenorm=instancenorm, leakyrelu=leakyrelu)
    p5 = MaxPool3D((2,2,1))(c5)

    c6 = conv3d_block(p5, n_filters=n_filters*(2**5), kernel_size=kernel_size, instancenorm=instancenorm, leakyrelu=leakyrelu)
    
    u7 = Conv3DTranspose(filters=n_filters*(2**4),kernel_size=kernel_size,strides=(2,2,1), padding="same")(c6)
    u7 = concatenate([u7,c5])
    c7 = conv3d_block(u7,n_filters=n_filters*(2**4))
    
    u8 = Conv3DTranspose(filters=n_filters*(2**3),kernel_size=kernel_size,strides=(2,2,2), padding="same")(c7)
    u8 = concatenate([u8,c4])
    c8 = conv3d_block(u8,n_filters=n_filters*(2**3))

    u9 = Conv3DTranspose(filters=n_filters*(2**2),kernel_size=kernel_size,strides=(2,2,2), padding="same")(c8)
    u9 = concatenate([u9,c3])
    c9 = conv3d_block(u9,n_filters=n_filters*(2**2))
    
    u10 = Conv3DTranspose(filters=n_filters*(2**1),kernel_size=kernel_size,strides=(2,2,2), padding="same")(c9)
    u10 = concatenate([u10,c2])
    c10 = conv3d_block(u10,n_filters=n_filters*(2**1))
    
    u11 = Conv3DTranspose(filters=n_filters*(2**0),kernel_size=kernel_size,strides=(2,2,2), padding="same")(c10)
    u11 = concatenate([u11,c1])
    c11 = conv3d_block(u11,n_filters=n_filters*(2**0))
    
    outputs = Conv3D(1,(1,1,1),activation="sigmoid")(c11)
    model = Model(inputs = [input_vol], outputs=[outputs])
    return model

**Data Generator**

In [None]:
class DataGenerator(Sequence):
    
    def __init__(self, ids, ct_path, mask_path, patch_size, batch_size=2, min_max_norm=False, seed=2020, augment=True, validation=False, shuffle=True):
        
        self.ids = ids
        self.ct_path = ct_path
        self.mask_path = mask_path
        self.batch_size = batch_size
        self.min_max_norm = min_max_norm
        self.shuffle = shuffle
        self.patch_size = patch_size
        self.seed = seed
        self.augment = augment
        self.validation = validation
        if self.augment == True:
            self.transforms = []
            self.spatial_transforms = SpatialTransform(self.patch_size, 
                                          [i // 2 for i in self.patch_size],
                                          do_elastic_deform=False,
                                          do_rotation=True, 
                                          angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
                                          angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
                                          angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
                                          do_scale=True,
                                          scale=(0.7,1.4),
                                          border_mode_data='constant', border_cval_data=0,
                                          border_mode_seg='constant', border_cval_seg=0,
                                          order_seg=1, order_data=3,
                                          p_el_per_sample=0, p_rot_per_sample=0.2, p_scale_per_sample=0.2)

            self.transforms.append(self.spatial_transforms)
            self.transforms.append(GaussianNoiseTransform(noise_variance=(0.0, 0.1), p_per_sample=0.15))
            self.transforms.append(GaussianBlurTransform(blur_sigma=(0.5, 1.5), p_per_sample=0.1))
            self.transforms.append(BrightnessMultiplicativeTransform((0.7, 1.3), per_channel=True, p_per_sample=0.15))
            self.transforms.append(ContrastAugmentationTransform((0.65, 1.5), preserve_range=True, p_per_sample=0.15))
            self.transforms.append(SimulateLowResolutionTransform((1,2), order_downsample=0, order_upsample=3, p_per_sample=0.25))
            self.transforms.append(GammaTransform(gamma_range=(0.7, 1.5), invert_image=False, p_per_sample=0.15))
            self.transforms.append(GammaTransform(gamma_range=(0.7, 0.71), invert_image=True, p_per_sample=1))
            self.transforms.append(MirrorTransform(axes=(0,1,2)))

            self.transforms = Compose(self.transforms)
        self.on_epoch_end()
    
    def __len__(self):
        
        'Number of batches per epoch'
        return len(self.ids)
    
    def __getitem__(self, index):
        
        'Generates one batch of data'
        indexes = self.indexes[index:index+1]
        ids_batch = [self.ids[k] for k in indexes] # There should just be one id in ids_batch
    
        X, y = self.__data_generation(ids_batch)

        return X, y
    
    def on_epoch_end(self):
        
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.ids))
        
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
        if self.validation == True:
            random.seed(self.seed) # To ensure validation set is the same for each epoch (though will have to review this idea b/c may just be good on the subset of patches that are randomly selected)
        
    def __data_generation(self, ids_batch):
        
        'Generates data containing batch samples'
        X = []
        y = []
        
        for _id in ids_batch: # There should just be one _id in ids_batch
            
            _ct_path = os.path.join(self.ct_path, _id)
            _mask_path = os.path.join(self.mask_path, _id)

            _mask = load(_mask_path)['arr_0']
            _ct_arr = load(_ct_path)['arr_0']
            
            # Can't be too careful
            if (_ct_arr.shape != _mask.shape):
                
                raise Exception(str(_ct_arr.shape) + " " + str(_mask.shape) + " " + str(_ct_path))
            
            # Pad to patch size if CT shape < patch size in any dimension
            if (_ct_arr.shape[0] < self.patch_size[0]) or (_ct_arr.shape[1] < self.patch_size[1]) or (_ct_arr.shape[2] < self.patch_size[2]):
                
                _padded_ct_arr = np.zeros((max(self.patch_size[0],_ct_arr.shape[0]),
                                       max(self.patch_size[1],_ct_arr.shape[1]),
                                       max(self.patch_size[2],_ct_arr.shape[2])))
                
                _padded_mask = np.zeros((max(self.patch_size[0],_ct_arr.shape[0]),
                                       max(self.patch_size[1],_ct_arr.shape[1]),
                                       max(self.patch_size[2],_ct_arr.shape[2])))
                
                pad_x = max(self.patch_size[0] - _ct_arr.shape[0],0)
                pad_y = max(self.patch_size[1] - _ct_arr.shape[1],0)
                pad_z = max(self.patch_size[2] - _ct_arr.shape[2],0)
                
                
                _padded_ct_arr[pad_x//2:pad_x//2+_ct_arr.shape[0],
                            pad_y//2:pad_y//2+_ct_arr.shape[1],
                            pad_z//2:pad_z//2+_ct_arr.shape[2]] = _ct_arr
                
                _padded_mask[pad_x//2:pad_x//2+_ct_arr.shape[0],
                            pad_y//2:pad_y//2+_ct_arr.shape[1],
                            pad_z//2:pad_z//2+_ct_arr.shape[2]] = _mask
                
                _ct_arr = _padded_ct_arr
                _mask = _padded_mask
            

            # Min/max normalization
            if self.min_max_norm == True:
                _ct_arr = np.clip(_ct_arr, 0, 1500)
                _ct_arr /= 1500
            
            # Z Score Normalization (DEFAULT)
            else:
                _ct_arr = _ct_arr - np.mean(_ct_arr)
                _ct_arr /= np.std(_ct_arr)
            
            # pick 10 patches, make sure all contain tumor (for non-iSABR pts)
            for patch in range(self.batch_size):
                
                patch_contains_tumor = False

                if _id not in self.isabr and patch_contains_tumor == False and (_mask.max() == 1):
                    while patch_contains_tumor == False:
                        r_x = random.choice(range(_ct_arr.shape[0]-self.patch_size[0]+1))
                        r_y = random.choice(range(_ct_arr.shape[1]-self.patch_size[1]+1))
                        r_z = random.choice(range(_ct_arr.shape[2]-self.patch_size[2]+1))

                        _ct_patch = _ct_arr[r_x:r_x+self.patch_size[0],
                                            r_y:r_y+self.patch_size[1],
                                            r_z:r_z+self.patch_size[2]]

                        _mask_patch = _mask[r_x:r_x+self.patch_size[0],
                                            r_y:r_y+self.patch_size[1],
                                            r_z:r_z+self.patch_size[2]]

                        if _mask_patch.max() == 1:
                            patch_contains_tumor = True
                else:
                    r_x = random.choice(range(_ct_arr.shape[0]-self.patch_size[0]+1))
                    r_y = random.choice(range(_ct_arr.shape[1]-self.patch_size[1]+1))
                    r_z = random.choice(range(_ct_arr.shape[2]-self.patch_size[2]+1))

                    _ct_patch = _ct_arr[r_x:r_x+self.patch_size[0],
                                        r_y:r_y+self.patch_size[1],
                                        r_z:r_z+self.patch_size[2]]

                    _mask_patch = _mask[r_x:r_x+self.patch_size[0],
                                        r_y:r_y+self.patch_size[1],
                                        r_z:r_z+self.patch_size[2]]
                
                # Augmentations
                if self.augment == True:

                    _ct_patch = np.expand_dims(_ct_patch,axis=0)
                    _ct_patch = np.expand_dims(_ct_patch,axis=0)

                    _mask_patch = np.expand_dims(_mask_patch,axis=0)
                    _mask_patch = np.expand_dims(_mask_patch,axis=0)

                    _ct_patch = _ct_patch.astype('float32') 
                    
                    data = AugmentationDataLoader({'vol': _ct_patch, 'seg': _mask_patch})
                    #multithread_generator = MultiThreadedAugmenter(data, self.transforms,num_processes=1,num_cached_per_queue=1)
                    multithread_generator = SingleThreadedAugmenter(data, self.transforms)

                    augmented = next(multithread_generator)

                    _ct_patch = augmented['data']
                    _mask_patch = augmented['seg']


                    _ct_patch = _ct_patch[0,0,:,:,:]
                    _mask_patch = _mask_patch[0,0,:,:,:]
                
                X.append(_ct_patch)
                y.append(_mask_patch)
        
        X = np.array(X, dtype = 'float32')
        y = np.array(y, dtype = 'float32')
        
        X = X.reshape(-1,X.shape[1],X.shape[2],X.shape[3],1)
        y = y.reshape(-1,y.shape[1],y.shape[2],y.shape[3],1)  

        return X, y
    
    
class AugmentationDataLoader(SlimDataLoaderBase):
    
    def __init__(self, data):

        super(AugmentationDataLoader,self).__init__(data, batch_size=1)

    def generate_train_batch(self):

        vol = self._data['vol']
        seg = self._data['seg']

        return {'data':vol.astype(np.float32), 'seg':seg}