1. upload BraTS2021_Training_Data.tar to google drive
2. unzip BraTS2021_Training_Data.tar to /content/brats
3. install monai packages
4. upload the pre-trained model to /content
5. upload train_labels.csv to /content

In [None]:
!mkdir brats
!7z x -aos /content/drive/MyDrive/RSNA_seg/BraTS2021_Training_Data.tar -o/content/brats
!cp /content/drive/MyDrive/RSNA_seg/exp_b_1_pretrained.h5 /content
!cp /content/drive/MyDrive/train_labels.csv /content
!pip install monai

In [2]:
import os
import cv2
import glob
import PIL
import shutil
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from skimage import data
from skimage.util import montage 
import skimage.transform as skTrans
from skimage.transform import rotate
from skimage.transform import resize
from PIL import Image, ImageOps  

import scipy
import nibabel as nib

from monai.transforms import Compose, Resize,AddChannel,RandGaussianNoise,RandAdjustContrast,ScaleIntensity
import monai
from sklearn.metrics import roc_auc_score

import keras
import keras.backend as K
from keras.callbacks import CSVLogger
import tensorflow as tf
from tensorflow.keras.utils import plot_model
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard
from tensorflow.keras.layers.experimental import preprocessing
from keras.layers import *
from keras.models import *
np.set_printoptions(precision=3, suppress=True)

In [3]:
# DEFINE seg-areas  
SEGMENT_CLASSES = {
    0 : 'NOT tumor',
    1 : 'NECROTIC/CORE', # or NON-ENHANCING tumor CORE
    2 : 'EDEMA',
    3 : 'ENHANCING' # original 4 -> converted into 3 later
}

# there are 155 slices per volume
# to start at 5 and use 145 slices means we will skip the first 5 and last 5 
VOLUME_SLICES = 100 
VOLUME_START_AT = 22 # first slice of volume that we will include
IMG_SIZE1=128
IMG_SIZE2=128
IMG_SIZE3=80

TRAIN_DATASET_PATH = '/content/brats'
pre_trained_model_path = "/content/exp_b_1_pretrained.h5"

In [17]:
def build_unet(inputs):
    conv11 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(inputs)
    conc11 = concatenate([inputs, conv11], axis=4)
    conv12 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(conc11)
    conc12 = concatenate([inputs, conv12], axis=4)
    pool1 = MaxPooling3D(pool_size=(2, 2, 2))(conc12)

    conv21 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(pool1)
    conc21 = concatenate([pool1, conv21], axis=4)
    conv22 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(conc21)
    conc22 = concatenate([pool1, conv22], axis=4)
    pool2 = MaxPooling3D(pool_size=(2, 2, 2))(conc22)

    conv31 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(pool2)
    conc31 = concatenate([pool2, conv31], axis=4)
    conv32 = Conv3D(128, (3, 3, 3), activation='relu', padding='same',name='123')(conc31)
    conc32 = concatenate([pool2, conv32], axis=4)
    pool3 = MaxPooling3D(pool_size=(2, 2, 2),name='pool3')(conc32)

    conv41 = Conv3D(256, (3, 3, 3), activation='relu', padding='same',name='conv41')(pool3)
    conc41 = concatenate([pool3, conv41], axis=4)
    conv42 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(conc41)
    conc42 = concatenate([pool3, conv42], axis=4)
    pool4 = MaxPooling3D(pool_size=(2, 2, 2))(conc42)

    conv51 = Conv3D(512, (3, 3, 3), activation='relu', padding='same',name='mid_feature_1')(pool4)
    conc51 = concatenate([pool4, conv51], axis=4,name='mid_feature')



    conv52 = Conv3D(512, (3, 3, 3), activation='relu', padding='same')(conc51)
    conc52 = concatenate([pool4, conv52], axis=4,name='mid_feature_2')

    up6 = concatenate([Conv3DTranspose(256, (2, 2, 2), strides=(2, 2, 2), padding='same')(conc52), conc42], axis=4)
    conv61 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(up6)
    conc61 = concatenate([up6, conv61], axis=4)
    conv62 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(conc61)
    conc62 = concatenate([up6, conv62], axis=4)

    up7 = concatenate([Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding='same')(conc62), conv32], axis=4)
    conv71 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(up7)
    conc71 = concatenate([up7, conv71], axis=4)
    conv72 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(conc71)
    conc72 = concatenate([up7, conv72], axis=4)

    up8 = concatenate([Conv3DTranspose(64, (2, 2, 2), strides=(2, 2, 2), padding='same')(conc72), conv22], axis=4)
    conv81 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(up8)
    conc81 = concatenate([up8, conv81], axis=4)
    conv82 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(conc81)
    conc82 = concatenate([up8, conv82], axis=4)

    up9 = concatenate([Conv3DTranspose(32, (2, 2, 2), strides=(2, 2, 2), padding='same')(conc82), conv12], axis=4)
    conv91 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(up9)
    conc91 = concatenate([up9, conv91], axis=4)
    conv92 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(conc91)
    conc92 = concatenate([up9, conv92], axis=4,name='conc92')

    conv10 = Conv3D(4, (1, 1, 1), activation='softmax',name='seg_m')(conc92)

    return Model(inputs=inputs, outputs=conv10)

input_layer = Input((IMG_SIZE1, IMG_SIZE2,IMG_SIZE3, 1))

model = build_unet(input_layer)

In [18]:
model.load_weights(pre_trained_model_path)  

In [6]:
# get the layer before the last layer of decoder
lastlayer_model = Model(inputs=model.input,outputs=model.get_layer('conc92').output)

In [19]:
# set all parameters of lastlayer_model untrainable
def mask_model(g,inputs):
    concat = g(inputs)
    g.trainable = False
    output1 = Conv3D(4, (1, 1, 1), activation='softmax',name='seg_m')(concat)
    output2 = Conv3D(4, (1, 1, 1), activation='softmax',name='seg_n')(concat)
    return Model(inputs=inputs,outputs=[output1,output2])
input_layer = Input((128,128,80,1))
new_model = mask_model(lastlayer_model,input_layer)
new_model.compile(loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),metrics = ['accuracy',tf.keras.metrics.MeanIoU(num_classes=4)])


In [8]:
cls_df = pd.read_csv('train_labels.csv')
cls_id = cls_df['BraTS21ID'].values
cls_label = cls_df['MGMT_value'].values

d1=zip(cls_id,cls_label)
label_dict = dict(d1)

# lists of directories with studies
train_and_val_directories = [f.path for f in os.scandir(TRAIN_DATASET_PATH) if f.is_dir()]

def pathListIntoIds(dirList):
    x = []
    for i in range(0,len(dirList)):
        x.append(dirList[i][dirList[i].rfind('/')+1:])
    return x

train_and_test_ids = pathListIntoIds(train_and_val_directories); 

only_id = []
new_train_and_test_ids = [] # means that these ids have corresponding MRI images, masks, labels of methylation
for i in range(len(train_and_test_ids)):
  only = int(train_and_test_ids[i].split('_')[-1])
  if only in cls_id:
    only_id.append(only)
    new_train_and_test_ids.append(train_and_test_ids[i])

train_ids, val_ids = train_test_split(new_train_and_test_ids,test_size=0.2,random_state=42) 

In [9]:
class DataGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, dim=(IMG_SIZE1,IMG_SIZE2,IMG_SIZE3), batch_size = 1, n_channels = 1, shuffle=True, is_train = True, label = None,test=False):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()
        self.transforms = Compose([AddChannel(),Resize((IMG_SIZE1,IMG_SIZE2,IMG_SIZE3))])
        self.is_train = is_train
        self.label = label
        self.test = test

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        Batch_ids = [self.list_IDs[k] for k in indexes]


        # Generate data
        X, [y1,y2] = self.__data_generation(Batch_ids)

        label_y = np.zeros(len(Batch_ids))
        for x in range(len(Batch_ids)):
          x1 = int(Batch_ids[x].split('_')[-1])
          label_y[x]=(self.label[x1])
        
        

        if self.test == True:
          return X,label_y
        elif self.is_train == True:
          # return X, [label_z]
          return X, [y1,y2]
        else:
          return X,Batch_ids

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, Batch_ids):
        'Generates data containing batch_size samples' 

        X = np.zeros((self.batch_size, *self.dim, self.n_channels))
        y = np.zeros((self.batch_size, *self.dim))


        label_z = np.zeros(len(Batch_ids))
        for x in range(len(Batch_ids)):
          x1 = int(Batch_ids[x].split('_')[-1])
          label_z[x]=(self.label[x1])

        
        
        # Generate data
        for c, i in enumerate(Batch_ids):
            case_path = os.path.join(TRAIN_DATASET_PATH, i)

            data_path = os.path.join(case_path, f'{i}_flair.nii.gz');
            flair = nib.load(data_path).get_fdata()
            flair = self.transforms(flair)
            
            data_path = os.path.join(case_path, f'{i}_seg.nii.gz');
            seg = nib.load(data_path).get_fdata()
            seg = self.transforms(seg)

            flair = flair.squeeze(0)
            seg = seg.squeeze(0)

            X[c,:,:,:,0] = flair
            y[c] = seg

        y[y==4] = 3
        mask = tf.one_hot(y, 4)

        white = np.ones((self.batch_size,*self.dim,1))
        zeros12 = np.zeros((self.batch_size,*self.dim,1))

        background = np.concatenate((white,zeros12,zeros12,zeros12),axis=-1)

        mask_3labels = mask

        seg_m = np.zeros((self.batch_size, *self.dim, 4)) #set their corresponding seg_m and seg_n
        seg_n = np.zeros((self.batch_size, *self.dim, 4))
        
        for xx in range(len(label_z)):
          if label_z[xx] == 1:
            seg_m[xx] = background[xx]
            seg_n[xx] = mask_3labels[xx]
          else:
            seg_m[xx] = mask_3labels[xx]
            seg_n[xx] = background[xx] 

        return X/np.max(X), [seg_m, seg_n]

In [15]:
training_generator = DataGenerator(train_ids,label=label_dict)
valid_generator = DataGenerator(val_ids,shuffle=False,label=label_dict)
test_generator = DataGenerator(new_train_and_test_ids,shuffle=False, label=label_dict,is_train=False, test=False)

In [11]:
callbacks = [
      keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,patience=2, min_lr=0.000001, verbose=1),
      CSVLogger('training.log', separator=',', append=False),
 keras.callbacks.ModelCheckpoint(filepath = 'model_.{epoch:02d}-{val_loss:.6f}.h5',verbose=1, save_best_only=True, save_weights_only = True)
        # csv_logger
    ]

In [None]:
K.clear_session()

history =  new_model.fit(training_generator,
            epochs=35,
            steps_per_epoch=len(train_ids),
            callbacks= callbacks,
            validation_data = valid_generator
                    )  

After training above model, I got 2 way to classify
1. generate their seg-m and seg-n , then save them. Then use these to train classification model
2. connect segmentation model and classification model, which set segmentation model untrainable and classification model trainable. then forward MRI images.

The first way:

In [None]:
o_masks1 = np.zeros((574,128,128,80,4),dtype=np.float16)
o_masks2 = np.zeros((574,128,128,80,4),dtype=np.float16)
id1 = np.zeros((574),dtype=np.float16)

for idx, data in enumerate(test_generator):
  images,idd=data
  output_masks1,output_masks2 = new_model.predict(images)#shape should be 2,128,128,80,4
  o_masks1[idx]=output_masks1
  o_masks2[idx]=output_masks2
  idd = int(idd[0].split('_')[-1])
  id1[idx] = idd

In [None]:
np.save("mask1.npy",o_masks1)
np.save("mask2.npy",o_masks1)
np.save('id_mask.npy',id1)

The second way:

In [None]:
def final_model(g,inputs):
    g.trainable = False
    mask1, mask2 = g(inputs)
    x1 = concatenate([mask1, mask2], axis=4)
    x2 = Conv3D(16, (3, 3, 3), activation='relu', padding='same')(x1)
    x3 = MaxPooling3D(pool_size=(2, 2, 2))(x2)
    x4 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(x3)
    x5 = MaxPooling3D(pool_size=(2, 2, 2))(x4)
    x6 = Flatten()(x5)
    x7 = Dense(1024,activation='relu')(x6)
    x8 = Dense(1,activation='sigmoid')(x7)
    return Model(inputs=inputs, outputs=x8)
input_layer = Input((128,128,80,1))
new_model = final_model(new_model,input_layer)
new_model.compile(loss="binary_crossentropy", optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),metrics = ['accuracy',tf.keras.metrics.AUC()])

new_model.summary()

In [21]:
callbacks = [
#     keras.callbacks.EarlyStopping(monitor='loss', min_delta=0,
#                               patience=2, verbose=1, mode='auto'),
      keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,patience=2, min_lr=0.000001, verbose=1),
      CSVLogger('training2.log', separator=',', append=False),
 keras.callbacks.ModelCheckpoint(filepath = 'model_.{epoch:02d}-{val_loss:.6f}.h5',verbose=1, save_best_only=True, save_weights_only = True)
        # csv_logger
    ]

In [22]:
training2_generator = DataGenerator(train_ids,label=label_dict,test=True)
valid2_generator = DataGenerator(val_ids,shuffle=False,label=label_dict,test=True)

In [None]:
history =  new_model.fit(training2_generator,
            epochs=35,
            steps_per_epoch=len(train_ids),
            callbacks= callbacks,
            validation_data = valid2_generator
                    )  