#Important Note:
You need to run DAPT Phase 3 only for the strategies DA_L1SB_PFT and DA_L2SB_PFT.

In [None]:
#import the necessary libraries
import numpy as np
import tensorflow as tf
import keras
from keras.callbacks import ModelCheckpoint
import os
from keras import models 
from keras import layers
#import gc

In [None]:
#Mount Google drive
#Run this cell only if your data (npy files of LGG and HGG reside on Google 
#drive)
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
#Set base_path to the location where the data and results of your project
#reside
base_path = '/content/gdrive/MyDrive/HPT/'

In [None]:
# Image shape set to constant for further use
# 240, 240 is the size of a slice in BraTS dataset. Same image/slice
# is copied to the 3 channels. We need to have 3 channels because we
# are using pre-trained ResNet50 (or its variant)
IMG_SHAPE = (240, 240, 3)

#Preparing TrainX and TrainY
Loading HGG and LGG data stored in .npy files. Creating their labels: 0 for LGG and 1 for HGG. Finally, the data and the corresponding labels will be shuffled randomly.

In [None]:
HGG_cases = 19496
LGG_cases = 4926
Total_cases = HGG_cases + LGG_cases

In [None]:
#creating a NumPy array for holding all the training data (TrainX)
TrainX = np.zeros((Total_cases, 240 , 240, 3), dtype=np.float16)

In [None]:
print (TrainX.dtype)

float16


In [None]:
#function to load HGG cases stored in BraTS2020_Tumorous_HGG_T1_f16.npy
def read_HGG():
  HGG_data_one_channel = np.load(base_path + 'Datasets/BraTS2020/BraTS2020_Tumorous_HGG_T1_f16.npy')
  print (HGG_data_one_channel.shape)
  print (HGG_data_one_channel.dtype)

  for i in range (HGG_data_one_channel.shape[0]):
    TrainX[i, :, :, 0] = HGG_data_one_channel[i, :, :]
    TrainX[i, :, :, 1] = HGG_data_one_channel[i, :, :]
    TrainX[i, :, :, 2] = HGG_data_one_channel[i, :, :]

In [None]:
#function to load LGG cases stored in BraTS2020_Tumorous_LGG_T1_f16.npy
def read_LGG():
  LGG_data_one_channel = np.load(base_path + 'Datasets/BraTS2020/BraTS2020_Tumorous_LGG_T1_f16.npy')
  print (LGG_data_one_channel.shape)
  print (LGG_data_one_channel.dtype)

  for i in range (LGG_data_one_channel.shape[0]):
    TrainX[i + HGG_cases, :, :, 0] = LGG_data_one_channel[i, :, :]
    TrainX[i + HGG_cases, :, :, 1] = LGG_data_one_channel[i, :, :]
    TrainX[i + HGG_cases, :, :, 2] = LGG_data_one_channel[i, :, :]

In [None]:
#call the function read_HGG() to load HGG data to TrainX
read_HGG()

(19496, 240, 240)
float16


In [None]:
#call the function read_LGG() to load LGG data to TrainX
read_LGG()

(4926, 240, 240)
float16


In [None]:
#function to define labels. i.e. 1 for HGG and 0 LGG cases
def define_labels():
  HGG_labels = np.ones(shape=(HGG_cases,1), dtype='uint8')
  LGG_labels = np.zeros(shape=(LGG_cases,1), dtype='uint8')

  return (np.concatenate((HGG_labels, LGG_labels), axis=0))



In [None]:
#Call the function to create labels and store in TrainY 
TrainY = define_labels()

In [None]:
#Printing the shape of TrainX and TrainY
print (TrainX.shape)
print (TrainY.shape)

(24422, 240, 240, 3)
(24422, 1)


In [None]:
#Shuffle the data in TrainX and TrainY
p = np.random.permutation(TrainX.shape[0])
TrainX = TrainX[p]
TrainY = TrainY[p]



# Phase 3 of DAPT (Start for the First Time)
Run the following cells only when starting DAPT (phase 3) for a particular 
strategy for the first time. **DO NOT** run the following cells if you are
resuming phase 3 of DAPT after some epochs.

#All the layers of the architecture will be un-frozen in phase 3

In [None]:
#Set the name of the strategy
Strategy = "DA_L1SB_PFT"

In [None]:
#load appropriate (based on the strategy) phase 2 model which has been trained for 100 epochs.
phase_3_model = models.load_model(base_path + 'DAPT/Checkpoints/' + Strategy + '/phase_2/checkpoint-100.h5')


In [None]:
for layer in enumerate(phase_3_model.layers):
    if "_bn" not in layer.name:
      layer.trainable = True

In [None]:
#Verify that all layers of the convolution base are frozen
for j, layer in enumerate(phase_3_model.layers):
  print (j, layer.name, layer.trainable)

In [None]:
#Since we have changed the trainable property of some layers, we need to 
#re-compile the model
phase_3_model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.BinaryAccuracy()],
)

In [None]:
#set the checkpoint path 
checkpoint_path = base_path + 'DAPT/Checkpoints/' + Strategy + '/phase_3'

#Define callback to save the model after every epoch
callbacks = []
callbacks.append(ModelCheckpoint(checkpoint_path + '/checkpoint-{epoch}.h5'))

In [None]:
#Set total epochs, batch size and initial epoch number
total_epochs = 50
batchSize=16
initial_epoch_number = 0

In [None]:
#Start DAPT phase 2
history = phase_3_model.fit(TrainX, TrainY,  batch_size=batchSize, 
                    epochs=total_epochs,
                    initial_epoch=initial_epoch_number,
                    verbose=2,
                    callbacks=callbacks)

# Phase 3 of DAPT (Resume Training)
Run the following cells only when resuming DAPT (phase 3) for a particular 
strategy from a particular epoch number. **DO NOT** run the following cells if you are
starting phase 3 of DAPT from epoch no. 0.

#The most recent checkpoint will be loaded and training will be resumed from where it was interrupted.

In [None]:
#set the checkpoint path 
checkpoint_path = base_path + 'DAPT/Checkpoints/' + Strategy + '/phase_3'

#Define callback to save the model after every epoch
callbacks = []
callbacks.append(ModelCheckpoint(checkpoint_path + '/checkpoint-{epoch}.h5'))

In [None]:
#Set total epochs and batch size 
total_epochs = 50
batchSize=16

In [None]:
#set the epoch number from where the training will be resumed.
initial_epoch_number = 20 

#loading the saved checkpoint from where to resume training 
phase_3_model = models.load_model(checkpoint_path + '/checkpoint-' + str(initial_epoch_number) + '.h5')

In [None]:
#Resume DAPT phase 2 training
history = phase_3_model.fit(TrainX, TrainY,  batch_size=batchSize, 
                    epochs=total_epochs,
                    initial_epoch=initial_epoch_number,
                    verbose=2,
                    callbacks=callbacks)

#END OF PHASE 3 DAPT