In [0]:
from google.colab import drive
drive.mount('/content/drive')

import sys
import time
import copy
import random

import keras
from keras import layers
from keras import models
from keras import optimizers
from keras.preprocessing.image import ImageDataGenerator
from keras import backend
from keras.utils import plot_model

import numpy as np
import matplotlib.pyplot as plt

import math



# Load Dataset and Preprocessing
Perform histogram equalization on the data

In [0]:
dataset = np.load("drive/My Drive/MRI_Brain_Segmentation/dataset/Dataset_final_1.npy", allow_pickle=True)
#x = np.load("drive/My Drive/MRI_Brain_Segmentation/dataset/Dataset_final_2.npy", allow_pickle=True) 
#dataset = np.concatenate((dataset,x),axis=0)
#x = np.load("drive/My Drive/MRI_Brain_Segmentation/dataset/Dataset_final_3.npy", allow_pickle=True) 
#dataset = np.concatenate((dataset,x),axis=0)

In [0]:
def histeq(im,nbr_bins=256):
  """This is for image equalization"""
  #get image histogram
  imhist,bins = np.histogram(im.flatten(),nbr_bins,normed=True)
   
  cdf = imhist.cumsum() #cumulative distribution function
  
  cdf_m = np.ma.masked_equal(cdf,0)#mask the background voxels 
    
  # the main step of histogram equalization
  cdf_m = (cdf_m - cdf_m.min())*255/(cdf_m.max()-cdf_m.min()) 
   
  cdf = np.ma.filled(cdf_m,0).astype('uint8') # set the removed background pixels back to 0

  #use linear interpolation of cdf to find new pixel values
    
  # im2 = np.interp(im.flatten(),bins[:-1],cdf)/255 # this line can cause the program to fail
  # ############################################## because it returns a float64 and run out of RAM
  im2 = (np.interp(im.flatten(),bins[:-1],cdf)/255).astype(np.float32)

  return im2.reshape(im.shape), cdf



# padding to all sides of the 3D volume
npad = ((43,43), (43,43), (43,43))
for i in range (0, len(dataset)):  
  dataset[i][0] = np.pad(dataset[i][0], pad_width=npad, mode='constant', constant_values=0)
  dataset[i][1] = np.pad(dataset[i][1], pad_width=npad, mode='constant', constant_values=0)
  #
  #

  if i == 0: # Just for the first patient
    # plot histogram of the dataset raw values
    # Do not display the voxels at 0 for there are too many background voxels
    plt.hist(dataset[i][0].ravel(),256,[0.000001,np.amax(dataset[i][0])])
    plt.show()
    #print("**** Display some sample values for reference: z=125; x=125; y=[125:140] ")
    #print("raw: ",dataset[i][0][125][125][125:140])
    #print("label: ",dataset[i][1][125][125][125:140])
    #print(np.shape(dataset[i][0]))
    tmp_total = np.size(dataset[i][0])
    #print("Total number of voxels in the Patient = ",tmp_total)
    # The following is to count the number of non-zero voxels in the raw values
    tmp_xyz=np.shape(dataset[i][0])
    countNZ = sum(sum(sum(dataset[i][0] != 0)))

    #print("Total number of non-zero voxels = ",countNZ)
    
    print("Next is Image Equalization")

  # Tp carry out image equalization
  dataset[i][0],_ = histeq(dataset[i][0])


  if i == 0: # Just for the first patient
    # Do not display the voxels at 0 for there are too many background voxels
    # Hence, the use of 0.000001
    plt.hist(dataset[i][0].ravel(),256,[0.000001,np.amax(dataset[i][0])])
    plt.show()
    #print("equalize: ",dataset[i][0][125][125][125:140])


print("Number of patient (Length of dataset) = ",len(dataset) )


# Foundamental building block - Residual Network
**buildResBlock** function returns the residual block and it is utilized in other function that stacks the residual blocks to form a larger network.

Each **buildSingleResNet__** functions create the network for different input dimension, small patch, large patch or 3D voxel.

In [0]:


"""
Residual Block
"""

def buildResBlock(inputs, layer_type='cv',
                  filter1=16, filter2=16, 
                  conv_size1=(3,3), conv_size2=(3,3), 
                  stride1=(1,1), stride2=(1,1), stridec=(2,2),
                  input_size_cross=False
                  ):
  
  if layer_type == 'cv':

    x = layers.BatchNormalization()(inputs)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter1, conv_size1, strides=stride1, padding='SAME')(x)
    
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter2, conv_size2, strides=stride2, padding='SAME')(x)
    
    if input_size_cross:
      ex = layers.Conv2D(filter2, (1,1), strides=stridec, padding='SAME')(inputs)
      x = layers.add([x, ex])
    else:
      x = layers.add([x, inputs])
      
    return x

  elif layer_type == 'cv3':

    x = layers.BatchNormalization()(inputs)
    x = layers.Activation('relu')(x)
    x = layers.Conv3D(filter1, conv_size1, strides=stride1, padding='SAME')(x)
    
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv3D(filter2, conv_size2, strides=stride2, padding='SAME')(x)
    
    if input_size_cross:
      ex = layers.Conv3D(filter2, (1,1,1), strides=stridec, padding='SAME')(inputs)
      x = layers.add([x, ex])
    else:
      x = layers.add([x, inputs])
      
    return x
  
  else:
    print("Type Error")
    return None



In [0]:
def buildSingleResNetSmall():
  inputs = layers.Input(shape=(16,16,1))
  x = layers.Conv2D(64, (3, 3), strides=(2, 2), padding='SAME', activation='relu')(inputs)
  x = layers.MaxPooling2D((2, 2))(x)
  
  x = buildResBlock(inputs=x, layer_type='cv', filter1=128, filter2=128, stridec=(1,1), input_size_cross=True)
  x = buildResBlock(inputs=x, layer_type='cv', filter1=128, filter2=128)
  
  x = buildResBlock(inputs=x, layer_type='cv', filter1=256, filter2=256, stride1=(2,2), input_size_cross=True)
  x = buildResBlock(inputs=x, layer_type='cv', filter1=256, filter2=256)

  #x = buildResBlock(inputs=x, layer_type='cv', filter1=256, filter2=256, stride1=(2,2), input_size_cross=True)
  #x = buildResBlock(inputs=x, layer_type='cv', filter1=256, filter2=256)
  
  x = layers.AveragePooling2D((2, 2))(x)
  
  x = layers.Flatten()(x)
  
  network = models.Model(inputs=inputs, outputs=x)
  return network

  

def buildSingleResNetLarge():
  inputs = layers.Input(shape=(87,87,1))
  x = layers.Conv2D(64, (7, 7), strides=(3, 3), padding='VALID', activation='relu')(inputs)
  x = layers.MaxPooling2D((2, 2))(x)
  
  x = buildResBlock(inputs=x, layer_type='cv', filter1=128, filter2=128, stridec=(1,1), input_size_cross=True)
  x = buildResBlock(inputs=x, layer_type='cv', filter1=128, filter2=128)
  
  x = layers.MaxPooling2D((2, 2))(x)
  
  x = buildResBlock(inputs=x, layer_type='cv', filter1=256, filter2=256, stride1=(2,2), input_size_cross=True)
  x = buildResBlock(inputs=x, layer_type='cv', filter1=256, filter2=256)
  
  x = buildResBlock(inputs=x, layer_type='cv', filter1=256, filter2=256, stride1=(2,2), input_size_cross=True)
  x = buildResBlock(inputs=x, layer_type='cv', filter1=256, filter2=256)
  
  x = layers.AveragePooling2D((2, 2))(x)
  
  x = layers.Flatten()(x)
  
  network = models.Model(inputs=inputs, outputs=x)
  return network

  

def buildSingleResNetCube():
  
  inputs = layers.Input(shape=(26,26,26,1))
  
  x = layers.Conv3D(64, (3, 3, 3), strides=(2, 2, 2), padding='SAME', activation='relu')(inputs)
  
  x = layers.MaxPooling3D((2, 2, 2))(x)
  
  
  x = buildResBlock(inputs=x, layer_type='cv3', conv_size1=(3,3,3), conv_size2=(3,3,3), 
                    filter1=128, filter2=128, stride1=(1,1,1), stride2=(1,1,1), stridec=(1,1,1), input_size_cross=True)
  
  x = buildResBlock(inputs=x, layer_type='cv3', conv_size1=(3,3,3), conv_size2=(3,3,3), 
                    filter1=128, filter2=128, stride1=(1,1,1), stride2=(1,1,1))

  
  x = buildResBlock(inputs=x, layer_type='cv3', conv_size1=(3,3,3), conv_size2=(3,3,3), 
                    filter1=256, filter2=256, stride1=(2,2,2), stride2=(1,1,1), stridec=(2,2,2), input_size_cross=True)
  
  x = buildResBlock(inputs=x, layer_type='cv3', conv_size1=(2,2,2), conv_size2=(2,2,2), 
                    filter1=256, filter2=256, stride1=(1,1,1), stride2=(1,1,1))

  x = layers.AveragePooling3D((2, 2, 2))(x)
  
  x = layers.Flatten()(x)
  
  network = models.Model(inputs=inputs, outputs=x)
  
  return network

  

#res_cube = buildSingleResNetCube()
#res_cube.summary()
#plot_model(res_cube, to_file='drive/My Drive/MRI_Brain_Segmentation/resnet_diagram.png',show_shapes=True)

#res_small = buildSingleResNetSmall()
#res_small.summary()


In [0]:
def buildFullNetwork():
  cube = layers.Input(shape=(26,26,26,1))
  cube_i = buildSingleResNetCube()(cube)
  x = layers.Dropout(rate=0.5)(cube_i)
  x = layers.Dense(135, activation='softmax')(x)
  network = models.Model(inputs=cube, outputs=x)
  return network

resnet = buildFullNetwork()
resnet.summary()
#plot_model(resnet, to_file='drive/My Drive/MRI_Brain_Segmentation/resnet_diagram_highlv.png',show_shapes=True)

Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         (None, 26, 26, 26, 1)     0         
_________________________________________________________________
model_3 (Model)              (None, 256)               5301120   
_________________________________________________________________
dropout_2 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 135)               34695     
Total params: 5,335,815
Trainable params: 5,333,127
Non-trainable params: 2,688
_________________________________________________________________




---



---



# **Model Compilation and Training**

During training, pixels that contain useful info (non background) are randomly sampled from the dataset and then the surrounding pixels of the chosen pixels are extracted to form a voxel of size 26x26x26. The sampled voxels are then fed into the model for training.


In [0]:
adam = optimizers.Adam(lr=2e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-07, decay=1e-6, amsgrad=False)
resnet.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])


In [0]:
# define the overall number of epochs per training session of 15 patients
T_epochs = 1
train_batch_size = 32
samples_size = 1000
vald_size = 200
test_size = samples_size - vald_size
samples_epochs = 1
total_voxel_sample = 1000000
imgH = 342
imgW = 342
layerSize = imgH*imgW

weight_path = "drive/My Drive/MRI_Brain_Segmentation/model/resnet_cube_26_3.h5"

#reload weights that have been saved by the previous training session
#
#resnet.load_weights("drive/My Drive/MRI_Brain_Segmentation/resnet_cube_26_2.h5")
#
#print("Model Restored")
train_hist_acc = []
train_hist_val = []

print("Start Training")
stime = time.time()


try:
  for epoch in range(T_epochs):
    correct = 0
    iter = 0
    #define lists to append the data for training to vectorize the input     
    t7=[] #voxel
    t8=[] #ground truth

    for i in range(0,len(dataset)):
      #total number of elements in the image

      #print("\n Train Patient Number " + str(i))
      print("\n Train Loop Number " + str(i))
      #Sample predefined number of voxels from each patient.
      for j in range (0,total_voxel_sample):
        value = 0
        pat = random.randint(0,len(dataset)-1)
        total = np.size(dataset[pat][0])
        raw = dataset[pat][0]
          
        #only select non-background voxels for training
        while(value==0):
          voxel = random.randint(0,total-1)
          z_layer = (voxel//layerSize)
          row = (voxel%layerSize)//imgH
          col = (voxel%layerSize)%imgW
          value = dataset[pat][1][z_layer][row][col]

        #extract 3D images needed for that pixel
        #volume is 26x26x26
        vol = raw[z_layer-13:z_layer+13,row-13:row+13,col-13:col+13]
        vol = np.reshape(vol, (26,26,26,1))

        #one-hot encode the output value (label) for training
        # The range of value/label is [1,134]
        output = keras.utils.to_categorical(value, num_classes=135)
        
        #append the data to the list for training
        t7.append(vol)
        t8.append(output)
          
        #Collect and train every certain samples.
        if((j+1)%samples_size == 0):
          x_train = np.array(t7)
          y_train = np.array(t8)

          fit_log = resnet.fit(x_train[0:test_size], y_train[0:test_size], 
                               epochs=samples_epochs, batch_size=train_batch_size, verbose=1,
                               validation_data=(x_train[test_size:], y_train[test_size:]))

          train_hist_acc.append(fit_log.history['accuracy'])
          train_hist_val.append(fit_log.history['val_accuracy'])

          #reset the lists
          t7=[]; t8=[]

          #Backup training weights after 1 patient
          resnet.save_weights(weight_path)

    #END of the 2 FOR loops 
    print("\n Training Completed: Time Taken %s s" % (time.time()-stime))
    print('Epoch', epoch+1, 'completed out of',T_epochs)

  #END of the 3 FOR loops 
  print("Training Complete and Saving Model with " + str(j+1) +" Iterations per patient")
  #Save weights after training completes
  resnet.save_weights(weight_path)
  print("Saved model to disk")
  print("All training completed !")

except KeyboardInterrupt:
  print("Training Interrupted Saving Model with " + str((j+1)) +" Iterations/patient")

Plotting the training & validation accuracy after training

In [0]:
plt.plot(train_hist_acc[:2500])
plt.title("Training Accuracy")
plt.xlabel("Iterations")
plt.ylabel("Accuracy")
plt.savefig('drive/My Drive/MRI_Brain_Segmentation/train_acc.png', dpi=200)

In [0]:
plt.plot(train_hist_val[:2500], 'r')
plt.title("Validation Accuracy")
plt.xlabel("Iterations")
plt.ylabel("Accuracy")
plt.savefig('drive/My Drive/MRI_Brain_Segmentation/train_val_acc.png', dpi=200)