In [None]:
import nibabel as nib
from glob import glob
import os
import shutil
#import gif_your_nifti.core as gif2nif
import math
from matplotlib import pyplot as plt
from matplotlib.pyplot import figure
import cv2
import numpy as np
import random
import tensorflow as tf
from skimage.util import montage 
from skimage.transform import rotate
from scipy import ndimage, misc
import time
import datetime
import csv


In [None]:
!pip install segmentation_models_3D
!pip install volumentations-3D
!pip install tensorflow-addons

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
from google.colab import drive
drive.mount('/content/drive')
import zipfile
with zipfile.ZipFile(r"/content/drive/MyDrive/BRATS/MICCAI_BraTS2020_TrainingData.zip", 'r') as zip_ref:
    zip_ref.extractall(r"./content/drive/MyDrive/BRATS/MICCAI_BraTS2020_TrainingData")
print(glob("./content/drive/MyDrive/BRATS/MICCAI_BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/*")[:2])

TRAINING_DIR = r"./content/drive/MyDrive/BRATS/MICCAI_BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"

training_folder_list = glob(TRAINING_DIR+"/*")
random.shuffle(training_folder_list)
validation_folder_list = training_folder_list[int(0.8*(len(training_folder_list))):]
training_folder_list = training_folder_list[:int(0.8*(len(training_folder_list)))]
print(len(training_folder_list), len(validation_folder_list))

In [None]:
%load_ext tensorboard

In [None]:
def convert_to_gif(path):
    os.makedirs("./gif", exist_ok=True)
    shutil.copy2(path, "./gif/"+os.path.basename(path))
    gif2nif.write_gif_normal("./gif/"+os.path.basename(path))
    os.remove("./gif/"+os.path.basename(path))


In [None]:
def plot_mri(path):
    figure(figsize=(15, 15), dpi=80)
    crop_size = 40
    dimentions= (128,128)
    if type(path) is str:
        mri = nib.load(path).get_fdata()
        mri = mri[crop_size:mri.shape[0]-crop_size, crop_size:mri.shape[1]-crop_size, crop_size: mri.shape[2]-crop_size]
        mri = ndimage.zoom(mri, 0.9)
        print(mri.shape)
    else:
        mri=path
    for index in range(mri.shape[0]):
        plt.subplot(math.ceil(mri.shape[-1]/10),10,index+1)
        plt.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)
        new_mri = mri[index,...]
        plt.imshow(new_mri,cmap='gray') 
    print(new_mri.shape)
    plt.show()
    
#plot_mri(glob(training_folder_list[70]+"/*_flair*")[0] )

In [None]:
def get_filepath(paths,mod):
    for path in paths:
        if mod in path :
            return path
    return None

In [None]:
def bb_volume(volume, mask=None):
    volume = np.transpose(volume,(1,2,0))
    rows = np.where(np.max(volume, 0) > 0)[0]
    cols = np.where(np.max(volume, 1) > 0)[0]
    if rows.size:
        volume = volume[cols[0]: cols[-1] + 1, rows[0]: rows[-1] + 1]
        if mask is not None:
            mask = np.transpose(mask,(1,2,0))
            mask = mask[cols[0]: cols[-1] + 1, rows[0]: rows[-1] + 1]
    else:
        volume = volume[:1, :1]
        if mask is not None:
            mask = np.transpose(mask,(1,2,0))
            mask = mask[:1, :1]
    return np.transpose(volume,(2,0,1)), None if mask is None else np.transpose(mask,(2,0,1))

def depth_filter(volume, depth):
    mid = volume.shape[0]//2
    return volume[mid-(depth//2):mid+(depth//2)]

def plot_volume(volume):
    fig, ax1 = plt.subplots(1, 1, figsize = (20,20))
    ax1.imshow(rotate(montage(volume), 90, resize=True), cmap ='gray')

def read(path):
    return nib.load(path).get_fdata()

def read_mri(path,stain):
    return np.transpose(read(glob(path+"/*"+stain+"*")[0]), (2,0,1)) 

def clip_and_normalize(volume):
    perc01 = np.percentile(volume, 1,  keepdims=True)
    perc99 = np.percentile(volume, 99.995, keepdims=True)
    mri_clipped = np.clip(volume, a_min=perc01, a_max=perc99)
    mri_normalized = (((mri_clipped-np.min(mri_clipped)) / (np.max(mri_clipped)-np.min(mri_clipped)+1e-7))*255).astype(np.uint8)
    return mri_normalized

def resize_height_width(volume, width, height,  mask= None):
    temp_vol , temp_mask = None, None
    if volume is not None:
        temp_vol = np.zeros((volume.shape[0], width, height))
    if mask is not None:
        temp_mask =  np.zeros((mask.shape[0], width, height))
    
    for index in range(volume.shape[0]):
        if volume is not None:
          temp_vol[index,...] = cv2.resize(volume[index], (width,height),interpolation = cv2.INTER_AREA)
        if mask is not None:
            temp_mask[index,...] = cv2.resize(mask[index], (width,height),interpolation = cv2.INTER_NEAREST)
    return temp_vol , temp_mask    

def create_patch_from_id_v3(folder,
                            depth = 128,
                            width = 128,
                            height = 128,
                            allowed_mod=["flair","t1ce"]):
    if type(folder) is not str:
        folder=str(folder.numpy().decode('utf-8'))
        depth = tf.cast(depth, tf.int32)
        width,height = tf.cast(width,tf.int32),tf.cast(height,tf.int32)
    try:
      mask = read_mri(folder, "seg")

      result = []
      for index, mod in enumerate(allowed_mod):
        if type(mod) is not str:
          mod = mod.numpy().decode('utf-8')
        if index == (len(allowed_mod)-1):
          mask_depth = depth_filter(mask,depth)
        else:
          mask_depth = None
        
        mri = read_mri(folder, mod)
        mri_chopped, mask_chopped = bb_volume(depth_filter(mri,depth),mask_depth)
        mri_clipped_normalized = clip_and_normalize(mri_chopped)
        mri_resized, mask_resized = resize_height_width(mri_clipped_normalized,mask=mask_chopped, width=width, height=height)
        result.append(mri_resized)
      result = np.transpose(np.array(result), (1,2,3,0))
      mask_resized[mask_resized==4] = 3
    except:
      return np.zeros((depth,width,height,len(allowed_mod))), np.zeros((depth,width,height,1))
    return result, np.expand_dims(mask_resized,-1)
    



In [None]:
#@title
def create_patch_from_id_v2(path, 
                         depth_per_mod = 50,
                         allowed_mod=["flair","t1","t1ce", "t2"],
                         dimentions=(240,240),
                         crop_size= 40,
                         classes = 3,
                         frame_skip=2
                       ):
    if type(path) is not str:
        path=str(path.numpy().decode('utf-8'))
        #print(path)
        depth_per_mod = tf.cast(depth_per_mod, tf.int32)
        dimentions = (tf.cast(dimentions[0],tf.int32),tf.cast(dimentions[1],tf.int32))
    paths = glob(path+"/*")
    volume = np.zeros((depth_per_mod, dimentions[0],dimentions[1],len(allowed_mod)))
    mask = np.zeros((depth_per_mod, dimentions[0],dimentions[1]))
    #clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    try:
      for index, mod in enumerate(allowed_mod):
          if type(mod) is not str:
              mod = mod.numpy().decode('utf-8')
          path = get_filepath(paths,mod)
          mri = nib.load(path).get_fdata() 
          mri = mri[crop_size:mri.shape[0]-crop_size, crop_size:mri.shape[1]-crop_size, crop_size: mri.shape[2]-crop_size]
          #mri = ndimage.zoom(mri, 0.9,order=0)
          rows = np.where(np.max(mri, 0) > 0)[0]
          cols = np.where(np.max(mri, 1) > 0)[0]
          if rows.size:
              mri = mri[cols[0]: cols[-1] + 1, rows[0]: rows[-1] + 1]
          else:
              mri = mri[:1, :1]
          mid = mri.shape[-1]//2
          count=0
          for i in range(mid-(frame_skip*(depth_per_mod//2)),mid+(frame_skip*(depth_per_mod//2)),frame_skip):
              temp = cv2.resize(mri[...,i], dimentions,interpolation = cv2.INTER_AREA)
              temp = (((temp-np.min(temp)) / (np.max(temp)-np.min(temp)+1e-7))*255).astype(np.uint8)
              volume[count,...,index] = temp
              count+=1
      perc01 = np.percentile(volume, 1,  keepdims=True)
      perc99 = np.percentile(volume, 99, keepdims=True)

      #  Clip array with different limits across the z dimension
      volume = np.clip(volume, a_min=perc01, a_max=perc99)

      path = get_filepath(paths,"seg")
      mri = nib.load(path).get_fdata()
      mri = mri[crop_size:mri.shape[0]-crop_size, crop_size:mri.shape[1]-crop_size, crop_size: mri.shape[2]-crop_size]
      #mri = ndimage.zoom(mri, 0.9,order=0)
      if rows.size:
          mri = mri[cols[0]: cols[-1] + 1, rows[0]: rows[-1] + 1]
      else:
          mri = mri[:1, :1]
      mid = mri.shape[-1]//2
      count=0
      for i in range(mid-(frame_skip*(depth_per_mod//2)),mid+(frame_skip*(depth_per_mod//2)),frame_skip):
          mask[count,...] = cv2.resize(mri[...,i], dimentions,interpolation =  cv2.INTER_NEAREST)
          count+=1
      mask[mask==4] = 3;
    except:
      pass
      #raise
    #print(np.unique(mask))

    return volume, np.expand_dims(mask,-1)
  

In [None]:
volume,mask = create_patch_from_id_v3(training_folder_list[38], 
                         depth = 96,
                         allowed_mod=["flair",'t1ce','t2'],
                         height = 64,
                         width = 64)
print(volume.shape, np.max(volume), np.min(volume))
print(mask.shape,np.unique(mask))
fig, ax1 = plt.subplots(1, 1, figsize = (15,15))
ax1.imshow(rotate(montage(volume[:,:,:,2]), 90, resize=True), cmap ='gray')
#fig, ax1 = plt.subplots(1, 1, figsize = (15,15))
#ax1.imshow(rotate(montage(volume[:,:,:,1]), 90, resize=True), cmap ='gray')
#fig, ax1 = plt.subplots(1, 1, figsize = (15,25))
#ax1.imshow(rotate(montage(volume[:,:,:,2]), 90, resize=True), cmap ='gray')
fig, ax1 = plt.subplots(1, 1, figsize = (15,25))
ax1.imshow(rotate(montage(mask[:,:,:,0]), 90, resize=True), cmap ='gray')

In [None]:
def random_saturation(image, mask):
  rd = tf.random.uniform([],0,1.0)
  sat_cond = tf.less(rd, 0.5)
  new_image = tf.cond(sat_cond, lambda: tf.image.random_saturation(image, 0.75,1.25), lambda:image)
  return new_image, mask

def random_brightness(image, mask):
  rd = tf.random.uniform([],0,1.0)
  sat_cond = tf.less(rd, 0.5)
  new_image = tf.cond(sat_cond, lambda: tf.image.random_brightness(image, 0.20), lambda:image)
  return new_image, mask

def random_contrast(image, mask):
  rd = tf.random.uniform([],0,1.0)
  sat_cond = tf.less(rd, 0.5)
  new_image = tf.cond(sat_cond, lambda: tf.image.random_contrast(image, 0.75,1.75), lambda:image)
  return new_image, mask

def normalize_function(image, mask):
  image = tf.cast(image,dtype=tf.float16)
  image = tf.truediv(image, tf.reduce_max(image)+1e-7)
  image = tf.cast(image,dtype=tf.float16)
  return image, mask

In [None]:
depth_per_mod = 128
allowed_mod=["flair",'t1ce']
dimentions=(128,128)
classes=3
crop_size= 0
frame_skip=1

def pyfunc_create_patch_from_id(path):
    depth_per_mod_g = depth_per_mod
    allowed_mod_g=allowed_mod
    dimentions_g=dimentions
    classes_g=classes
    crop_size_g= crop_size
    frame_skip_g=frame_skip

    volume, mask = tf.py_function(
        create_patch_from_id_v3, 
            [path,
            depth_per_mod_g,
            dimentions_g[0],
            dimentions_g[1],
            allowed_mod_g], [tf.float32, tf.uint8]
    )
    return volume, mask

In [None]:
from volumentations import *
def one_hot(x,y):
    x.set_shape([depth_per_mod, dimentions[0],dimentions[1],len(allowed_mod)])
    y_new = tf.one_hot(tf.squeeze(y, axis=-1),4)
    y_new=y_new[...,1:]
    y_new.set_shape([depth_per_mod, dimentions[0],dimentions[1],classes])
    return x,y_new
  
def get_augmentation():
    return Compose([
        Rotate((-10,10), (0, 0), (0,0), p=0.5),
        GaussianNoise(var_limit=(0, 15), p=0.5),
        RandomGamma(gamma_limit=(0.7, 1.3), p=0.5),
        RandomRotate90((1, 2), p=0.5)
    ], p=0.5)


def augmentation(image, mask):
    image =  tf.cast(image, tf.float32).numpy()
    mask = tf.cast(mask, tf.uint8).numpy()
    data = {'image': image,"mask": mask }
    compose = get_augmentation()
    aug_data = compose(**data)
    image_aug = aug_data['image']
    mask_aug = aug_data['mask']
    return image_aug, mask_aug


def pyfunc_augmentation(image, mask):
  image, mask = tf.py_function(
        augmentation, 
            [image, mask], [tf.float32, tf.uint8]
    )
  return image,mask

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices(training_folder_list)
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.shuffle(100, reshuffle_each_iteration=True)
train_dataset = train_dataset.map(pyfunc_create_patch_from_id, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

train_dataset = train_dataset.map(pyfunc_augmentation,num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

train_dataset = train_dataset.map(one_hot,num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
train_dataset = train_dataset.map(normalize_function,num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
train_dataset = train_dataset.batch(1)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

valid_dataset = tf.data.Dataset.from_tensor_slices(validation_folder_list)
valid_dataset = valid_dataset.repeat()
valid_dataset = valid_dataset.shuffle(100,reshuffle_each_iteration=True)
valid_dataset = valid_dataset.map(pyfunc_create_patch_from_id, num_parallel_calls=tf.data.AUTOTUNE)
valid_dataset = valid_dataset.prefetch(tf.data.AUTOTUNE)
valid_dataset = valid_dataset.map(one_hot,num_parallel_calls=tf.data.AUTOTUNE)
valid_dataset = valid_dataset.prefetch(tf.data.AUTOTUNE)
valid_dataset = valid_dataset.map(normalize_function,num_parallel_calls=tf.data.AUTOTUNE)
valid_dataset = valid_dataset.prefetch(tf.data.AUTOTUNE)
valid_dataset = valid_dataset.batch(1)
valid_dataset = valid_dataset.prefetch(tf.data.AUTOTUNE)



In [None]:
for volume, mask in train_dataset.as_numpy_iterator():
    print(volume.shape, np.max(volume), np.min(volume))
    print(mask.shape, np.unique(mask))

    fig, ax1 = plt.subplots(1, 1, figsize = (20,20))
    ax1.imshow(rotate(montage(volume[0,:,:,:,0].astype(np.float32)), 90, resize=True), cmap ='gray')

    fig, ax1 = plt.subplots(1, 1, figsize = (20,20))
    ax1.imshow(rotate(montage(mask[0,:,:,:,1]), 90, resize=True), cmap ='gray')
    break

In [None]:
import keras.backend as K
def scheduler(epoch, lr):
  #if ((epoch+1) % 4 == 0) and (lr > 0.00005):
  if False:
    return lr * tf.math.exp(-0.4)
  else:
    return lr

def custom_CE_loss_function(weights = [1,1,1]):
    def weighted_categorical_crossentropy(target, output):
        output /= K.sum(output, axis=-1, keepdims=True)
        output = K.clip(output, 1e-6, 1 - 1e-6)
        loss_per_pixel = K.sum((target * -K.log(output+1e-6))*weights, axis=-1, keepdims=False)
        weighted_loss = K.mean(loss_per_pixel)
        return weighted_loss
    return weighted_categorical_crossentropy
    
def focal_cross_entropy():
    loss = tf.keras.losses.BinaryFocalCrossentropy(gamma=2.0, from_logits=False)
    def focal_loss(y_true, y_pred):
      result = 0
      for i in range(classes):
          y_pred = K.clip(y_pred, 1e-7, 1 - 1e-7)
          if i==0:
            result = loss(y_true[...,i], y_pred[...,i])
          else:
            result += (loss(y_true[...,i], y_pred[...,i]))
          #print(result)
      return result
    return focal_loss

def dice_coef(y_true, y_pred, smooth=1e-7):
    class_num = 3
    for i in range(class_num):
        y_true_f = K.flatten(y_true[...,i])
        y_pred_f = K.flatten(y_pred[...,i])
        intersection = K.sum(y_true_f * y_pred_f)
        loss = ((2. * intersection + smooth) / (K.sum(y_true_f*y_true_f) + K.sum(y_pred_f*y_pred_f) + smooth))
        if i == 0:
            total_loss = loss
        else:
            total_loss = total_loss + loss
        #print(loss)
    total_loss = total_loss / class_num
    return total_loss

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)
    
def CE_Diceloss(y_true, y_pred):
    ce_loss = custom_CE_loss_function([1,1,1])
    return dice_loss(y_true, y_pred) + ce_loss(y_true, y_pred)

def dice_focal_loss(y_true, y_pred):
  fc = focal_cross_entropy()
  return dice_loss(y_true, y_pred) + fc(y_true, y_pred)

def dice_coef_necrotic(y_true, y_pred, epsilon=1e-6):
    #y_pred = tf.cast(y_pred > 0.5, tf.float32)
    intersection = K.sum(y_true[...,0] * y_pred[...,0])
    return (2. * intersection) / (K.sum(y_true[...,0]) + K.sum(y_pred[...,0]) + epsilon)

def dice_coef_edema(y_true, y_pred, epsilon=1e-6):
    #y_pred = tf.cast(y_pred > 0.5, tf.float32)
    intersection = K.sum(y_true[...,1] * y_pred[...,1])
    return (2. * intersection) / (K.sum(y_true[...,1]) + K.sum(y_pred[...,1]) + epsilon)

def dice_coef_enhancing(y_true, y_pred, epsilon=1e-6):
    #y_pred = tf.cast(y_pred > 0.5, tf.float32)
    intersection = K.sum(y_true[...,2] * y_pred[...,2])
    return (2. * intersection) / (K.sum(y_true[...,2]) + K.sum(y_pred[...,2]) + epsilon)

def precision_necrotic(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true[...,0] * y_pred[...,0], 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred[...,0], 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision
def precision_edema(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true[...,1] * y_pred[...,1], 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred[...,1], 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision

def precision_enhancing(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true[...,2] * y_pred[...,2], 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred[...,2], 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision




# Computing Sensitivity      
def sensitivity_necrotic(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true[...,0] * y_pred[...,0], 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true[...,0], 0, 1)))
    return true_positives / (possible_positives + K.epsilon())

def sensitivity_edema(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true[...,1] * y_pred[...,1], 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true[...,1], 0, 1)))
    return true_positives / (possible_positives + K.epsilon())

def sensitivity_enhancing(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true[...,2] * y_pred[...,2], 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true[...,2], 0, 1)))
    return true_positives / (possible_positives + K.epsilon())


# Computing Specificity
def specificity_necrotic(y_true, y_pred):
    true_negatives = K.sum(K.round(K.clip((1-y_true[...,0]) * (1-y_pred[...,0]), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true[...,0], 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())

def specificity_edema(y_true, y_pred):
    true_negatives = K.sum(K.round(K.clip((1-y_true[...,1]) * (1-y_pred[...,1]), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true[...,1], 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())

def specificity_enhancing(y_true, y_pred):
    true_negatives = K.sum(K.round(K.clip((1-y_true[...,2]) * (1-y_pred[...,2]), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true[...,2], 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())


In [None]:
import tensorflow_addons as tfa
def Ranger(sync_period=6,
           slow_step_size=0.5,
           learning_rate=5e-5,
           beta_1=0.9,
           beta_2=0.999,
           epsilon=1e-7,
           weight_decay=0.,
           amsgrad=False,
           sma_threshold=5.0,
           total_steps=0,
           warmup_proportion=0.3,
           min_lr=0.,
           name="Ranger"):
    inner = tfa.optimizers.RectifiedAdam(learning_rate, beta_1, beta_2, epsilon, weight_decay, amsgrad, sma_threshold, total_steps, warmup_proportion, min_lr, name)
    optim = tfa.optimizers.Lookahead(inner, sync_period, slow_step_size, name)
    return optim

In [None]:
from tensorflow import keras
#keras.backend.set_image_data_format('channels_last')
import segmentation_models_3D as sm
custom_model = sm.Unet('efficientnetb2', input_shape=(128, 128, 128, len(allowed_mod)), encoder_weights=None, classes=classes, activation='sigmoid')

In [None]:
load = True
if load:
  weight_path = "/content/drive/MyDrive/BRATS/models/v2_1646179339/basic_customRangerLoss-35-0.664.hdf5"
  custom_model.load_weights(weight_path)
  model_save_path = os.path.dirname(weight_path)
  logdir = "/content/drive/MyDrive/BRATS/models/v2_1646179339/logs20220302-0002120"
  stats_save_path = model_save_path+"/history.csv"
  csvfile = open(stats_save_path, 'a')
  csv_writer = csv.writer(csvfile)
  headers = True
else:
  model_save_path = '/content/drive/MyDrive/BRATS/models/v2_'+str(int(time.time()))
  stats_save_path = model_save_path+"/history.csv"
  os.mkdir(model_save_path)
  logdir = model_save_path +"/logs/"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
  csvfile = open(stats_save_path, 'a')
  csv_writer = csv.writer(csvfile) 
  headers = False

In [None]:


lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5,
                              patience=3, min_lr=0.00001)
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler)


opt = tf.keras.optimizers.Adam(learning_rate=0.0001)

tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

CE_loss = custom_CE_loss_function([1,1,1])

custom_model.compile(optimizer=opt, loss=dice_focal_loss, metrics = ["accuracy",
                                                                      dice_loss,
                                                                      focal_cross_entropy(),
                                                                      tf.keras.metrics.OneHotMeanIoU(num_classes=3),
                                                                      precision_necrotic,
                                                                      precision_edema,
                                                                      precision_enhancing,
                                                                      sensitivity_necrotic,
                                                                      sensitivity_edema,
                                                                      sensitivity_enhancing,
                                                                      specificity_necrotic,
                                                                      specificity_edema,
                                                                      specificity_enhancing,
                                                                      dice_coef_necrotic, 
                                                                      dice_coef_edema ,
                                                                      dice_coef_enhancing])

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=model_save_path+ "/basic_customRangerLoss-{epoch:02d}-{val_loss:.3f}.hdf5",
    save_weights_only=True,
    monitor = "val_loss",
    mode='min',save_best_only=True,
    save_freq='epoch')


In [None]:
history = custom_model.fit(train_dataset,
                    epochs=40 ,
                    steps_per_epoch=200,
                    validation_data = valid_dataset,
                    validation_steps=70, 
                    callbacks=[lr , model_checkpoint_callback, tensorboard_callback])
if not headers:
  csv_writer.writerow(list(history.history.keys()))
  headers = True
csv_writer.writerows(list(np.array(list(history.history.values())).T))
csvfile.flush()


In [None]:
 
import pandas as pd
df_history = pd.DataFrame(history.history)
df_history.to_csv("BASE_MODEl_v24_with_augmentation_ramger_loss_v2.csv")

In [None]:
%tensorboard --logdir "/content/drive/MyDrive/BRATS/models/v2_1646179339/logs/20220302-000219"

In [None]:
from tensorflow import keras
import segmentation_models_3D as sm
model_test = sm.Unet('efficientnetb3', input_shape=(64, 128, 128, len(allowed_mod)), encoder_weights=None, classes=classes, activation='softmax')
model_test.load_weights('/content/baisc_v20_customCELoss-04-0.911.hdf5')

In [None]:
def plot_volume(volume,mask):
    #print(volume.shape, mask.shape)
    fig, ax1 = plt.subplots(1, 1, figsize = (20,15))
    plt.subplot(1,2,1)
    
    plt.xticks([]),plt.yticks([])
    plt.imshow(rotate(montage(volume), 90, resize=True), cmap ='gray')
    plt.subplot(1,2,2)
    plt.xticks([]),plt.yticks([])
    plt.imshow(rotate(montage(mask), 90, resize=True), cmap ='gray')


In [None]:
count, ne,ed,eh, cel = 0,0,0,0,0
x,y = None,None
for volume, mask in train_dataset.as_numpy_iterator():
  if count < 1:
      count+=1
  else:
      break
  print(volume.shape, np.unique(mask))
  print(mask.shape, np.unique(mask))
  result = custom_model.predict(np.expand_dims(volume[0,...],0))
  x,y = mask,result
  plot_volume(mask[0,...,0], result[0,...,0]>0.5)
  plot_volume(mask[0,...,1], result[0,...,1]>0.5)
  plot_volume(mask[0,...,2], result[0,...,2]>0.5)
  #plot_volume()
  print("mask",np.unique(np.argmax(np.expand_dims(mask[0,...],0),axis=-1),return_counts=True))
  print("result",np.unique(np.argmax(result,axis=-1),return_counts=True))
  ne+= dice_coef_necrotic(mask, result).numpy() 
  ed +=dice_coef_edema(mask, result).numpy()
  eh += dice_coef_enhancing(mask, result).numpy()
  cel+=CE_loss(mask, result).numpy()
  print(dice_coef_necrotic(mask, result).numpy().round(3),
      dice_coef_edema(mask, result).numpy().round(3),
      dice_coef_enhancing(mask, result).numpy().round(3),
      CE_loss(mask, result).numpy().round(3)
      )
print(ne/10, ed/10, eh/10, cel/10)
    

In [None]:
plot_volume(x[0,...,0], y[0,...,0])

In [None]:
y[0,...,0]

In [None]:
plot_mri(result[0,...,2]>0.5)
print(np.unique(np.argmax(result[0,...],axis=-1),return_counts=True))
print(np.unique(result[0,...,2]>0.5, return_counts=True))

In [None]:
temp1, temp2, temp3 = None,None, None
for path in training_folder_list[29:30]:
  try:
    print(os.path.basename(path))
    volume,mask = create_patch_from_id_v3(path, 
                         depth = 128,
                         allowed_mod=["flair",'t1ce'],
                         height = 128,
                         width = 128
                        )
    #mask = np.squeeze(np.eye(4)[mask.reshape(-1)])
    volume = volume/(np.max(volume)+1e-7)
    temp1=volume
    #plot_mri(volume[...,0])
    result = custom_model.predict(np.expand_dims(volume,0))
    temp3= result 
    #print("mask",np.unique(mask,return_counts=True))
    print("result",np.unique(np.argmax(result,axis=-1),return_counts=True))
    mask = tf.keras.utils.to_categorical(mask, dtype ="float32")
    temp2=mask[...,1:]
    print(dice_coef_necrotic(mask[...,1:], result).numpy(),
      dice_coef_edema(mask[...,1:], result).numpy(),
      dice_coef_enhancing(mask[...,1:], result).numpy()
      )
  except:
    raise

In [None]:
print(temp2.shape, temp3.shape)
plot_volume(temp2[...,0], temp3[0,...,0])

In [None]:
m1= np.expand_dims(temp2[45,...,0],-1)
m2= np.expand_dims(temp2[45,...,1],-1)
m3= np.expand_dims(temp2[45,...,2],-1)

p1 = np.expand_dims(temp3[0,45,...,0],-1)
p2 = np.expand_dims(temp3[0,45,...,1],-1)
p3 = np.expand_dims(temp3[0,45,...,2],-1)

v = np.expand_dims(temp1[45,...,0],-1)

from matplotlib.pyplot import figure

figure(figsize=(9, 12), dpi=80)
plt.subplot(3,2,1)
plt.xticks([]),plt.yticks([])
green = np.ones(m1.shape, dtype=np.float)*(0,1,0)
out = green*m1 + v*(1.0-m1)
plt.imshow(out)
plt.subplot(3,2,2)
plt.xticks([]),plt.yticks([])
green = np.ones(p1.shape, dtype=np.float)*(0,1,0)
out = green*p1 + v*(1.0-p1)
plt.imshow(out)

plt.subplot(3,2,3)
plt.xticks([]),plt.yticks([])
green = np.ones(m2.shape, dtype=np.float)*(0,1,0)
out = green*m2 + v*(1.0-m2)
plt.imshow(out)

plt.subplot(3,2,4)
plt.xticks([]),plt.yticks([])
green = np.ones(p2.shape, dtype=np.float)*(0,1,0)
out = green*p2 + v*(1.0-p2)
plt.imshow(out)

plt.subplot(3,2,5)
plt.xticks([]),plt.yticks([])
green = np.ones(m3.shape, dtype=np.float)*(0,1,0)
out = green*m3 + v*(1.0-m3)
plt.imshow(out)

plt.subplot(3,2,6)
plt.xticks([]),plt.yticks([])
green = np.ones(p3.shape, dtype=np.float)*(0,1,0)
out = green*p3 + v*(1.0-p3)
plt.imshow(out)
