In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import numpy as np
import tensorflow as tf
import nibabel as nib
import glob
import time
from tensorflow.keras.utils import to_categorical
from sys import stdout
import matplotlib.pyplot as plt
import matplotlib.image as mpim
import elasticdeform as ed
from scipy.ndimage.interpolation import affine_transform
import concurrent.futures

In [None]:
from tensorflow.keras.preprocessing.image import apply_affine_transform

Nclasses = 4
classes = np.arange(Nclasses)

# images lists
t1_list = sorted(glob.glob('/nobackup/data/marci30/*/*/*t1.nii'))
t2_list = sorted(glob.glob('/nobackup/data/marci30/*/*/*t2.nii'))
t1ce_list = sorted(glob.glob('/nobackup/data/marci30/*/*/*t1ce.nii'))
flair_list = sorted(glob.glob('/nobackup/data/marci30/*/*/*flair.nii'))
seg_list = sorted(glob.glob('/nobackup/data/marci30/*/*/*seg.nii'))

idxTrain, idxValid, idxTest = np.load('idxTrain_cv5.npy'), np.load('idxValid_cv5.npy'), np.load('idxTest.npy')
print('Training, validation and testing set have lenghts: {}, {} and {} respectively.'.format(len(idxTrain), len(idxValid), len(idxTest)))

sets = {'train': [], 'valid': [], 'test': []}

for i in idxTrain:
    sets['train'].append([t1_list[i], t2_list[i], t1ce_list[i], flair_list[i], seg_list[i]])
for i in idxValid:
    sets['valid'].append([t1_list[i], t2_list[i], t1ce_list[i], flair_list[i], seg_list[i]])
for i in idxTest:
    sets['test'].append([t1_list[i], t2_list[i], t1ce_list[i], flair_list[i], seg_list[i]])
    
def load_img(img_files):
    ''' Load one image and its target form file
    '''
    N = len(img_files)
    # target
    y = nib.load(img_files[N-1]).get_fdata(dtype='float32')
    y = y[40:200,34:226,8:136]
    y[y==4]=3
      
    X_norm = np.empty((240, 240, 155, 4))
    for channel in range(N-1):
        X = nib.load(img_files[channel]).get_fdata(dtype='float32')
        brain = X[X!=0] 
        brain_norm = np.zeros_like(X) # background at -100
        norm = (brain - np.mean(brain))/np.std(brain)
        brain_norm[X!=0] = norm
        X_norm[:,:,:,channel] = brain_norm        
        
    X_norm = X_norm[40:200,34:226,8:136,:]    
    del(X, brain, brain_norm)
    
    return X_norm, y

def flip3D(X, y):
    choice = np.random.randint(3)
    if choice == 0: # flip on x
        X_flip, y_flip = X[::-1, :, :, :], y[::-1, :, :]
    if choice == 1: # flip on y
        X_flip, y_flip = X[:, ::-1, :, :], y[:, ::-1, :]
    if choice == 2: # flip on z
        X_flip, y_flip = X[:, :, ::-1, :], y[:, :, ::-1]
        
    return X_flip, y_flip

In [None]:
from scipy.ndimage.interpolation import affine_transform

def rotation3D(X, y):
    alpha, beta, gamma = np.random.randint(0, 31, size=3)/180*np.pi
    Rx = np.array([[1, 0, 0],
                   [0, np.cos(alpha), -np.sin(alpha)],
                   [0, np.sin(alpha), np.cos(alpha)]])
    
    Ry = np.array([[np.cos(beta), 0, np.sin(beta)],
                   [0, 1, 0],
                   [-np.sin(beta), 0, np.cos(beta)]])
    
    Rz = np.array([[np.cos(gamma), -np.sin(gamma), 0],
                   [np.sin(gamma), np.cos(gamma), 0],
                   [0, 0, 1]])
    
    R = np.dot(np.dot(Rx, Ry), Rz)
    
    X_rot = np.empty_like(X)
    for channel in range(X.shape[-1]):
        X_rot[:,:,:,channel] = affine_transform(X[:,:,:,channel], R, offset=0, order=3, mode='constant')
    y_rot = affine_transform(y, R, offset=0, order=0, mode='constant')
    
    return X_rot, y_rot

## Data Generator for Keras

In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, batch_size=4, dim=(160,192,128), n_channels=4, n_classes=4, shuffle=True, augmentation=False, patch_size=64, n_patches=8):
        'Initialization'
        self.list_IDs = list_IDs
        self.batch_size = batch_size
        self.dim = dim
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.augmentation = augmentation
        self.patch_size = patch_size
        self.n_patches = n_patches
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data     
        X, y = self.__data_generation(list_IDs_temp)
        if self.augmentation == True:
            X, y = self.__data_augmentation(X, y)
        
        if index == self.__len__()-1:
            self.on_epoch_end()
        
        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
  
    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, *self.dim))

        # Generate data
        for i, IDs in enumerate(list_IDs_temp):
            # Store sample
            X[i], y[i] = load_img(IDs)
            
        if self.augmentation == True:
            return X.astype('float32'), y
        else:
            return X.astype('float32'), to_categorical(y, self.n_classes)

    def __data_augmentation(self, X, Y):
        x_aug = np.empty((self.batch_size*self.n_patches, self.patch_size, self.patch_size, self.patch_size, self.n_channels))
        y_aug = np.empty((self.batch_size*self.n_patches, self.patch_size, self.patch_size, self.patch_size))
#         print(X.shape, Y.shape)
        i = 0
        for b in range(self.batch_size):
            for _ in range(self.n_patches):
                x = np.random.randint(self.dim[0]-self.patch_size+1) 
                y = np.random.randint(self.dim[1]-self.patch_size+1)
                z = np.random.randint(self.dim[2]-self.patch_size+1)
                
                im = X[b, x:x+self.patch_size, y:y+self.patch_size, z:z+self.patch_size, :]
                gt = Y[b, x:x+self.patch_size, y:y+self.patch_size, z:z+self.patch_size]
                
                aug_choice = np.random.randint(4)
                # flip
                if aug_choice == 1:
                    im, gt = flip3D(im, gt)
                # rotation
                if aug_choice == 2:
                    im, gt = rotation3D(im, gt)
                # flip + rotation
                if aug_choice == 3:
                    im, gt = flip3D(im, gt)
                    im, gt = rotation3D(im, gt)
                
                x_aug[i], y_aug[i] = im, gt
                i += 1
                
        return x_aug, to_categorical(y_aug, self.n_classes)

train_gen = DataGenerator(sets['train'], augmentation=True, patch_size=128, n_patches=1)
valid_gen = DataGenerator(sets['valid'], augmentation=True, patch_size=128, n_patches=1)

### class weights

In [None]:
class_weights = np.load('class_weights2.npy').astype('float32')
print(class_weights)

## GAN: Vox2Vox

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, Conv3DTranspose, Dropout, ReLU, LeakyReLU, Concatenate
from tensorflow.keras.optimizers import Adam

import tensorflow_addons as tfa
from tensorflow_addons.layers import InstanceNormalization

class vox2vox():
    def __init__(self, img_shape, seg_shape, class_weights, Nfilter_start=64, depth=4, batch_size=3, LAMBDA=5):
        self.img_shape = img_shape
        self.seg_shape = seg_shape
        self.class_weights = class_weights
        self.Nfilter_start = Nfilter_start
        self.depth = depth
        self.batch_size = batch_size
        self.LAMBDA = LAMBDA
        
        def diceLoss(y_true, y_pred, w=self.class_weights):
            y_true = tf.convert_to_tensor(y_true, 'float32')
            y_pred = tf.convert_to_tensor(y_pred, y_true.dtype)

            num = tf.math.reduce_sum(tf.math.multiply(w, tf.math.reduce_sum(tf.math.multiply(y_true, y_pred), axis=[0,1,2,3])))
            den = tf.math.reduce_sum(tf.math.multiply(w, tf.math.reduce_sum(tf.math.add(y_true, y_pred), axis=[0,1,2,3])))+1e-5

            return 1-2*num/den

        # Build and compile the discriminator
        self.discriminator = self.Discriminator()
        self.discriminator.compile(loss='mse', optimizer=Adam(2e-4, beta_1=0.5), metrics=['accuracy'])

        # Construct Computational Graph of Generator
        # Build the generator
        self.generator = self.Generator()

        # Input images and their conditioning images
        seg = Input(shape=self.seg_shape)
        img = Input(shape=self.img_shape)

        # By conditioning on B generate a fake version of A
        seg_pred = self.generator(img)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # Discriminators determines validity of translated images / condition pairs
        valid = self.discriminator([seg_pred, img])

        self.combined = Model(inputs=[seg, img], outputs=[valid, seg_pred])
        self.combined.compile(loss=['mse', diceLoss], loss_weights=[1, self.LAMBDA], optimizer=Adam(2e-4, beta_1=0.5))
    
    def Generator(self):
        '''
        Generator model
        '''

        inputs = Input(self.img_shape, name='input_image')     

        def encoder_step(layer, Nf, inorm=True):
            x = Conv3D(Nf, kernel_size=4, strides=2, kernel_initializer='he_normal', padding='same')(layer)
            if inorm:
                x = InstanceNormalization()(x)
            x = LeakyReLU()(x)
            
            return x
        
        def bottlenek(layer, Nf):
            x = Conv3D(Nf, kernel_size=4, strides=2, kernel_initializer='he_normal', padding='same')(layer)
            x = InstanceNormalization()(x)
            x = LeakyReLU()(x)
            for i in range(4):
                y = Conv3D(Nf, kernel_size=4, strides=1, kernel_initializer='he_normal', padding='same')(x)
                x = InstanceNormalization()(y)
                x = Dropout(0.2)(x)
                x = LeakyReLU()(x)
                x = Concatenate()([x, y])
                
            return x

        def decoder_step(layer, layer_to_concatenate, Nf):
            x = Conv3DTranspose(Nf, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal')(layer)
            x = InstanceNormalization()(x)
            x = ReLU()(x)
            x = Concatenate()([x, layer_to_concatenate])
            return x

        layers_to_concatenate = []
        x = inputs

        # encoder
        for d in range(self.depth-1):
            if d==0:
                x = encoder_step(x, self.Nfilter_start*np.power(2,d), False)
            else:
                x = encoder_step(x, self.Nfilter_start*np.power(2,d))
            layers_to_concatenate.append(x)
        
        # bottlenek
        x = bottlenek(x, self.Nfilter_start*np.power(2,self.depth-1))

        # decoder
        for d in range(self.depth-2, -1, -1): 
            x = decoder_step(x, layers_to_concatenate.pop(), self.Nfilter_start*np.power(2,d))

        # classifier
        last = Conv3DTranspose(4, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal', activation='softmax', name='output_generator')(x)

       # Create model
        return Model(inputs=inputs, outputs=last, name='Generator')

    def Discriminator(self):
        '''
        Discriminator model
        '''
        
        inputs = Input(self.img_shape, name='input_image')
        targets = Input(self.seg_shape, name='target_image')

        def encoder_step(layer, Nf, inorm=True):
            x = Conv3D(Nf, kernel_size=4, strides=2, kernel_initializer='he_normal', padding='same')(layer)
            if inorm:
                x = InstanceNormalization()(x)
            x = LeakyReLU()(x)
            return x

        x = Concatenate()([inputs, targets])

        for d in range(self.depth):
            if d==0:
                x = encoder_step(x, self.Nfilter_start*np.power(2,d), False)
            else:
                x = encoder_step(x, self.Nfilter_start*np.power(2,d))


        last = tf.keras.layers.Conv3D(1, 4, strides=1, padding='same', kernel_initializer='he_normal', name='output_discriminator')(x) 

        return Model(inputs=[targets, inputs], outputs=last, name='Discriminator')
    
    def train_step(self, Xbatch, Ybatch, mp=True, n_workers=16):
        # Generetor output
        gen_output = self.generator.predict(Xbatch, use_multiprocessing=mp, workers=n_workers)
        
        # Discriminator output shape    
        disc_output_shape = self.discriminator.output_shape
        disc_output_shape = (gen_output.shape[0], *disc_output_shape[1:])
        
        # Train Discriminator
        disc_loss_real = self.discriminator.fit([Ybatch, Xbatch], tf.ones(disc_output_shape), verbose=0, use_multiprocessing=mp, workers=n_workers)
        disc_loss_fake = self.discriminator.fit([gen_output, Xbatch], tf.zeros(disc_output_shape), verbose=0, use_multiprocessing=mp, workers=n_workers)
        #disc_loss = disc_loss_real['loss'][0] + disc_loss_fake['loss'][0]

        # Train Generator
        gen_loss = self.combined.fit([Ybatch, Xbatch], [tf.ones(disc_output_shape), Ybatch], verbose=0, use_multiprocessing=mp, workers=16)
        #g_loss = [gen_loss.history['loss'][0], gen_loss.history['Discriminator_loss'][0], gen_loss.history['Generator_loss'][0]]
        
        return gen_loss
    
    def valid_step(self, Xbatch, Ybatch, mp=True, n_workers=16):
        # Generetor output
        gen_output = self.generator.predict(Xbatch, use_multiprocessing=mp, workers=n_workers)
        
        # Discriminator output shape    
        disc_output_shape = self.discriminator.output_shape
        disc_output_shape = (gen_output.shape[0], *disc_output_shape[1:])
        
        # Train Discriminator
        disc_loss_real = self.discriminator.evaluate([Ybatch, Xbatch], tf.ones(disc_output_shape), verbose=0, use_multiprocessing=mp, workers=n_workers)
        disc_loss_fake = self.discriminator.evaluate([gen_output, Xbatch], tf.zeros(disc_output_shape), verbose=0, use_multiprocessing=mp, workers=n_workers)
        #disc_loss = disc_loss_real['loss'][0] + disc_loss_fake['loss'][0]

        # Train Generator
        gen_loss = self.combined.evaluate([Ybatch, Xbatch], [tf.ones(disc_output_shape), Ybatch], verbose=0, use_multiprocessing=mp, workers=n_workers)
        #g_loss = [gen_loss.history['loss'][0], gen_loss.history['Discriminator_loss'][0], gen_loss.history['Generator_loss'][0]]
        
        return gen_loss

    
    def train(self, train_generator, valid_generator, nEpochs):
        print('Training process:')
        print('Training on {} and validating on {} batches.\n'.format(len(train_generator), len(valid_generator)))
        
        # we save in a dictionary the histories obtained after each epoch
        trends_train = tf.keras.callbacks.History()
        trends_train.epoch = []
        trends_train.history = {'loss': [], 'Discriminator_loss': [], 'Generator_loss': []}
        
        trends_valid = tf.keras.callbacks.History()
        trends_valid.epoch = []
        trends_valid.history = {'loss': [], 'Discriminator_loss': [], 'Generator_loss': []}
        
        path = '/home/marci30/Desktop/RESULTS/cv5' # './Results_mri2seg_128_aug_lambda{}'.format(self.LAMBDA)
        if os.path.exists(path)==False:
            os.mkdir(path)
        
        prev_loss = np.inf
        
        for e in range(nEpochs): 
            
            print('Epoch {}/{}'.format(e+1,nEpochs))
            start_time = time.time()           
            
            b = 0
            for Xbatch, Ybatch in train_generator:
                b+=1
                gan_losses = self.train_step(Xbatch, Ybatch)
                gan_losses.history['Generator_loss'][0] *= self.LAMBDA
                stdout.write('\rBatch: {}/{} - v2v_loss: {:.4f} - disc_loss: {:.4f} - gen_loss: {:.4f}'.format(b, len(train_generator), gan_losses.history['loss'][0], gan_losses.history['Discriminator_loss'][0], gan_losses.history['Generator_loss'][0]))
                stdout.flush()
            del(Xbatch, Ybatch)
            
            for Xbatch, Ybatch in valid_generator:
                gan_losses_val = self.valid_step(Xbatch, Ybatch)   
#             del(Xbatch, Ybatch)
            
            log = {'loss': gan_losses_val[0], 'Discriminator_loss': gan_losses_val[1], 'Generator_loss': gan_losses_val[2]*self.LAMBDA}
            stdout.write(' - v2v_loss_val: {:.4f} - disc_loss_val: {:.4f} - gen_loss_val: {:.4f}'.format(gan_losses_val[0], gan_losses_val[1], gan_losses_val[2]))
            elapsed_time = time.time() - start_time
            stdout.write('\nElapsed time: {}:{} mm:ss'.format(int(elapsed_time//60), int(elapsed_time%60)))
            stdout.flush()
                
            # saving the loss values
            trends_train.on_epoch_end(e, gan_losses.history)
            trends_valid.on_epoch_end(e, log)        
            print('\n ')
            
            # save tmp images
            y_pred = self.generator.predict(Xbatch)
            Ybatch = np.argmax(Ybatch, axis=-1)
            y_pred = np.argmax(y_pred, axis=-1)
            
            imsize, r, c = 128, 1, 3

            canvas = np.zeros((r*imsize,c*imsize))
            for i in range(r):
                s = Xbatch[i,:,:,imsize//2,2] 
                canvas[i*imsize : (i+1)*imsize, 0 : imsize] = (s - np.min(s)) / (np.max(s)-np.min(s))
                canvas[i*imsize : (i+1)*imsize, imsize : 2*imsize] = Ybatch[i,:,:,imsize//2]/6
                canvas[i*imsize : (i+1)*imsize, 2*imsize : 3*imsize] = y_pred[i,:,:,imsize//2]/6

            del(Xbatch, Ybatch)
            
            fname = (path + '/pred@epoch_{}.png').format(e+1)
            mpim.imsave(fname, canvas, cmap='gray')
            
            if gan_losses_val[0]<prev_loss:
                print("Validation loss decreaed from {:.4f} to {:.4f}. Hence models' weights are now saved.".format(prev_loss, gan_losses_val[0]))
                prev_loss = gan_losses_val[0]
                self.generator.save_weights(path + '/Generator.h5') 
                self.discriminator.save_weights(path + '/Discriminator.h5') 
                self.combined.save_weights(path + '/Vox2Vox.h5')
        
        np.save(path + '/history_train', trends_train.history)
        np.save(path + '/history_valid', trends_valid.history)
        
        return trends_train, trends_valid

imShape = (128, 128, 128, 4) 
gtShape = (128, 128, 128, 4)
gan = vox2vox(imShape, gtShape, class_weights, depth=4, batch_size=4, LAMBDA=5)       

In [None]:
trends_train, trends_valid = gan.train(train_gen, valid_gen, 200)