# Today you are a MLE@Samsung Research and your goal is to perform segmentation of cystic regions from OCT images.
## This work is based on the recent publication https://arxiv.org/abs/2008.02952
## This model is adapted from the original codebase in https://github.com/sohiniroych/U-net_using_TF2

# Optical Coherence Tomography (OCT) images represent grayscale images representing the depth of retina. Cystic regions are gaps in the retina as shown below:
<img src='https://drive.google.com/uc?id=1YRljOSUMEBLKBCSiU1TOAfwnoBcrV7LS' width="600">


## Your goal is to segment the cysts (dark gaps) in the images using the U-net model.

# Your Deliverables are as follows:
### 1. Train a u-net model from scratch and test performance on test images for 2 OCT repos.
### 2. Vary the loss function, kernel dilation, depthwise separability of the kernels, and report results.
### 3. Report observations with and without Batch normalization and Dropout at test time.
### 4. If you use Dropout at test time and generate 2-3 test predictions, what do you observe from these predictions? 

# Task 1: Construct U-net model from scratch for the 'cirrus_3' data set. Report performance on test set and save the model to disk.

### If using Colab, mount your Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### If you're running locally, especially with RTX series GPUs, limiting GPU memory growth can be helpful. Otherwise ignore

In [None]:
#This code snippet helps if your computer has RTX 2070 GPU. If not then comment this cell.
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

## Lets start by stepwise defining all libraries and functions needed to generate the model and pre-process the data

In [None]:
#Step 1: Load libraries for the U-net Model
import numpy as np 
import os
import skimage.io as io
import numpy as np
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras import backend as keras
import tensorflow as tf

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import skimage.io as io
import matplotlib.pyplot as plt

In [None]:
#Step 2: Define the U-net model
def unet(pretrained_weights = None,input_size = (256,256,1)):
    inputs = tf.keras.Input(shape=input_size)
    conv1 = Conv2D(64, 3, activation = 'relu',padding = 'same', kernel_initializer = 'he_normal')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(64, 3, activation = 'relu',padding = 'same', kernel_initializer = 'he_normal')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(128, 3, activation = 'relu', dilation_rate=2,padding = 'same', kernel_initializer = 'he_normal')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(128, 3, activation = 'relu', dilation_rate=2, padding = 'same', kernel_initializer = 'he_normal')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
    conv4 = BatchNormalization()(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
    conv5 = BatchNormalization()(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
    merge6 = concatenate([drop4,up6], axis = 3)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
    

    up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
    merge7 = concatenate([conv3,up7], axis = 3)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
    

    up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
    merge8 = concatenate([conv2,up8], axis = 3)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
    

    up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
    merge9 = concatenate([conv1,up9], axis = 3)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
    conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
   
    conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)

    model = tf.keras.Model(inputs = inputs, outputs = conv10)

    model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = 'accuracy')

    if(pretrained_weights):
    	model=keras.models.load_model(pretrained_weights)

    return model

In [None]:
# Change directory to wherever you've stored unet_helper_functions, for instance for my Colab:
import os
os.chdir('/content/drive/MyDrive/Live_session_notebooks/week_7/')  # change this for your system

#All additional functions for data prep and evaluation are housed in unet_helper_finctions.py
from unet_helper_functions import *

## All definitions are now done! Lets start using the functions now...
## B. Call to image data generator, model initialization, followed by model fitting.

In [None]:
#Step 1: Call to image data generator in keras
data_gen_args = dict(rotation_range=0.2,
                    width_shift_range=0.05,
                    height_shift_range=0.05,
                    shear_range=0.05,
                    zoom_range=[0.7,1],
                    horizontal_flip=True,
                    fill_mode='nearest')
PATH='/content/drive/MyDrive/Datasets/week_7/Data/cirrus_3/'  # give the path to where you've stored and decompressed Data.zip, the cirrus_3 subdirectory

In [None]:
data_gen = trainGenerator(10,PATH+'train/','Image','GT',data_gen_args)

############### 
# If you want to view the augmented training images you can run these three lines instead of the one above
###############
# if not os.path.exists(PATH +'train/aug'):
#     os.makedirs(PATH+'train/aug')
# data_gen = trainGenerator(10,PATH+'train/','Image','GT',data_gen_args, save_to_dir = PATH+'train/aug')

In [None]:
#Step 2: Initialize the model. We're going to train it from scratch!
model = unet()
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 256, 256, 64) 640         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 256, 256, 64) 256         conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 256, 256, 64) 36928       batch_normalization[0][0]        
______________________________________________________________________________________________

  "The `lr` argument is deprecated, use `learning_rate` instead.")


In [None]:
##### Comment this cell out if you have any issues with tensorboard

#Step 3: Initialize Tensorboard to monitor changes in Model Loss 
import datetime
%load_ext tensorboard
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

In [None]:
# if you used tensorboard callbacks, try launching tensorboard to view the logs:
# If developing locally: `tensorboard --logdir <THIS_DIRECTORY>/logs
# If developing on Colab: `%tensorboard --logdir logs`
%load_ext tensorboard
%tensorboard --logdir log_dir

In [None]:
#Step 4: Fit the u-net model

# saves the best version of the model as `unet_cirrus3_V1.hdf5`
model.fit(data_gen,steps_per_epoch=15,epochs=50,verbose=1,callbacks=[model_checkpoint_callback, tensorboard_callback])




## C. Run the trained model on test images and save the outputs, and evaluate pixel-level segmentation performance 

In [None]:
# Step 1: create a directory to store predicted segmentations
if not os.path.exists(PATH+'test/pred_V1'):
    os.makedirs(PATH+'test/pred_V1')

#Step 2: Run model on test images and save the images
#number of test images
n_i=len(os.listdir(PATH+'test/Image/'))
#Call test generator
test_gen = testGenerator(PATH+'test/Image/')
#Return model outcome for each test image
results = model.predict_generator(test_gen,n_i,verbose=1)
#If dropout is activated for test data, then calling this function multiple times will generate difefrent outputs!
saveResult(PATH+'test/Image/', PATH+'test/pred_V1/',results)

In [None]:
#Step 2: Evaluate the predicted outcome
gt_path=PATH+'test/GT/'
evalResult(gt_path,results)

# Task 2: Make some modifications to the model (kernels and loss function)

* First create a version of the model `unet_mod` which uses dilated kernels (first try a dilation rate of 2 then you can experiment with other values). Train this model with the dice coefficient loss, while tracking the dice coefficient metric.

Save the best version of this model (using the checkpoint callback as we did above) as `unet_cirrus3_V2.hdf5`

In [None]:
#Define Additional loss functions for this task
def dice_coef(y_true, y_pred, smooth=1):
    intersection = keras.sum(y_true * y_pred, axis=[1,2,3])
    union = keras.sum(y_true, axis=[1,2,3]) + keras.sum(y_pred, axis=[1,2,3])
    return keras.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

In [None]:
#Modify U-net Definition
def unet_mod(pretrained_weights = None,input_size = (256,256,1)):
    ### PUT YOUR MODIFIED VERSION OF UNET HERE ####
    ###############################################

    if(pretrained_weights):
    	model=keras.models.load_model(pretrained_weights)

    return model

In [None]:
# Fit the model
model_mod=unet_mod()
## YOUR CODE TO CHECKPOINT AND FIT THE MODEL HERE ##


In [None]:
def saveResult_raw(img_path,save_path,npyfile,flag_multi_class = False,num_class = 2):
    files=os.listdir(img_path)
    
    for i,item in enumerate(npyfile):
        img = labelVisualize(num_class,COLOR_DICT,item) if flag_multi_class else item[:,:,0]
        io.imsave(os.path.join(save_path, files[i]+'_predict.png'),img)

In [None]:
if not os.path.exists(PATH+'test/pred_V2'):
    os.makedirs(PATH+'test/pred_V2')

n_i=len(os.listdir(PATH + 'test/Image/'))
test_gen = testGenerator(PATH+'test/Image/')
results_mod = model_mod.predict_generator(test_gen,n_i,verbose=1)
saveResult_raw(PATH+'test/Image/', PATH+'test/pred_V2/',results_mod)


In [None]:
gt_path=PATH+'test/GT/'
evalResult(gt_path,results_mod)

* Now try making some of the convolution layers DepthwiseConv2D instead of Conv2D. Keep the dice coefficient as the metric and loss. Save a checkpoint of this nodel as `unet_cirrus3_v3.hdf5`

In [None]:
def unet_depth(pretrained_weights = None,input_size = (256,256,1)):
    ### PUT YOUR MODIFIED VERSION OF UNET HERE ####
    ###############################################

    if(pretrained_weights):
    	model=tf.keras.models.load_model(pretrained_weights)

    return model

In [None]:
# fit the model
model_depth=unet_depth()
## YOUR CODE TO CHECKPOINT AND FIT THE MODEL HERE ##

In [None]:
if not os.path.exists(PATH+'test/pred_V3'):
    os.makedirs(PATH+'test/pred_V3')

n_i=len(os.listdir(PATH+'test/Image/'))
test_gen = testGenerator(PATH+'test/Image/')
results_depth = model_depth.predict_generator(test_gen,n_i,verbose=1)
saveResult_raw(PATH+'test/Image/',PATH+'test/pred_V3/',results_depth)


In [None]:
gt_path=PATH+'test/GT/'
evalResult(gt_path,results_depth)

## Select the best network parameters for semantic segmentation here and save the best model as unet_cirrus3.hdf5! Enter metrics for the three versions of your model into the table below:

|U-net Parameters  (cirrus_3)          | Precision|Recall|IoU   |acc   |F1    | Size |
|------|-------|---------|-------------|----------|--------------|--------|
|binary cross entropy loss    |    **   |  **   |  **  | ** | **  | ** (MB) |
|dilated kernels, dice coef|     **  | ** | ** | ** | ** | ** (MB) |
|depthwise separable kernels, dice coef|   **    |  **  |  **  | ** | ** | ** (MB) |

# Task 3: Perform transfer learning with each of the `unet_cirrus3_Vx.hdf5` as the base weights and retrain (fine-tune) on the 'nidek1' data set . Report the same table as above for the 'nidek1' test data.

# Task 4: Report test performance on Cirrus3 and Nidek1 for the following:
## A. Remove the BatchNormalization commands.
## B. Activate dropout on test data (enable training=True) and create 2 cyst masks for each test image. Comment on the overlap between the cyst masks per image. What do you learn here?