# Model Reloading

Goal of this notebook is to double check that I can load Keras model and Keras model checkpoint. Ultimately want to continue training on model on 2020 data set.

In [42]:
import keras, glob, re, os, math
import numpy as np

import keras.backend as K
from keras.losses import mse
from keras.layers import Conv3D, Activation, Add, UpSampling3D, Lambda, Dense
from keras.layers import Input, Reshape, Flatten, Dropout, SpatialDropout3D
from keras.optimizers import Adam as adam
from keras.models import Model
import tensorflow as tf

import sys
sys.path.append('../auto_encoder/')
from train_model import preprocess, read_img, preprocess_label
from model import build_model  # For creating the model

import SimpleITK as sitk  # For loading the dataset


In [2]:
def dice_coefficient(y_true, y_pred):
    intersection = K.sum(K.abs(y_true * y_pred), axis=[-3,-2,-1])
    dn = K.sum(K.square(y_true) + K.square(y_pred), axis=[-3,-2,-1]) + 1e-8
    return K.mean(2 * intersection / dn, axis=[0,1])


def loss_gt(e=1e-8):
    """
    loss_gt(e=1e-8)
    ------------------------------------------------------
    Since keras does not allow custom loss functions to have arguments
    other than the true and predicted labels, this function acts as a wrapper
    that allows us to implement the custom loss used in the paper. This function
    only calculates - L<dice> term of the following equation. (i.e. GT Decoder part loss)

    L = - L<dice> + weight_L2 ∗ L<L2> + weight_KL ∗ L<KL>

    Parameters
    ----------
    `e`: Float, optional
        A small epsilon term to add in the denominator to avoid dividing by
        zero and possible gradient explosion.

    Returns
    -------
    loss_gt_(y_true, y_pred): A custom keras loss function
        This function takes as input the predicted and ground labels, uses them
        to calculate the dice loss.

    """
    def loss_gt_(y_true, y_pred):
        y_true_float = Lambda(lambda x: K.cast(x, 'float32'), name='change_to_float')(y_true)

        intersection = K.sum(K.abs(y_true_float * y_pred), axis=[-3,-2,-1])
        dn = K.sum(K.square(y_true_float) + K.square(y_pred), axis=[-3,-2,-1]) + e

        return - K.mean(2 * intersection / dn, axis=[0,1])

    return loss_gt_

# def loss_gt_(y_true, y_pred, e=1e-8):
#     y_true_float = Lambda(lambda x: K.cast(x, 'float32'), name='change_to_float')(y_true)

#     intersection = K.sum(K.abs(y_true_float * y_pred), axis=[-3,-2,-1])
#     dn = K.sum(K.square(y_true_float) + K.square(y_pred), axis=[-3,-2,-1]) + e

#     return - K.mean(2 * intersection / dn, axis=[0,1])

# return loss_gt_

def loss_VAE(input_shape, z_mean, z_var, weight_L2=0.1, weight_KL=0.1):
    """
    loss_VAE(input_shape, z_mean, z_var, weight_L2=0.1, weight_KL=0.1)
    ------------------------------------------------------
    Since keras does not allow custom loss functions to have arguments
    other than the true and predicted labels, this function acts as a wrapper
    that allows us to implement the custom loss used in the paper. This function
    calculates the following equation, except for -L<dice> term. (i.e. VAE decoder part loss)

    L = - L<dice> + weight_L2 ∗ L<L2> + weight_KL ∗ L<KL>

    Parameters
    ----------
     `input_shape`: A 4-tuple, required
        The shape of an image as the tuple (c, H, W, D), where c is
        the no. of channels; H, W and D is the height, width and depth of the
        input image, respectively.
    `z_mean`: An keras.layers.Layer instance, required
        The vector representing values of mean for the learned distribution
        in the VAE part. Used internally.
    `z_var`: An keras.layers.Layer instance, required
        The vector representing values of variance for the learned distribution
        in the VAE part. Used internally.
    `weight_L2`: A real number, optional
        The weight to be given to the L2 loss term in the loss function. Adjust to get best
        results for your task. Defaults to 0.1.
    `weight_KL`: A real number, optional
        The weight to be given to the KL loss term in the loss function. Adjust to get best
        results for your task. Defaults to 0.1.

    Returns
    -------
    loss_VAE_(y_true, y_pred): A custom keras loss function
        This function takes as input the predicted and ground labels, uses them
        to calculate the L2 and KL loss.

    """
    def loss_VAE_(y_true, y_pred):
        c, H, W, D = input_shape
        n = c * H * W * D

        loss_L2 = K.mean(K.square(y_true - y_pred), axis=(1, 2, 3, 4)) # original axis value is (1,2,3,4).

        loss_KL = (1 / n) * K.sum(
            K.exp(z_var) + K.square(z_mean) - 1. - z_var,
            axis=-1
        )

        return weight_L2 * loss_L2 + weight_KL * loss_KL

    return loss_VAE_

In [3]:
# path = '/home/ubuntu/model_ae_400_2020-11-03-0700/'
path = '/home/ubuntu/model_ae_3_2020-11-03-0635/'

model = keras.models.load_model(path,custom_objects={'dice_coefficient':dice_coefficient,'loss_gt_':loss_gt, 'loss_VAE_':loss_VAE})



In [13]:
input_shape = (4, 80, 96, 64)

model2 = build_model(input_shape=input_shape, output_channels=3)

In [21]:
path_checkpoints = '/home/ubuntu/checkpoints/ae_weights.400-0.00843.hdf5'
model2.load_weights(path_checkpoints)

In [60]:
# Import data
# Get a list of files for all modalities individually
end_index = 10
input_shape = (4, 80, 96, 64)
output_channels = 3

data = np.empty((len(data_paths[:end_index]),) + input_shape, dtype=np.float32) 
labels = np.empty((len(data_paths[:end_index]), output_channels) + input_shape[1:], dtype=np.uint8)

path = '/home/ubuntu/data/brats-data/MICCAI_BraTS_2018_Data_Training/'
t1 = glob.glob(os.path.join(path, '*GG/*/*t1.nii.gz'))
t2 = glob.glob(os.path.join(path, '*GG/*/*t2.nii.gz'))
flair = glob.glob(os.path.join(path, '*GG/*/*flair.nii.gz'))
t1ce = glob.glob(os.path.join(path, '*GG/*/*t1ce.nii.gz'))
seg = glob.glob(os.path.join(path, '*GG/*/*seg.nii.gz'))  # Ground Truth

pat = re.compile('.*_(\w*)\.nii\.gz')

data_paths = [{
    pat.findall(item)[0]: item
    for item in items
}
    for items in list(zip(t1, t2, t1ce, flair, seg))]

total = len(data_paths[:end_index])
step = 25 / total

for i, imgs in enumerate(data_paths[:2]):
    data[i] = np.array([preprocess(read_img(imgs[m]), input_shape[1:]) for m in ['t1', 't2', 't1ce', 'flair']],
                       dtype=np.float32)
    labels[i] = preprocess_label(read_img(imgs['seg']), input_shape[1:])[None, ...]

    if ~np.isfinite(data[i]).any() or ~np.isfinite(labels[i]).any():
        print('bad frame found:')
        print(data_paths[i])
        bad_frames.append(i)

    # Print the progress bar
    print('\r' + f'Progress: '
                 f"[{'=' * int((i + 1) * step) + ' ' * (24 - int((i + 1) * step))}]"
                 f"({math.ceil((i + 1) * 100 / (total))} %)",
          end='')

Progress: [=====                   ](20 %)

In [61]:
data.shape

(10, 4, 80, 96, 64)

In [51]:
pred = model2.predict(data)

In [56]:
np.unique(pred[0])

array([0., 1.], dtype=float32)

In [66]:
np.unique(labels[0])

array([0, 1], dtype=uint8)

In [68]:
labels[0][0].shape

(80, 96, 64)

In [72]:
np.unique(labels[:,1,1,1])

array([  0,  73, 115, 166, 190, 204, 205], dtype=uint8)

In [73]:
labels.shape

(10, 3, 80, 96, 64)

In [74]:
labels[0,:,0,0,0]

array([0, 0, 0], dtype=uint8)

In [75]:
data.shape

(10, 4, 80, 96, 64)

In [76]:
labels.shape

(10, 3, 80, 96, 64)

In [82]:
pred[0].shape

(2, 3, 80, 96, 64)