
# SETUP FOR DRIVE ENVIRONMENT 

In [8]:
# Define the paths
dataset_path = '../input/neuroengineering-project/Data/Data'
model_path = '../input/neuroengineering-project/Data/Model'
output_path = '..output/kaggle/working'

# MODULES TO BE IMPORTED

In [9]:
import os
import sys
import time
import gc
import pickle
import math
from random import shuffle

import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_addons as tfa

# CONSTANTS DEFINITION

In [10]:
# Unzip the files Model.zip and Data.zip
#!unzip 'Data.zip'
#!unzip 'Model.zip'

# FOLDER TO LOAD DATA FROM
DATA_PATH = dataset_path
MODEL_PATH = model_path

# SETUP
# Volume size
N_ROWS_VOLUME           = 128
N_COLUMNS_VOLUME        = 128
N_SLICES_VOLUME         = 64
NOISE                   = 0.0001 # comment this line if your input volume don't require noise

# Label size
N_ROWS_LABEL            = 256
N_COLUMNS_LABEL         = 256
N_SLICES_LABEL          = 64

# Patch sizes
IN_PATCH_SIZE           = (16, 16, 16)
OUT_PATCH_SIZE          = (32, 32, 16)

# Type
VOLUME_TYPE       = 'nii'

# Use this if your input volume require noise
VOLUME_TEMPLATE = "{}/VolumeCT_%s_{}_{}_{}_n{}.{}".format(
    DATA_PATH,
    N_ROWS_VOLUME,
    N_COLUMNS_VOLUME,
    N_SLICES_VOLUME,
    str(NOISE),
    VOLUME_TYPE
    )

# Use this if your input volume don't require noise
#VOLUME_TEMPLATE = "{}/VolumeCT_%s_{}_{}_{}.{}".format(
#    DATA_PATH,
#    N_ROWS_VOLUME,
#    N_COLUMNS_VOLUME,
#    N_SLICES_VOLUME,
#    VOLUME_TYPE
#    )

LABEL_TEMPLATE = "{}/VolumeCT_%s_{}_{}_{}.{}".format(
    DATA_PATH,
    N_ROWS_LABEL,
    N_COLUMNS_LABEL,
    N_SLICES_LABEL,
    VOLUME_TYPE
    )

# DATA
# Number of cases
OVERALL_NUMBER_OF_CASES             = 58
# TRAINING-VALIDATION-TEST PERCENTAGES
TRAINING_PERC_CASES                 = 0.80
VALIDATION_PERC_CASES               = 0.10
TEST_PERC_CASES                     = 1 - TRAINING_PERC_CASES - VALIDATION_PERC_CASES

# Model ID
ModelID                             = 'DensNet_2' 

# Maximum nuber of epochs
MAX_EPOCHS                          = 100
# Size for batch normalization
BATCH_SIZE                          = 4
# Learning Rate
LEARNING_RATE                       = 0.001

# PATCHING AND MERGING

In [11]:
# Patching grid -> list contains locations of patch centers

def make_patch_grid(patch_size, N_ROWS_VOLUME, N_COLUMNS_VOLUME, N_SLICES_VOLUME, overlap_bool=False, overlap = 1/2):
    dim_patch = patch_size
    if overlap_bool is False:
        overlaps = (0, 0, 0)
        center_dist = tuple(s//2 for s in dim_patch) 
        num_patch_dim1 = N_ROWS_VOLUME//dim_patch[0]
        num_patch_dim2 = N_COLUMNS_VOLUME//dim_patch[1]
        num_patch_dim3 = N_SLICES_VOLUME//dim_patch[2]
        
        patches=[]
        for i in range(num_patch_dim1):
            for j in range(num_patch_dim2):
                for k in range(num_patch_dim3):    
                    patch=(i*dim_patch[0]+center_dist[0], 
                           j*dim_patch[1]+center_dist[1], 
                           k*dim_patch[2]+center_dist[2])
                    patches.append(patch)   
    else:
        # For example the overlap is 1/4 of the patch size     
        overlaps = tuple(int(o*overlap) for o in dim_patch)
        center_dist = tuple(d-o for d,o in zip(dim_patch, overlaps))
        num_patch_dim1 = (N_ROWS_VOLUME-dim_patch[0])//center_dist[0] + 1
        num_patch_dim2 = (N_COLUMNS_VOLUME-dim_patch[1])//center_dist[1] + 1
        num_patch_dim3 = (N_SLICES_VOLUME-dim_patch[2])//center_dist[2] + 1
        
        patches=[]
        for i in range(num_patch_dim1):
            for j in range(num_patch_dim2):
                for k in range(num_patch_dim3):    
                    patch=((i+1)*dim_patch[0]//2+center_dist[0]-overlaps[0], 
                          (j+1)*dim_patch[1]//2+center_dist[1]-overlaps[1], 
                          (k+1)*dim_patch[2]//2+center_dist[2]-overlaps[2])
                    patches.append(patch)

    num_patches = len(patches)

    return patches

input_patches = make_patch_grid(IN_PATCH_SIZE , 
                                N_ROWS_VOLUME, 
                                N_COLUMNS_VOLUME, 
                                N_SLICES_VOLUME,
                                overlap_bool = False)

output_patches = make_patch_grid(OUT_PATCH_SIZE , 
                                 N_ROWS_LABEL, 
                                 N_COLUMNS_LABEL, 
                                 N_SLICES_LABEL,
                                 overlap_bool = False)

print(len(input_patches))
print(len(output_patches))


In [12]:
def create_patches(image, patch_list, patch_size):
    patches = []
    step = tuple(s//2 for s in patch_size)
    for p in patch_list:
        patches.append(image[int(p[0])-step[0]:int(p[0])+step[0],
                             int(p[1])-step[1]:int(p[1])+step[1],
                             int(p[2])-step[2]:int(p[2])+step[2]])
    return patches


def create_3D_image(patches, patch_size, image_size, patch_list, inout, overlap_bool=False, overlap = 1/2):
    dim_patch = patch_size
    image = np.zeros(image_size)
    
    if overlap_bool==False:
        step = tuple(int(s//2) for s in patch_size)
        for patch, patch_pos in enumerate(patch_list):
            x = int(patch_pos[0])
            y = int(patch_pos[1])
            z = int(patch_pos[2])
            if inout==True:
                image[x-step[0]:x+step[0],y-step[1]:y+step[1],z-step[2]:z+step[2]]=patches[patch][:,:,:,0]
            else:
                image[x-step[0]:x+step[0],y-step[1]:y+step[1],z-step[2]:z+step[2]]=patches[patch]
    else:
        overlaps_list = []
        overlaps = tuple(int(o*overlap) for o in dim_patch)
        center_dist = tuple(d-o for d,o in zip(dim_patch, overlaps))
        for i, patch_pos in enumerate(patch_list[0:-1]):
            patch_curr = patch_pos
            x_curr = int(patch_curr[0])
            y_curr = int(patch_curr[1])
            z_curr = int(patch_curr[2])
            
            patch_next = patch_list[i+1]
            x_next = int(patch_next[0])
            y_next = int(patch_next[1])
            z_next = int(patch_next[2])
            
            if i == 0:
                image[x_curr-overlaps[0]:x_curr, 
                      y_curr-overlaps[1]:y_curr,
                      z_curr-overlaps[2]:z_curr] = patches[i][0:dim_patch[0]//2,0:dim_patch[1]//2, 0:dim_patch[2]//2]

                overlap_part = patches[i][dim_patch[0]//2:,dim_patch[1]//2:, dim_patch[2]//2:]
                overlap_part += patches[i+1][0:dim_patch[0]//2,0:dim_patch[1]//2, 0:dim_patch[2]//2]
                overlap_part /= 2

                image[x_curr:x_curr+overlaps[0], 
                      y_curr:y_curr+overlaps[1],
                      z_curr:z_curr+overlaps[2]] = overlap_part
            else:
                overlap_part = patches[i][dim_patch[0]//2:, dim_patch[1]//2:, dim_patch[2]//2:]
                overlap_part += patches[i+1][0:dim_patch[0]//2, 0:dim_patch[1]//2, 0:dim_patch[2]//2]
                overlap_part /= 2

                
                image[x_curr-overlaps[0]:x_curr, 
                      y_curr-overlaps[1]:y_curr,
                      z_curr-overlaps[2]:z_curr] = overlap_part
                
                image[x_curr:x_curr+overlaps[0], 
                      y_curr:y_curr+overlaps[1],
                      z_curr:z_curr+overlaps[2]] = patches[i][dim_patch[0]//2:,dim_patch[1]//2:, dim_patch[2]//2:]

    return image 

# Data Preprocesing & Augmentation

## Removing mainly black patches

In [13]:
def remove_black_patches(patches):
    num_voxels = len(patches[0].flatten())
    non_black_patches = []
    non_black_idx = []
    for i, patch in enumerate(patches):
        num_black_voxels = np.sum(patch<0.1)
        if num_black_voxels<num_voxels//5*4:
            non_black_patches.append(patch)
            non_black_idx.append(i)
        
    return non_black_patches, non_black_idx     

# DATA LOADING

In [14]:
# Read available data
AVAILABLE_NUMBER_OF_CASES = 0
try:
    del trainVolumes
    del trainLabels
    del validationVolumes
    del validationLabels
    del testVolumes
    del testLabels
except:
    pass
gc.collect()
volumes_list = []
labels_list = []

for index_case in range(1, OVERALL_NUMBER_OF_CASES+1):
    case_id = "{:0>3}".format(index_case)
    volume_path = VOLUME_TEMPLATE % (case_id)
    label_path = LABEL_TEMPLATE % (case_id)
    if (os.path.exists(volume_path) and os.path.exists(label_path)):
        AVAILABLE_NUMBER_OF_CASES += 1
        volumes_list.append(volume_path)
        labels_list.append(label_path)
    else:
        print("Not found")
        print(volume_path)
        print(label_path)

# SPLIT TRAINING AND VALIDATION SETS
TRAINING_NUMBER_OF_CASES      = int(AVAILABLE_NUMBER_OF_CASES * TRAINING_PERC_CASES) * len(input_patches);
VALIDATION_NUMBER_OF_CASES    = int(AVAILABLE_NUMBER_OF_CASES * VALIDATION_PERC_CASES) * len(input_patches);
TEST_NUMBER_OF_CASES          = AVAILABLE_NUMBER_OF_CASES*len(input_patches) - TRAINING_NUMBER_OF_CASES - VALIDATION_NUMBER_OF_CASES;
print("Number of cases for training: " + str(TRAINING_NUMBER_OF_CASES))
print("Number of cases for validation: " + str(VALIDATION_NUMBER_OF_CASES))
print("Number of cases for testing: " + str(TEST_NUMBER_OF_CASES))

# Training set
trainVolumes = np.empty((TRAINING_NUMBER_OF_CASES, IN_PATCH_SIZE[0], IN_PATCH_SIZE[1], IN_PATCH_SIZE[2])) 
trainLabels = np.empty((TRAINING_NUMBER_OF_CASES, OUT_PATCH_SIZE[0], OUT_PATCH_SIZE[1], IN_PATCH_SIZE[2]))  
# Validation set
validationVolumes = np.empty((VALIDATION_NUMBER_OF_CASES, IN_PATCH_SIZE[0], IN_PATCH_SIZE[1], IN_PATCH_SIZE[2])) 
validationLabels = np.empty((VALIDATION_NUMBER_OF_CASES, OUT_PATCH_SIZE[0], OUT_PATCH_SIZE[1], IN_PATCH_SIZE[2]))  
# Training set
testVolumes = np.empty((TEST_NUMBER_OF_CASES, IN_PATCH_SIZE[0], IN_PATCH_SIZE[1], IN_PATCH_SIZE[2])) 
testLabels = np.empty((TEST_NUMBER_OF_CASES, OUT_PATCH_SIZE[0], OUT_PATCH_SIZE[1], IN_PATCH_SIZE[2]))  

# Grid for patches
grid_volumes = make_patch_grid((16,16,16), N_ROWS_VOLUME, N_COLUMNS_VOLUME, N_SLICES_VOLUME, overlap_bool=False, overlap = 1/2)
grid_labels = make_patch_grid((32,32,16), N_ROWS_LABEL, N_COLUMNS_LABEL, N_SLICES_LABEL, overlap_bool=False, overlap = 1/2)
print('Number of patches in one image: ', len(grid_volumes))

count           = 0
countTraining   = 0
countValidation = 0
countTest       = 0
m=1
for volume, label in zip(volumes_list, labels_list):
    if countTraining < TRAINING_NUMBER_OF_CASES:
        # get the refs to training set
        volumes = trainVolumes
        labels  = trainLabels
        index = countTraining
        countTraining += len(input_patches)
    elif countValidation < VALIDATION_NUMBER_OF_CASES:
        volumes = validationVolumes
        labels  = validationLabels
        index = countValidation
        countValidation += len(input_patches)
    else:
        # get the refs to validation set
        volumes = testVolumes
        labels  = testLabels
        index = countTest
        countTest += 1*len(input_patches)
    
    # Loading label -> getting patches 
    temp = nib.load(label) 
    temp = temp.get_fdata()
    temp = np.asarray(temp)
    patches = create_patches(temp, grid_labels, OUT_PATCH_SIZE )
    labels[index:index+len(patches), :, :, :] = patches

    # Loading inputs -> getting patches 
    temp = nib.load(volume)
    temp = temp.get_fdata()
    temp = np.asarray(temp)
    if m==1:
        image_temp = temp
        m+=1
    patches = create_patches(temp, grid_volumes, IN_PATCH_SIZE )
    volumes[index:index+len(patches), :, :, :] = patches


# Shuffle the dataset

    
trainVolumes = trainVolumes.reshape(trainVolumes.shape + (1,)) # necessary to give it as input to model  
validationVolumes = validationVolumes.reshape(validationVolumes.shape + (1,)) # necessary to give it as input to model  
testVolumes = testVolumes.reshape(testVolumes.shape + (1,)) # necessary to give it as input to model  

remove_black = False
if remove_black==True:
    # Remove black patches
    train_patches = trainVolumes[1:-1,:,:,:,0]
    trainVolumes_non_black, idx_vol = remove_black_patches(train_patches)
    trainLabels = np.array(trainLabels[idx_vol,:,:,:])

    validation_patches = validationVolumes[1:-1,:,:,:,0]
    validationVolumes_non_black, idx_vol = remove_black_patches(validation_patches)
    validationLabels = np.array(validationLabels[idx_vol,:,:,:])
  
    trainVolumes = np.array(trainVolumes_non_black) 
    trainVolumes = trainVolumes.reshape(trainVolumes.shape + (1,)) # necessary to give it as input to model
    validationVolumes = np.array(validationVolumes_non_black)
    validationVolumes = validationVolumes.reshape(validationVolumes.shape + (1,)) # necessary to give it as input to model  

    print('-----------------------------------------------------------')
    print("Number of cases for training after removing black patches: " + str(trainVolumes.shape[0]))
    print("Number of cases for validation after removing black patches: " + str(validationVolumes.shape[0]))
    print("Number of cases for testing after removing black patches: " + str(TEST_NUMBER_OF_CASES))


In [15]:
# Check splitting and merging functions on one image
input_patches = make_patch_grid(IN_PATCH_SIZE , 
                                N_ROWS_VOLUME, 
                                N_COLUMNS_VOLUME, 
                                N_SLICES_VOLUME,
                                overlap_bool = False)

ONE_IMAGE_N_PATCHES = len(input_patches)
image_patches = testVolumes[:ONE_IMAGE_N_PATCHES,:,:,:,:]
image_2 = create_3D_image(image_patches, IN_PATCH_SIZE , (128,128,64), input_patches, inout=True, overlap_bool=False, overlap = 1/2)
print(image_2.shape)

image_temp = testVolumes[:256,:,:,:,:]
print(image_temp.shape)
image_temp = create_3D_image(image_temp, IN_PATCH_SIZE , (128,128,64), input_patches, inout=True, overlap_bool=False, overlap = 1/2)

fig, axis = plt.subplots(nrows=1, ncols=2, figsize=(10,18))
axis[0].imshow(image_temp[:,:,32], cmap=plt.get_cmap('gray'))
axis[0].set_title('Before Splitting Image')
axis[1].imshow(image_2[:,:,32], cmap=plt.get_cmap('gray'))
axis[1].set_title('After Merging Image')
plt.tight_layout()


# MODEL ARCHITECTURE

## Define the class for padding = 'reflect'

In [18]:
#from keras.engine.topology import Layer
#from keras.engine import InputSpec

#class ReflectionPadding3D(Layer):
#    def __init__(self, padding=(1, 1, 1), **kwargs):
#        self.padding = tuple(padding)
 #       self.input_spec = [InputSpec(ndim=4)]
#        super(ReflectionPadding3D, self).__init__(**kwargs)

#    def get_output_shape_for(self, s):
#        """ If you are using "channels_last" configuration"""
#        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3] + 2 * self.padding[2], s[4])

 #   def call(self, x, mask=None):
 #       w_pad,h_pad = self.padding
 #       return tf.pad(x, [[0,0,0], [h_pad,h_pad,h_pad], [w_pad,w_pad,w_pad], [0,0,0] ], 'REFLECT')

In [17]:
# Initialize input size
input_tensor = tf.keras.layers.Input(shape=(IN_PATCH_SIZE[0],
                                            IN_PATCH_SIZE[1],
                                            IN_PATCH_SIZE[2],
                                            1))
# Number of filters                                            
k = 48 

#--------------------------------------------------------------------------------------------
#----------------------------Create Model----------------------------------------------------
#--------------------------------------------------------------------------------------------

conv1_out = tf.keras.layers.Conv3D(filters = 2*k, kernel_size=(3,3,3),
                                   strides = (1,1,1), padding='same')(input_tensor)

up6_1 = tf.keras.layers.UpSampling3D(size=(2,2,1))(conv1_out)

print('Output shape after conv1: {}'.format(conv1_out.shape))
print('Output upsamp shape after conv1: {}'.format(up6_1.shape))

#----------------------------Block 1---------------------------------------------------------

block1_out = tf.keras.layers.LayerNormalization()(conv1_out)
#block1_out = tfa.layers.SpectralNormalization()(conv1_out)
block1_out = tf.keras.layers.ELU()(block1_out)
block1_out = tf.keras.layers.Conv3D(filters = k, kernel_size=(3,3,3),
                                    strides = (1,1,1), padding='same')(block1_out)
up6_2 = tf.keras.layers.UpSampling3D(size=(2,2,1))(block1_out)

#----------------------------Block 2---------------------------------------------------------

#block2_out = tfa.layers.SpectralNormalization()(block1_out)
block2_out = tf.keras.layers.LayerNormalization()(block1_out)
block2_out = tf.keras.layers.ELU()(block2_out)
block2_out = tf.keras.layers.Conv3D(filters = k, kernel_size=(3,3,3),
                                    strides = (1,1,1), padding='same')(block2_out)
up6_3 = tf.keras.layers.UpSampling3D(size=(2,2,1))(block2_out)
print('Output shape after block2: {}'.format(block2_out.shape))     

#----------------------------Block 3---------------------------------------------------------
block3_input = tf.keras.layers.Concatenate(axis=4)([block1_out, block2_out])

#block3_out = tfa.layers.SpectralNormalization()(block3_input)
block3_out = tf.keras.layers.LayerNormalization()(block3_input)
block3_out = tf.keras.layers.ELU()(block3_out)
block3_out = tf.keras.layers.Conv3D(filters = k, kernel_size=(3,3,3),
                                    strides = (1,1,1), padding='same')(block3_out)
up6_4 = tf.keras.layers.UpSampling3D(size=(2,2,1))(block3_out)
print('Output shape after block3: {}'.format(block3_out.shape))     

#----------------------------Block 4---------------------------------------------------------

block4_input = tf.keras.layers.Concatenate(axis=4)([block1_out, block2_out, block3_out])

#block4_out = tfa.layers.SpectralNormalization()(block4_input)
block4_out = tf.keras.layers.LayerNormalization()(block4_input)
block4_out = tf.keras.layers.ELU()(block4_out)
block4_out = tf.keras.layers.Conv3D(filters = k, kernel_size=(3,3,3),
                                    strides = (1,1,1), padding='same')(block4_out)
up6_5 = tf.keras.layers.UpSampling3D(size=(2,2,1))(block4_out)
print('Output shape after block4: {}'.format(block4_out.shape))     

#----------------------------Block 5---------------------------------------------------------

block5_input = tf.keras.layers.Concatenate(axis=4)([block1_out, block2_out, 
                                                    block3_out, block4_out])
#block5_out = tfa.layers.SpectralNormalization()(block5_input)
block5_out = tf.keras.layers.LayerNormalization()(block5_input)
block5_out = tf.keras.layers.ELU()(block5_out)
block5_out = tf.keras.layers.Conv3D(filters = k, kernel_size=(3,3,3),
                                    strides = (1,1,1), padding='same')(block5_out)
up6_6 = tf.keras.layers.UpSampling3D(size=(2,2,1))(block5_out)

# ----------------------------Output---------------------------------------------------------  

last_conv_input = tf.keras.layers.Concatenate(axis=4)([up6_6, up6_5, up6_4,
                                                       up6_3, up6_2, up6_1])
out = tf.keras.layers.Conv3D(filters = 1, kernel_size=(3,3,3),
                             strides = (1,1,1), padding='same')(last_conv_input)

output_tensor = tf.keras.layers.Reshape((OUT_PATCH_SIZE[0],
                                         OUT_PATCH_SIZE[1],
                                         OUT_PATCH_SIZE[2],
                                         1))(out)

my_model = tf.keras.Model(inputs = [input_tensor], 
                          outputs = [output_tensor])
my_model.summary()
tf.keras.utils.plot_model(my_model)

# PSNR & SSIM

In [None]:
from tensorflow.keras import backend as K
from tensorflow.image import psnr

def PSNR(y_true, y_pred):
    max_pixel = 1.0
    #psnr = (10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true), axis=-1))))
    return psnr(y_true, y_pred, max_pixel)

def SSIM(y_true, y_pred):
    return -tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0, filter_size = 5))

def MAE_and_SSIM_loss(y_true, y_pred):
    mae = tf.keras.metrics.mean_absolute_error(y_true, y_pred)
    ssim = SSIM(y_true, y_pred)
    return abs(mae-ssim)/2

def PSNR_and_MAE_loss(y_true, y_pred):
    
    # Since PSNR can be (-inf, +inf) we need to scale it to [0, 1]
    psnr_limit = tf.constant(80.0)
    psnr_val = PSNR(y_true, y_pred)  
    if psnr_val > psnr_limit:
        psnr_val = psnr_limit
    if psnr_val < -1*psnr_limit:
        psnr_val = psnr_limit
    psnr_scaled = (psnr_val + psnr_limit)/(2 * psnr_limit)
    
    mae = tf.keras.metrics.mean_absolute_error(y_true, y_pred)
    mae = K.cast(mae, 'float32')
    
    result = (mae + psnr_scaled)/2 
    return tf.constant(result, dtype=tf.float32)


# LOSS FUNCTION

Useful links for loss functions:
* https://towardsdatascience.com/deep-learning-image-enhancement-insights-on-loss-function-engineering-f57ccbb585d7#:~:text=Peak%20signal%2Dto%2Dnoise%20ratio%20definition%20(PSNR)%20is,created%20by%20compressing%20the%20image.
* https://stackoverflow.com/questions/49404309/how-does-keras-handle-multiple-losses

In [None]:
my_loss = [tf.keras.losses.MeanSquaredError(), SSIM]
#my_loss = MAE_and_SSIM_loss

# CUSTOM METRICS

Choose the metrics for you model from tf.keras.metrics or write your own custom metrics here

In [None]:
my_metrics = [tf.keras.metrics.MeanSquaredError(), PSNR, SSIM, tf.keras.metrics.MeanAbsoluteError()]

# CUSTOM CALLBACKS

Choose the callbacks of interest (e.g. Tensorboard or ModelCheckpoint) and append them to my_callbacks_list

**EarlyStopping** callback is used very often. This allows us to monitor our metrics, and stop model training when it stops improving. For example, assume that you want to stop training if the accuracy is not improving by 0.05; you can use this callback to do so. This is useful in preventing overfitting of a model.

**ModelCheckpoint** callback allows us to save the model regularly during training. This is especially useful when training deep learning models which take a long time to train. This callback monitors the training and saves model checkpoints at regular intervals, based on the metric.




In [None]:
from datetime import datetime

def make_callbacks(model_name):

    exps_dir = os.path.join('experiments')
    if not os.path.exists(exps_dir):
        os.makedirs(exps_dir)

    now = datetime.now().strftime('%b%d_%H-%M-%S')

    exp_dir = os.path.join(exps_dir, model_name + '_' + str(now))
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)

    my_callbacks = []

    # Early stopping callbacks ---------------------------------------------------
    es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', 
                                                 patience=10, 
                                                 restore_best_weights = True)
    my_callbacks.append(es_callback)

    # Checkpoints ----------------------------------------------------------------
    ckpt_dir = os.path.join(exp_dir, 'ckpts')
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    ckpt_callback = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, 
                                                    monitor='val_loss', 
                                                    save_best_only=False,
                                                    save_weights_only=False, 
                                                    mode='auto', 
                                                    save_freq='epoch')
    my_callbacks.append(ckpt_callback)

    # Tensorboard ---------------------------------------------------------------- 
    tb_dir = os.path.join(exp_dir, 'tb_logs')
    if not os.path.exists(tb_dir):
        os.makedirs(tb_dir)

    tb_callback = tf.keras.callbacks.TensorBoard(log_dir=tb_dir, update_freq=1)
    my_callbacks.append(tb_callback)

    return my_callbacks



# MODEL TRAINING

In [None]:
# Set model optimizer
# compiling model:
my_model.compile(
    optimizer = tf.keras.optimizers.Adam(learning_rate = LEARNING_RATE),
    loss      = my_loss,
    loss_weights = [2,1],
    metrics   = my_metrics 
    )

# Create the folder for callbacks
ModelID = 'Model_3D_DensNet_Patches_'
my_callbacks = make_callbacks(model_name=ModelID)


# Run Model Training
monitoring = my_model.fit(x = trainVolumes, 
                          y = trainLabels, 
                          batch_size = BATCH_SIZE,
                          shuffle = True,
                          epochs = MAX_EPOCHS, 
                          validation_data = (validationVolumes,
                                             validationLabels),  
                          callbacks = my_callbacks) 

In [None]:
# Save the net
my_model.save(os.path.join(output_path,
                           f"model_{ModelID}.h5"))

# MODEL EVALUATION

In [None]:
PLOT = True
prediction = my_model.predict(testVolumes[:,:,:,:,0])

print(my_model.evaluate(x = testVolumes, y = testLabels))
          
# just visualizing slice by slice:
#if PLOT:
#    for case in range(len(testVolumes)):
 #       for i in range(0, N_SLICES_VOLUME):
#            fig = plt.figure(figsize = [9, 3])
#            plt.subplot(1, 3, 1)
#            plt.imshow(testVolumes[case, :, :, i, 0], cmap = 'gray')
#            plt.subplot(1, 3, 2)
#            plt.imshow(testLabels[case, :, :, i], cmap = 'gray')
#            plt.subplot(1, 3, 3)
#            plt.imshow(prediction[case, :, :, i], cmap = 'gray')
#            plt.show(fig)

In [None]:
# If you want to load a pre-trained model
# Set this parameter 'True' only if you want to load another model, otherwise leave it 'False'
LOAD = False                    
#MODEL_NAME =       # The name of the model to load 

if LOAD is True:
    model_to_load = os.path.join(output_path, f"model_{ModelID}.h5")
    my_model = tf.keras.models.load_model(model_to_load, compile = False)

    my_model.compile(
        optimizer = keras.optimizers.Adam(learning_rate = custom_learning_rate_program()),
        loss    = my_loss,
        metrics = my_metrics,
        )

    print("Model loaded!")

# MERGING Neural Network Outputs

In [None]:
import matplotlib.pyplot as plt

image_idx = 1
num_patches = 256
patch_depth = 16

# Extract patches that make one image from the input and the output
testVol = testVolumes[:image_idx*num_patches,:,:,:,:]  # first 32 patches are the first input image 
testLab = testLabels[:image_idx*num_patches,:,:,:]
predVol = prediction[:image_idx*num_patches,:,:,0]     # first 32 patches are the first label image

# Merge patches
image_in = create_3D_image(testVol, (16,16,16), (128,128,64), input_patches, inout=True)
image_out = create_3D_image(predVol, (32,32,16), (256,256,64), output_patches, inout= False)
image_true = create_3D_image(testLab, (32,32,16), (256,256,64), output_patches, inout= False)

fig, axis = plt.subplots(nrows=1, ncols=3, figsize=(13,20))
axis[0].imshow(image_in[:,:,32], cmap=plt.get_cmap('gray'))
axis[0].set_title('Input Image')
axis[1].imshow(image_out[:,:,32], cmap=plt.get_cmap('gray'))
axis[1].set_title('Output Image')
axis[2].imshow(image_true[:,:,32], cmap=plt.get_cmap('gray'))
axis[2].set_title('True Image')
plt.tight_layout()


now = datetime.now().strftime('%b%d_%H-%M-%S')
plt.savefig('foo' + now + '.png')

# Plotting the values 
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity

mse_ = mean_squared_error(image_out,image_true)
psnr_ = peak_signal_noise_ratio(image_out,image_true)
ssim_ = structural_similarity(image_out,image_true)
print('MSE: {}, PSNR: {}, SSIM: {}'.format(mse_,psnr_,ssim_))