In [0]:
#required imports 
import numpy as np
import pandas as pd
import matplotlib as mlt
import tensorflow as tf
import os
import pathlib
import keras
from keras import regularizers
from keras.models import Model
from keras.models import Sequential
from keras.layers import Convolution2D
from keras.layers import Conv2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense, Input
from keras.layers import Activation, ZeroPadding2D, Lambda,Concatenate
from keras.layers import AveragePooling2D, MaxPooling2D, Dropout, GlobalMaxPooling2D, GlobalAveragePooling2D, Add
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
import h5py
from keras.utils import layer_utils

In [0]:
#library for pre-processing mha files
!pip3 install medpy




In [0]:
from google.colab import drive
drive.mount('/content/gdrive')


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
'''A utility function for applying N4ITK bias correction to MRI slices'''

import SimpleITK as sitk
def preprocess(path):
  #path='/content/gdrive/My Drive/BRATS2015_Training/HGG/brats_2013_pat0001_1/VSD.Brain.XX.O.MR_T1.54513/VSD.Brain.XX.O.MR_T1.54513.mha'
  #path is the location of the .mha file. Modify it as per the files in your drive. 

  inputImage = sitk.ReadImage(path)
  maskImage = sitk.OtsuThreshold( inputImage, 0, 1, 200 )
  #Otsu thresholding sets the default mask

  inputImage = sitk.Cast( inputImage, sitk.sitkFloat32 )
  #Actual input image is of type int16, it has to be cast to Float32 
  corrector = sitk.N4BiasFieldCorrectionImageFilter()
  #corrector.SetMaximumNumberOfIterations( 1 ). It takes the default value for number of iterations.

  output = corrector.Execute( inputImage, maskImage )

  output_fn = path[:-4] + '_n.mha'
  #Output normalised image is stored in the same folder as input with the input filename along with _n.mha
  sitk.WriteImage( output, output_fn)

In [0]:
'''This function extracts 2d slices from each modality, converts them into numpy array and stacks them together
Input: paths of four modalities
Output: A list containing stacked axial 155 slices and the ground truth of these slices'''
#imports required for pre-processing of the images
from sklearn.preprocessing import normalize
from PIL import Image
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from medpy.io import load, save
#conversion of .mha files into n-dimensional numpy arrays
def gen_image(path1, path2, path3,path4,gtpath):   
  image_t1, image_header1 = load(path2)
  image_t2, image_header2 =load(path4)
  image_t1c, image_header1c =load(path3)
  image_flair, image_headerr =load(path1)
  image_gt, image_header_gt = load(gtpath)
  t1x, t1y, t1z = image_t1.shape
  final = np.zeros((155, 240, 240,4))
#extraction of each slice of the four modalities and their normalization
  for a in range(0, 155):
    imgt1 = image_t1[:,:,a]
    imgt1 = normalize(imgt1)
    imgt2 = image_t2[:,:,a]
    imgt2 = normalize(imgt2)
    imgt1c = image_t1c[:,:,a]
    imgt1c = normalize(imgt1c)
    imgtflair = image_flair[:,:,a]
    imgtflair = normalize(imgtflair)
 #for stacking of all 4 modalities together   
    final[a] = np.dstack([imgt1, imgt2, imgt1c, imgtflair])

  return final, image_gt

In [0]:
'''This function creates balanced input patches for each class
Input: list of input patches and their groundtruth (gt)
Output: Two lists of balanced input patches one sized 65x65x4 and other sized 33x33x4
        List of middle pixel of ground truth with one hot encoding and integer labels
        The number of pixels of each class '''
def create_patch(array, gt):
    patches65=[]
    patches33=[]
    
    #gt_pixel = np.zeros(5580, dtype=int)
    gt1=np.zeros((10000,1,1,5))
    actual_gt=np.zeros(10000,dtype=int)
    count0,count1,count2,count3,count4=0,0,0,0,0
    k=0
    #array of shape (155,240, 240,4). create patch of size (65, 65, 4) and (33, 33, 4) and check the labels of center pixel of gt
    for slice in range(30,130):
        c=30
        for i in range(0,80):
            r=30
            for j in range(0,80):
              
                

                key=gt[r+32,c+32,slice]
                
                actual_gt[k]=key
                
                
                #gt_pixel[k]=gt[r+32:r+33,c+32:c+33,slice]
                if(key==2 and count2<2000 ):
                  patch=array[slice,r:r+65,c:c+65,:]
                  s_patch=patch[16:49,16:49,:]
                  patches65.append(patch)
                  patches33.append(s_patch)
                  gt1[k,0,0,key]=1
                  k=k+1
                  count2 = count2+1
                if(key==1 and count1<2000 ):
                  patch=array[slice,r:r+65,c:c+65,:]
                  s_patch=patch[16:49,16:49,:]
                  patches65.append(patch)
                  patches33.append(s_patch)
                  gt1[k,0,0,key]=1
                  k=k+1
                  count1 = count1+1
                if(key==3 and count3<2000 ):
                  patch=array[slice,r:r+65,c:c+65,:]
                  s_patch=patch[16:49,16:49,:]
                  patches65.append(patch)
                  patches33.append(s_patch)
                  gt1[k,0,0,key]=1
                  k=k+1
                  count3 = count3+1
                if(key==4 and count4<2000 ):
                  patch=array[slice,r:r+65,c:c+65,:]
                  s_patch=patch[16:49,16:49,:]
                  patches65.append(patch)
                  patches33.append(s_patch)
                  gt1[k,0,0,key]=1
                  k=k+1
                  count4 = count4+1
                r=r+2
            c=c+2
    for slice in range(50,130):
        l=30
        for i in range(0,6):
            m=30
            for j in range(0,6):
              key=gt[l+32,m+32,slice]
              if(k==8000 or count0==1999):
                gt2 = gt1[0:k]
                return patches65, patches33, gt2, actual_gt,count0,count2,count3,count4
              actual_gt[k]=key

              if(key==0 and count0<2000 ):
                patch=array[slice,l:l+65,m:m+65,:]
                s_patch=patch[16:49,16:49,:]
                patches65.append(patch)
                patches33.append(s_patch)
                gt1[k,0,0,key]=1
                k=k+1
                count0 = count0+1
              m=m+20
            l=l+20
    gt2 = gt1[0:k]
    
    return patches65, patches33, gt2, actual_gt, count0,count2,count3,count4

In [0]:
'''This CNN mplements the two way CNN architecture
Input: input patch in form of tensor
Output: an output patch(dimension depends on the input)'''
def two_way_CNN(img):
  
    #valid padding means no padding, stride is 1 by default
    #the local pathway
    O1 = Conv2D(filters= 64,kernel_size =(7,7),strides=(1,1),padding='valid')(img)
    O2 = Conv2D(filters= 64,kernel_size =(7,7),strides=(1,1),padding='valid')(img)
  
    Max_O = keras.layers.Maximum()([O1,O2])
    Max_O = MaxPooling2D(pool_size=(4,4), padding='valid',strides=(1,1), data_format='channels_last')(Max_O)
    
  
    #Coming to the second layer ...
    O3 = Conv2D(filters= 64,kernel_size =(3,3),strides=(1,1),padding='valid')(Max_O)
    O4 = Conv2D(filters= 64,kernel_size =(3,3),strides=(1,1),padding='valid')(Max_O)
    Max_O = keras.layers.Maximum()([O3,O4])
    Max_O = MaxPooling2D(pool_size=(2,2),padding='valid',strides=(1,1),data_format='channels_last')(Max_O)
    
    #for the second path(global).. no pooling here
    O5=Conv2D(filters= 160,kernel_size =(13,13),strides=(1,1),padding='valid')(img)
    O6=Conv2D(filters= 160,kernel_size =(13,13),strides=(1,1),padding='valid')(img)
    
    Max_O3 = keras.layers.Maximum()([O5,O6])
    
    # concatenation of two pathways
    Max_O4 = Concatenate()([Max_O,Max_O3])
    #our final output...
    
    Max_O5 = Conv2D(filters=5,kernel_size=(21,21),strides=(1,1), padding ='valid', activation = 'softmax',kernel_regularizer=regularizers.l1_l2(0.01,0.01))(Max_O4)
    return Max_O5

In [0]:
'''Implements the input Cascade CNN architecture
Input: None
Output: an instance of model class implemeting the Input Cascade CNN architecture'''
def inputCascadeCNN():
  
    img1 = Input((65,65,4))
  
    O1 = two_way_CNN(img1)   
    # the output of first CNN
  
    img2 = Input((33,33,4))    
    O2 = Concatenate()([O1,img2])
    # Concatenated input is fed into the second CNN to give the output one-hot encoded.
  
    O2 = two_way_CNN(O2)
    final_model = Model(inputs = [img1, img2], outputs = O2 )
    return final_model

In [0]:
#utility function for calculating the dice score
def dice_score(y_true, y_pred):
    y_true_f = K.batch_flatten(y_true)
    y_pred_f = K.batch_flatten(y_pred)
    intersection = 2. * K.sum(y_true_f * y_pred_f, axis=1, keepdims=True) + smooth
    union = K.sum(y_true_f, axis=1, keepdims=True) + K.sum(y_pred_f, axis=1, keepdims=True) + smooth
    return K.mean(intersection / union)

In [0]:

model = inputCascadeCNN()
model.summary()

In [0]:
'''This cell trains the CNN with a balanced dataset. Change the variable path as per the requirement'''
from sklearn.utils import class_weight
from keras import backend as K
smooth=1
#path = '/content/gdrive/My Drive/NNFL Assignment/BRATS 2015 Dataset/Training'
path='/content/gdrive/My Drive/NNFL/DATASET/BRATS2015_Training'
#path='/content/gdrive/My Drive/BRATS2015_Training'
j=0
model=inputCascadeCNN()
with os.scandir(path) as training:
    for folder1 in training:
        path1 = path + '//' + folder1.name
        with os.scandir(path1) as lgg_hgg:
            for folder2 in lgg_hgg:
                path2 = path1 + '//' + folder2.name
                with os.scandir(path2) as brats:
                    modularity_path = []
                    for folder3 in brats:          
                        path3 = path2 + '/' + folder3.name
                        with os.scandir(path3) as vsdbrain:
                            for file in vsdbrain:
                                if file.name.endswith('mha'):
                                    path4 = path3 + '/' + file.name
                                    modularity_path.append(path4)      
                
                sorted_path = sorted(modularity_path)
                print(sorted_path)
                #for preventing array out of bounds error
                try:
                  arr, gt= gen_image(sorted_path[0], sorted_path[1], sorted_path[2], sorted_path[3], sorted_path[4])
                  patches65, patches33, gt_pixel, actual_gt, a,b,c,d = create_patch(arr, gt)
                  print(a,b,c,d)
                  
                  #to account for minor disbalance in the data.
                  #class_weight: useful to tell the model to "pay more attention" to samples from an under-represented class.
                  
                  class_weights = class_weight.compute_class_weight('balanced',np.unique(actual_gt),actual_gt)
                  class_weights = class_weights[1:]
                  #for regular saving of the model
                  if(j==20 or j==80 or j==120 or j==150 or j==200 or j==250):
                    model_json = model.to_json()
                    with open("/content/gdrive/My Drive/mymodelinput.json","w") as json_file:
                      json_file.write(model_json)
                    model.save_weights("/content/gdrive/My Drive/myinputweight274.h5")
                    print("Saved model to disk")

                  #compiling and fiting the model                      
                  sgd = keras.optimizers.SGD(learning_rate=0.005, momentum=0.5, nesterov=False)
                  model.compile(optimizer='sgd',loss='categorical_crossentropy',metrics=['accuracy',tf.keras.metrics.Precision(class_id=0, name = 'normal_tissue_precision'),tf.keras.metrics.Precision(class_id=2,name = 'edema_precision'),tf.keras.metrics.Precision(class_id=3,name = 'non_enhancing_tumor_precision'),tf.keras.metrics.Precision(class_id=4,name = 'enhancing_tumor_precision'),tf.keras.metrics.Recall(class_id=0,name = 'normal_tissue_recall'),tf.keras.metrics.Recall(class_id=2,name = 'edema_recall'),tf.keras.metrics.Recall(class_id=3,name = 'non_enhancing_tumor_recall'),tf.keras.metrics.Recall(class_id=4,name = 'enhancing_tumor_recall')])
                  model.fit([patches65, patches33],gt_pixel,epochs=1, class_weight=class_weights)
                  j=j+1
                  print(j)
                except:
                  print(j)
                

In [0]:
#Save model results
model_json = model.to_json()
with open("/content/gdrive/My Drive/mymodelinput.json","w") as json_file:
    json_file.write(model_json)
model.save_weights("/content/gdrive/My Drive/myinputweight274.h5")
print("Saved model to disk")

In [0]:
#Loading the model
from keras.models import model_from_json
json_file = open('/content/gdrive/My Drive/mymodelmodel.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)
loaded_model.load_weights("/content/gdrive/My Drive/myinputweight274.h5")