In [1]:
import sys
import time
sys.path.append('/home/kevinteng/Desktop/BrainTumourSegmentation')
import numpy as np 
import tensorflow as tf
import matplotlib.pyplot as plt
import os, random
import utils
from utils_vis import plot_comparison, plot_labels_color 
from utils import dice_coef, ss_metric, compute_metric
import nibabel as nib
from sklearn.model_selection import KFold
%matplotlib inline

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


- Blue => Label 1 (Necrotic and Non-enhancing Tumor Core)
- Yellow => Label 2 (Peritumoral Edema)
- Green => Label 3/4 (GD-Enhancing Tumor)
---
* Core => Label 1 & 3
* Enhancing => Label 3
* Complete => Label 1,2, 3

---

# Hyperparameter

In [2]:
SHUFFLE_BUFFER = 4000
max_epochs = 2
BATCH_SIZE = 24
lr = 0.00001
opt = tf.keras.optimizers.Adam(lr)
ver = 'DeepSupervisedAttentionUNet03' #save version 
dropout=0.3 #dropout rate
hn = 'he_normal' #kernel initializer 
tfrecords_read_dir = '/home/kevinteng/Desktop/ssd02/BraTS20_tfrecords03/'
stack_npy = "/home/kevinteng/Desktop/ssd02/BraTS2020_stack03/"

---

# Helper Functions

In [3]:
xent = tf.keras.losses.CategoricalCrossentropy()

def generalized_dice(y_true, y_pred, smooth = 1e-5):
    """
    Generalized Dice Score
    https://arxiv.org/pdf/1707.03237
    https://github.com/Mehrdad-Noori/Brain-Tumor-Segmentation/blob/master/loss.py
    """
    
    y_true    = tf.reshape(y_true,shape=(-1,4))
    y_pred    = tf.reshape(y_pred,shape=(-1,4))
    sum_p     = tf.reduce_sum(y_pred, -2)
    sum_r     = tf.reduce_sum(y_true, -2)
    sum_pr    = tf.reduce_sum(y_true * y_pred, -2)
    weights   = tf.math.pow(tf.math.square(sum_r) + smooth, -1)
    generalized_dice = (2 * tf.reduce_sum(weights * sum_pr)) / (tf.reduce_sum(weights * (sum_r + sum_p)))
    return generalized_dice

def generalized_dice_loss(y_true, y_pred):   
    return 1-generalized_dice(y_true, y_pred)
    
def custom_loss(y_true, y_pred):
    
    """
    The final loss function consists of the summation of two losses "GDL" and "CE"
    with a regularization term.
    """
    
    return generalized_dice_loss(y_true, y_pred) + 1.25 * xent(y_true, y_pred)

def data_aug(imgs, seed=8888):
    x = tf.image.random_flip_up_down(imgs,seed)
    x = tf.image.random_flip_left_right(x,seed)
    return x

----

# Model

In [4]:
from utils_model import conv_block, coordconv_block, up, pool, attention_block
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, Activation, Add, Multiply, GaussianNoise

def AttUnet_model(input_layer, attention_mode='grid'):
    gauss1 = GaussianNoise(0.01)(input_layer)
    #downsampling path
    conv1 = conv_block(gauss1, filters=64, kernel_initializer=hn)
    pool1 = pool(conv1)
    
    conv2 = conv_block(pool1, filters=128, kernel_initializer=hn)
    pool2 = pool(conv2)
    
    conv3 = conv_block(pool2, filters=256, kernel_initializer=hn)
    pool3 = pool(conv3)
    
    conv4 = conv_block(pool3, filters=512, kernel_initializer=hn)
    pool4 = pool(conv4)
    
    conv5 = conv_block(pool4, filters=1024, kernel_initializer=hn)
    
    #upsampling path
    att01 = attention_block(conv4, conv5, 512)
    up1 = up(conv5,filters=512, merge=att01, kernel_initializer=hn)
    conv6 = conv_block(up1, filters=512, kernel_initializer=hn)
    
    if attention_mode=='grid':
        att02 = attention_block(conv3, conv6, 256)
    else:
        att02 = attention_block(conv3, conv4, 256)
    up2 = up(conv6, filters=256, merge=att02, kernel_initializer=hn)
    conv7 = conv_block(up2, filters=256, kernel_initializer=hn)
    #injection block 1
    seg01 = Conv2D(4,(1,1),padding='same')(conv7)
    up_seg01 = UpSampling2D()(seg01)
    
    if attention_mode=='grid':
        att03 = attention_block(conv2, conv7, 128)
    else:
        att03 = attention_block(conv2, conv3, 128)
    up3 = up(conv7, filters=128, merge=att03, kernel_initializer=hn)
    conv8 = conv_block(up3, filters=128, kernel_initializer=hn)
    #injection block 2
    seg02 = Conv2D(4,(1,1),padding='same')(conv8)
    add_21 = Add()([seg02, up_seg01])
    up_seg02 = UpSampling2D()(add_21)
    
    if attention_mode=='grid':
        att04 = attention_block(conv1, conv8, 64)
    else:
        att04 = attention_block(conv1, conv2, 64)
    up4 = up(conv8, filters=64, merge=att04, kernel_initializer=hn)
    conv9 = conv_block(up4, filters=64, kernel_initializer=hn)
    #injection block 3
    seg03 = Conv2D(4,(1,1),padding='same')(conv9)
    add_32 = Add()([seg03, up_seg02])
    
    output_layer = Conv2D(4, (1,1), activation = 'softmax')(add_32)
    
    return output_layer

In [5]:
#Build Model
input_layer = Input(shape=(240,240,4))
model = Model(input_layer, AttUnet_model(input_layer))

In [6]:
@tf.function
def train_fn(image, label):
    with tf.GradientTape() as tape:
        model_output = model(image, training=True)
        loss = custom_loss(label, model_output)
    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    
    return model_output, loss, gradients

@tf.function
def val_fn(image, label):
    model_output = model(image, training=False)
    loss = custom_loss(label, model_output)
    return model_output, loss

---

In [None]:
ds = os.listdir(tfrecords_read_dir)
kf = KFold(n_splits=len(ds),shuffle=True)
folds = 1
for train_id, val_id in kf.split(ds):
    print("Fold: {}".format(folds))
    epochs=1
    start = time.time()
    #list for training 
    loss_list = []
    acc_list = []
    loss_inner = []
    #for every fold we run(<=max_epochs)
    while epochs <= max_epochs:
        print("Epochs {:2d}".format(epochs))
        steps=1
        #training fold
        for idx in train_id:
            acc_inner = []
            dc_app = []
            sens_app = []
            spec_app = []
            train_tf = ds[idx]
            tf_dir = os.path.join(tfrecords_read_dir+train_tf)
            dataset = utils.parse_tfrecord(tf_dir).shuffle(SHUFFLE_BUFFER).batch(BATCH_SIZE)     
            for imgs in dataset:
                #data augmentation
                imgs = data_aug(imgs)
                image = imgs[:,:,:,:4]
                #unprocessed label for plotting 
                label = imgs[:,:,:,-1]
                #for simplicity label 4 will be converted to 3 for sparse encoding
                label = tf.where(label==4,3,label)
                label = tf.keras.utils.to_categorical(label, num_classes=4)
                img_seg, loss, gradients = train_fn(image, label) #training function 
                #map from sparse to label
                img_seg = tf.math.argmax(img_seg,-1,output_type=tf.int32) 
                label = tf.math.argmax(label,-1,output_type=tf.int32)
                #accuracy of the output values for that batch
                acc = tf.reduce_mean(tf.cast(tf.equal(img_seg,label), tf.float32))
                #append accuracy for every steps
                acc_inner.append(acc)
                #accumulate dc score, sensitivity and specificity 
                dc_list, sens_list, spec_list =compute_metric(label,img_seg)
                dc_app.append(dc_list)
                sens_app.append(sens_list)
                spec_app.append(spec_list)
                #output
                if steps%1000==0:
                    input_img = [image[0,:,:,0], plot_labels_color(label[0]), plot_labels_color(img_seg[0])]
                    caption = ['Input Image', 'Ground Truth', 'Model Output']
                    plot_comparison(input_img, caption, n_col = 3, figsize=(10,10), captions_font = 10)
                    loss_list.append(loss)
                    acc_stp = tf.reduce_mean(tf.cast(tf.equal(img_seg[0],label[0]), tf.float32))
                    dc_list_stp, sens_list_stp, spec_list_stp =compute_metric(label[0],img_seg[0])
                    print("Steps: {}, Loss:{}".format(steps, loss))
                    print("Accurary: {}".format(acc_stp))
                    print("Dice coefficient: {}".format(dc_list_stp))
                    print("Sensitivity: {}".format(sens_list_stp))
                    print("Specificity: {}".format(spec_list_stp))
                    print("Gradient min:{}, max:{}".format(np.min(gradients[0]), np.max(gradients[0])))
                steps+=1
        acc_list.append(np.mean(acc_inner))
        mean_dc = np.mean(np.array(dc_app),0)
        mean_sens = np.mean(np.array(sens_app),0)
        mean_spec = np.mean(np.array(spec_app),0)
        print()
        print('-----------<Training summary for Epoch:{}>------------'.format(epochs))
        print("Mean Accuracy: {}".format(np.mean(acc_list)))
        #'core','enhancing','complete'
        print("Mean Dice coefficient: {}".format(mean_dc))
        print("Mean Sensitivity: {}".format(mean_sens))
        print("Mean Specificity: {}".format(mean_spec))
        print('------------------------------------------------')
        print()
        #validation
        acc_inner_val=[]
        dc_app_val = []
        sens_app_val = []
        spec_app_val = []
        val_tf = ds[val_id[0]]
        val_tf_dir = os.path.join(tfrecords_read_dir+val_tf)
        val_ds = utils.parse_tfrecord(val_tf_dir).shuffle(SHUFFLE_BUFFER).batch(BATCH_SIZE)
        for imgs in val_ds:
            image = imgs[:,:,:,:4]
            label = imgs[:,:,:,-1]
            label = tf.where(label==4,3,label)
            label = tf.keras.utils.to_categorical(label, num_classes=4)
            val_img_seg, val_loss = val_fn(image, label)
            val_img_seg = tf.math.argmax(val_img_seg,-1,output_type=tf.int32) 
            label = tf.math.argmax(label,-1,output_type=tf.int32)
            acc = tf.reduce_mean(tf.cast(tf.equal(val_img_seg,label), tf.float32))
            #append accuracy for every steps
            acc_inner_val.append(acc)
            #accumulate dc score, sensitivity and specificity 
            dc_list, sens_list, spec_list =compute_metric(label,val_img_seg)
            dc_app_val.append(dc_list)
            sens_app_val.append(sens_list)
            spec_app_val.append(spec_list)
            
        acc_list_val = np.mean(acc_inner)
        mean_dc = np.mean(np.array(dc_app_val),0)
        mean_sens = np.mean(np.array(sens_app_val),0)
        mean_spec = np.mean(np.array(spec_app_val),0)
        print()
        print('-----------<Validation summary for Epoch:{}>------------'.format(epochs))
        print("Mean Accuracy: {}".format(np.mean(acc_list_val)))
        #'core','enhancing','complete'
        print("Mean Dice coefficient: {}".format(mean_dc))
        print("Mean Sensitivity: {}".format(mean_sens))
        print("Mean Specificity: {}".format(mean_spec))
        print('------------------------------------------------')
        print()
        
        elapsed_time =(time.time()-start)/60 #unit in mins
        print("Compute time per epochs: {:.2f} mins".format(elapsed_time))
        epochs+=1 
    folds+=1
    print()

Fold: 1
Epochs  1


---

# Save Weights

In [None]:
model.save_weights('/home/kevinteng/Desktop/model_weights/model_{}.h5'.format(ver))

---

# Validation 

In [None]:
model.load_weights('/home/kevinteng/Desktop/model_weights/model_{}.h5'.format(ver))
def output_fn(image):
    model.trainable = False
    model_output = model(image)
    return model_output

In [None]:
# ds = '/home/kevinteng/Desktop/ssd02/BraTS2020_preprocessed03/'
# save_path = '/home/kevinteng/Desktop/ssd02/submission/'
# actual_label = '/home/kevinteng/Desktop/ssd02/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_seg.nii.gz'
# #all brain affine are the same just pick one 
# brain_affine = nib.load(actual_label).affine
# steps = 1
# acc_list = []
# for train_or_val in sorted(os.listdir(ds)):
#     save_dir = save_path + train_or_val+'_'+ver
#     if not os.path.exists(save_dir):
#         os.makedirs(save_dir)
#     merge01 = os.path.join(ds+train_or_val)
#     for patient in sorted(os.listdir(merge01)):
#         patient_id = patient.split('.')[0]
#         merge02 = os.path.join(merge01,patient)
#         imgs = np.load(merge02)
#         image = imgs[:,:,:,:4]
#         seg_output = 0 #flush RAM
#         seg_output = np.zeros((240,240,155))
#         for i in range(image.shape[0]):
#             inp = tf.expand_dims(image[i],0)
#             img_seg = output_fn(inp) #validation function 
#             #map from sparse to label
#             seg_output[:,:,i] = np.argmax(img_seg,-1) 
#         #convert label from 4 to 3 and np array and cast as int
#         seg_output= np.where(seg_output==3,4,seg_output).astype(np.uint8)
#         prediction_ni = nib.Nifti1Image(seg_output, brain_affine)
#         prediction_ni.to_filename(save_dir+'/{}.nii.gz'.format(patient_id))

---

# Model Summary

In [None]:
model.summary()