In [1]:
import sys
sys.path.append('/home/kevinteng/Desktop/BrainTumourSegmentation')
import numpy as np 
import tensorflow as tf
import matplotlib.pyplot as plt
import os 
import utils
from utils_vis import plot_comparison, plot_labels_color 
from sklearn.metrics import confusion_matrix
%matplotlib inline

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


---

# Hyperparameter

In [2]:
SHUFFLE_BUFFER = 4000
BATCH_SIZE = 16
lr = 0.000001
opt = tf.keras.optimizers.Adam(lr)
ver = '06' #save version 
dropout=0.2 #dropout rate
hn = 'he_normal' #kernel initializer 
tfrecords_read_dir = '/home/kevinteng/Desktop/ssd02/BraTS20_tfrecords03/HGG/'

---

# Helper Functions

In [3]:
def dice_coef(y_true, y_pred, smooth=1e-5):
    '''
    Dice coefficient for tensorflow
    :param y_true: Ground truth
    :param y_pred: Prediction from the model
    :return: dice coefficient 
    '''
    #if input is not flatten
    if (tf.rank(y_true)!=1 and tf.rank(y_pred)!=1):
        y_true = tf.reshape(y_true, [-1]) #flatten 
        y_pred = tf.reshape(y_pred, [-1]) #flatten
    #casting for label from int32 to float32 for computation
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred)
    return (2.0 * intersection + smooth) / \
(tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + smooth)

def ss_metric(y_true, y_pred, label_type = 'binary', mode = 'global', smooth=1e-5):
    '''
    Compute sensitivity and specificity for groundtruth and prediction
    :param y_true: Ground truth
    :param y_pred: Prediction from the model
    :label_type: 'binary': input labels is binarized
                 'multi': mutli class labels
    :mode: 'local' compute the sensitivity label wise
           'global' compute the sensitivity overall
    :return: sensitivity & specificity 
    '''
    #if input is not flatten
    if (tf.rank(y_true)!=1 and tf.rank(y_pred)!=1):
        y_true = tf.reshape(y_true, [-1]) #flatten 
        y_pred = tf.reshape(y_pred, [-1]) #flatten
    #label types    
    if label_type =='binary':
        tn, fp, fn, tp = confusion_matrix(y_true , y_pred, labels = [0,1]).ravel()
        sensitivity = (tp+smooth)/(tp+fn+smooth)
        specificity = (tn+smooth)/(tn+fp+smooth)
    if label_type =='multi':
        cm = confusion_matrix(y_true , y_pred, labels = [0,1,2,3])
        #true positive rate 
        if mode=='global':
            tp = np.trace(cm)
            tp_fn = np.sum(cm)
        else: #local
            tp = np.diag(cm)
            tp_fn = np.sum(cm,1)
        sensitivity = (tp+smooth)/(tp_fn+smooth)
        #true negative rate
        diag = np.diag(cm)
        tn = []
        for i in range(len(cm)):
            negs = np.sum([neg for neg in diag if neg!=diag[i]]) 
            tn.append(negs)
        cm_copy = cm
        #make diagonal 0
        for i in range(len(cm)):
            for j in range(len(cm)):
                if i==j:
                    cm_copy[i,j]=0
        if mode=='global':
            tn = np.sum(tn)
            fp = np.sum(cm_copy)
        else: #local
            tn = np.array(tn)
            fp = np.sum(cm_copy, 0)
        specificity = (tn+smooth)/(tn+fp+smooth)
    return sensitivity, specificity

def compute_metric(y_true, y_pred, label_type='binary'):
    '''
    This function compute the metrics specify by BraTS competition
    which is dice coefficient, sensitivity, specificity
    :param y_true: Ground truth image
    :param y_pred: Prediction image from the model
    :label_type: 'binary': input labels is binarized
             'multi': mutli class labels
    :return: dice coefficient, sensitivity & specificity list
            with order ['core', 'enhancing', 'complete']
    '''
    y_list = [y_true, y_pred]
    tumours = ['core', 'enhancing', 'complete']
    dc_output = []
    sens_output = []
    spec_output = []
    #compute dice coefficient for each tumour type
    for tumour_type in tumours:
        if label_type =='multi':
            #label 1, 3(4)
            if tumour_type== 'core':
                y_true, y_pred = [np.where(((lbl==1) | (lbl==3)), lbl, 0) for lbl in y_list]
            #label 3(4)
            if tumour_type== 'enhancing':
                y_true, y_pred = [np.where(lbl==3, lbl, 0) for lbl in y_list]
            #label 1,2,3,
            if tumour_type== 'complete':
                y_true, y_pred = [np.where(lbl>=0, lbl, 0) for lbl in y_list]
        if label_type =='binary':
            #label 1, 3(4) =>1
            if tumour_type== 'core':
                y_true, y_pred = [np.where(((lbl==1) | (lbl==3)), 1, 0) for lbl in y_list]
            #label 3(4) =>1
            if tumour_type== 'enhancing':
                y_true, y_pred = [np.where(lbl==3, 1, 0) for lbl in y_list]
            #label 1,2,3 =>1
            if tumour_type== 'complete':
                y_true, y_pred = [np.where(lbl>=0, 1, 0) for lbl in y_list]
        dc_list = []
        sens_list = []
        spec_list = []
        for idx in range(len(y_true)): 
            
            y_true_f= tf.reshape(y_true[idx], [-1]) #flatten 
            y_pred_f = tf.reshape(y_pred[idx], [-1]) #flatten

            dc = dice_coef(y_true_f, y_pred_f)
            sensitivity, specificity = ss_metric(y_true_f, y_pred_f)    
            #store values
            dc_list.append(dc)
            sens_list.append(sensitivity)
            spec_list.append(specificity)
        #output [BATCH_SIZE, tumours_type]
        #taking the mean along the batch axis
        mean_ = lambda x: np.mean(x)
        dc_batch_mean = mean_(dc_list)
        sens_batch_mean = mean_(sens_list)
        spec_batch_mean = mean_(spec_list)
        #append for each tumour type
        dc_output.append(dc_batch_mean)
        sens_output.append(sens_batch_mean)
        spec_output.append(spec_batch_mean)
    #for each list the order is as following=> 'core','enhancing','complete'    
    return dc_output, sens_output, spec_output

----

# Model

In [4]:
from utils_model import conv_block, coordconv_block, up, pool
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Conv2D

def Unet_model(input_layer):
    #downsampling
#     conv1 = coordconv_block(input_layer, x_dim=240, y_dim=240, filters=64)
    conv1 = conv_block(input_layer, filters=64, kernel_initializer=hn)
    pool1 = pool(conv1)
    
    conv2 = conv_block(pool1, filters=128, kernel_initializer=hn)
    pool2 = pool(conv2)
    
    conv3 = conv_block(pool2, filters=256, kernel_initializer=hn)
    pool3 = pool(conv3)
    
    conv4 = conv_block(pool3, filters=512, kernel_initializer=hn, dropout_rate = dropout)
    pool4 = pool(conv4)
    
    conv5 = conv_block(pool4, filters=1024, kernel_initializer=hn, dropout_rate = dropout)
    
    #upsampling
    up1 = up(conv5,filters=512, merge=conv4, kernel_initializer=hn)
#     conv6 = coordconv_block(up1, x_dim=30, y_dim=30, filters=512)
    conv6 = conv_block(up1, filters=512, kernel_initializer=hn)
    
    up2 = up(conv6, filters=256, merge=conv3, kernel_initializer=hn)
    conv7 = conv_block(up2, filters=256, kernel_initializer=hn)
    
    up3 = up(conv7, filters=128, merge=conv2, kernel_initializer=hn)
    conv8 = conv_block(up3, filters=128, kernel_initializer=hn)
    
    up4 = up(conv8, filters=64, merge=conv1, kernel_initializer=hn)
    conv9 = conv_block(up4, filters=64, kernel_initializer=hn)
    
    output_layer = Conv2D(4, (1,1), activation = 'softmax')(conv9)
    
    return output_layer

In [5]:
input_layer = Input(shape=(240,240,4))
Unet = Model(input_layer, Unet_model(input_layer))

In [6]:
#to do..Sensitivity
xent = tf.keras.losses.CategoricalCrossentropy()
@tf.function
def train_fn(image, label):
    with tf.GradientTape() as tape:
        model_output = Unet(image)
        loss = xent(label, model_output)
    gradients = tape.gradient(loss, Unet.trainable_variables)
    opt.apply_gradients(zip(gradients, Unet.trainable_variables))
    
    return model_output, loss, gradients

In [None]:
epochs = 1
max_epochs = 30
#list
loss_list = []
acc_list = []
loss_inner = []
while epochs <= max_epochs:
    print()
    print("Epochs {:2d}".format(epochs))
    steps = 1
    dc_app = []
    sens_app = []
    spec_app = []
    for tf_re in sorted(os.listdir(tfrecords_read_dir)):
        tf_dir = os.path.join(tfrecords_read_dir+tf_re)
        dataset = utils.parse_tfrecord(tf_dir).shuffle(SHUFFLE_BUFFER).batch(BATCH_SIZE)
        acc_inner = []
        for imgs in dataset:
            image = imgs[:,:,:,:4]
            #unprocessed label for plotting 
            label = imgs[:,:,:,-1]
            #for simplicity label 4 will be converted to 3 for sparse encoding
            label = tf.where(label==4,3,label)
            label = tf.keras.utils.to_categorical(label, num_classes=4)
            img_seg, loss, gradients = train_fn(image, label) #training function 
            #map from sparse to label
            img_seg = tf.math.argmax(img_seg,-1,output_type=tf.int32) 
            label = tf.math.argmax(label,-1,output_type=tf.int32)
            #accuracy of the output values for that batch
            acc = tf.reduce_mean(tf.cast(tf.equal(img_seg,label), tf.float32))
            dc_list, sens_list, spec_list =compute_metric(label,img_seg)
            #append accuracy for every steps
            acc_inner.append(acc)
            if epochs%5==0:
                dc_app.append(dc_list)
                sens_app.append(sens_list)
                spec_app.append(spec_list)
            #output
            if steps%2000==0:
                input_img = [image[0,:,:,0], plot_labels_color(label[0]), plot_labels_color(img_seg[0])]
                caption = ['Input Image', 'Ground Truth', 'Model Output']
                plot_comparison(input_img, caption, n_col = 3, figsize=(10,10))
                loss_list.append(loss)
                acc_stp = tf.reduce_mean(tf.cast(tf.equal(img_seg[0],label[0]), tf.float32))
                dc_list_stp, sens_list_stp, spec_list_stp =compute_metric(label[0],img_seg[0])
                print("Steps: {}, Loss:{}".format(steps, loss))
                print("Accurary: {}".format(acc_stp))
                print("Dice coefficient: {}".format(dc_list_stp))
                print("Sensitivity: {}".format(sens_list_stp))
                print("Specificity: {}".format(spec_list_stp))
                print("Gradient min:{}, max:{}".format(np.min(gradients[0]), np.max(gradients[0])))
            steps+=1
        acc_list.append(np.mean(acc_inner))
    if epochs%5==0:
        mean_dc = np.mean(np.array(dc_app),0)
        mean_sens = np.mean(np.array(sens_app),0)
        mean_spec = np.mean(np.array(spec_app),0)
        print()
        print('-----------<Summary for Epoch:{}>------------'.format(epochs))
        print("Mean Accuracy: {}".format(np.mean(acc_list)))
        print("Mean Dice coefficient: {}".format(mean_dc))
        print("Mean Sensitivity: {}".format(mean_sens))
        print("Mean Specificity: {}".format(mean_spec))
        print('------------------------------------------------')
        print()
    epochs+=1 


Epochs  1


---

# Save Weights

In [None]:
Unet.save_weights('/home/kevinteng/Desktop/model_weights/Unet_{}.h5'.format(ver))

---

# Validation 

In [None]:
def val_fn(image, label):
    Unet.load_weights('/home/kevinteng/Desktop/model_weights/Unet_{}.h5'.format(ver))
    Unet.trainable = False
    model_output = Unet(image)
    loss = xent(label, model_output)
    return model_output, loss

In [None]:
tfrecords_val = '/home/kevinteng/Desktop/ssd02/BraTS20_tfrecords03/LGG/'

steps = 1
acc_list = []
for tf_re in sorted(os.listdir(tfrecords_val)):
    tf_dir = os.path.join(tfrecords_val+tf_re)
    dataset = utils.parse_tfrecord(tf_dir).shuffle(SHUFFLE_BUFFER).batch(BATCH_SIZE)
    dc_app = []
    sens_app = []
    spec_app = []
    for imgs in dataset:
        image = imgs[:,:,:,:4]
        label = imgs[:,:,:,-1]
        label = tf.where(label==4,3,label)
        #for simplicity label 4 will be converted to 3 for sparse encoding
        label = tf.keras.utils.to_categorical(label, num_classes=4)
        img_seg, loss = val_fn(image, label) #validation function 
        #map from sparse to label
        img_seg = tf.math.argmax(img_seg,-1,output_type=tf.int32) 
        label = tf.math.argmax(label,-1,output_type=tf.int32)
        #accuracy of the output values for that batch
        acc = tf.reduce_mean(tf.cast(tf.equal(img_seg,label), tf.float32))
        dc_list, sens_list, spec_list =compute_metric(label,img_seg)
        #append
        acc_list.append(acc)
        dc_app.append(dc_list)
        sens_app.append(sens_list)
        spec_app.append(spec_list)
        #output
        if steps%100==0:
#             dc_list, sens_list, spec_list =compute_metric(label[0],img_seg[0])
            input_img = [image[0,:,:,0], plot_labels_color(label[0]), plot_labels_color(img_seg[0])]
            caption = ['Input Image', 'Ground Truth', 'Model Output']
            plot_comparison(input_img, caption, n_col = 3, figsize=(10,10))
            acc_stp = tf.reduce_mean(tf.cast(tf.equal(img_seg[0],label[0]), tf.float32))
            dc_list, sens_list, spec_list =compute_metric(label[0],img_seg[0])
            print("Steps: {}, Loss:{}".format(steps, loss))
            print("Accuracy: {}".format(acc_stp))
            print("Dice coefficient: {}".format(dc_list))
            print("Sensitivity: {}".format(sens_list))
            print("Specificity: {}".format(spec_list))
        steps+=1
    mean_dc = np.mean(np.array(dc_app),0)
    mean_sens = np.mean(np.array(sens_app),0)
    mean_spec = np.mean(np.array(spec_app),0)
    print("Mean Accuracy: {}".format(np.mean(acc_list)))
    print("Mean Dice coefficient: {}".format(mean_dc))
    print("Mean Sensitivity: {}".format(mean_sens))
    print("Mean Specificity: {}".format(mean_spec))

---

# Model Summary

In [None]:
Unet.summary()