In [None]:
%load_ext autoreload
%autoreload 2

### Import libraries

In [None]:
import sys, os
import numpy as np
import glob 
import random
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import cv2
from IPython import display
# import imageio
import PIL
import time
import tensorflow as tf

import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers

from tensorflow.keras import backend as K
from tensorflow.keras.losses import mse, binary_crossentropy
from datetime import datetime
from tqdm import tqdm

In [None]:
# check tf version
 assert tf.__version__ == '2.4.1' , "TF version is not matching! Make sure you have tf 2.4.1-gpu installed!"

In [None]:
# # Enable GPU memory growth - avoid allocating all memory at start
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(device=gpu, enable=True)

In [None]:
# custom functions 
sys.path.append('../')
from src.dataloader.cvae_loader import debug_batch_of_data, get_training_tfdata
from src.models.cvae_model import CCVAE

### Define losses


In [None]:
class FeatureMatchingLoss(keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.mae = keras.losses.MeanAbsoluteError()

    def call(self, y_true, y_pred):
        loss = 0
        for i in range(len(y_true) - 1):
            loss += self.mae(y_true[i], y_pred[i])
        return loss


class VGGFeatureMatchingLoss(keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.encoder_layers = [
            "block1_conv1",
            "block2_conv1",
            "block3_conv1",
            "block4_conv1",
            "block5_conv1",
        ]
        self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
        vgg = keras.applications.VGG19(include_top=False, weights="imagenet")
        layer_outputs = [vgg.get_layer(x).output for x in self.encoder_layers]
        self.vgg_model = keras.Model(vgg.input, layer_outputs, name="VGG")
        self.mae = keras.losses.MeanAbsoluteError()

    def call(self, y_true, y_pred):
        y_true = keras.applications.vgg19.preprocess_input(127.5 * (y_true + 1))
        y_pred = keras.applications.vgg19.preprocess_input(127.5 * (y_pred + 1))
        real_features = self.vgg_model(y_true)
        fake_features = self.vgg_model(y_pred)
        loss = 0
        for i in range(len(real_features)):
            loss += self.weights[i] * self.mae(real_features[i], fake_features[i])
        return loss

# KL-Divergence loss
def kl_divergence_loss(mean, variance):
    return -0.5 * tf.reduce_sum(1 + variance - tf.square(mean) - tf.exp(variance))
# MSE
MSE = tf.keras.losses.MeanSquaredError()

def reconstruction_loss(y_true, y_pred):
#     mse = tf.keras.losses.MeanSquaredError()
    return MSE(y_true,y_pred)

# Perceptual loss
vgg_loss = VGGFeatureMatchingLoss()

# Feature matching loss
feature_matching_loss = FeatureMatchingLoss()

def vae_loss(inputs, outputs,z_mean,z_log_var, image_size=256):
    beta = 1.0
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5 * beta
    tmp_vgg_loss = vgg_loss(inputs, outputs)
    tmp_feat_loss = feature_matching_loss(inputs, outputs)
    cvae_loss = K.mean(25*tmp_vgg_loss + 10*tmp_feat_loss+0.1*kl_loss) # best result so far
    return cvae_loss

In [None]:
# utiliy function to compute loss
def compute_loss_v3(model, x_image, x_cond_labels):
        mean, logvar = model.encode([x_image, x_cond_labels])
        z = model.reparameterize(mean, logvar)
        x_logit = model.decode([z, x_cond_labels])
        return vae_loss(x_image, x_logit,mean,logvar,image_size=256)

### Training hyperparameters

In [None]:
# hyperparameters
epochs = 200
# set the dimensionality of the latent space to a plane for visualization later
latent_dim = 128
image_shape = (256,256,3)
num_examples_to_generate = 8
batch_size = 4

### Tf function for custom training loop

In [None]:
# Optimizer 
optimizer = tf.keras.optimizers.Adam(1e-3)
# tf function for custom training loop
@tf.function
def train_step(model, x_image, x_cond_labels, optimizer):
    """Executes one training step and returns the loss.

      This function computes the loss and gradients, and uses the latter to
      update the model's parameters.
      """
    with tf.GradientTape() as tape:
#         loss = compute_loss(model, x_image, x_cond_labels)
        loss = compute_loss_v3(model, x_image, x_cond_labels)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

### Read training set of CNV cases 
In this notebook, we will train a model for CNV class only. Hence, while loading the training set paths, we will filter only CNV cases. 


In [None]:
train_csv_path =  "/data/oct_train_filtered.csv" 
train_df = pd.read_csv(train_csv_path)
train_df.head()

In [None]:
# ONLY CNV cases
train_df = train_df[train_df['label']=='CNV']
# get tf training set
tf_train_set = get_training_tfdata(train_df,batch_size=batch_size)

### Create output directory under `model_registry` to save training artifacts

In [None]:
LOG_DIR = os.path.join('/model_registry/output',datetime.now().strftime("%Y%m%d-%H%M%S")+'_cvae_subpixel_kl_reco_feat_loss_CNV')
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)

### Sumsample a test batch to visualize the generation while training

In [None]:
for test_batch in tf_train_set.take(1):
    test_sample = test_batch

### Utility function to save generated images over the test batch while training

In [None]:
def generate_and_save_images(model, epoch, test_sample,apply_sigmo=False):
    x_images_test, x_cond_labels_test = test_sample
    mean, logvar = model.encode([x_images_test, x_cond_labels_test])
    z = model.reparameterize(mean, logvar)
#     print(z.shape)
#     print(z)
#     print(mean,logvar)
    predictions = model.decode([z, x_cond_labels_test],apply_sigmoid=apply_sigmo)
    
#     predictions = model.sample(z)
    fig = plt.figure(figsize=(20, 20))

    for i in range(predictions.shape[0]):
        ax = plt.subplot(4, 4, i*2 + 1)
        # plt.imshow(x_images_test[i, :, :, :])
        plt.imshow((x_images_test[i, :, :, :] + 1) / 2)
        ax.set_title("Real")
        plt.axis('off')
        ax = plt.subplot(4, 4, i*2 + 2)
        # plt.imshow(predictions[i, :, :, :])
        plt.imshow((predictions[i, :, :, :] + 1) / 2)
        ax.set_title("Predictions")
        plt.axis('off')

      # tight_layout minimizes the overlap between 2 sub-plots
    if apply_sigmo:
        plt.savefig(LOG_DIR +'/image_at_epoch_{:04d}_sigmoid.png'.format(epoch))
    else:
        plt.savefig(LOG_DIR +'/image_at_epoch_{:04d}_nosigmoid.png'.format(epoch))
    # plt.show()
    plt.cla()
    plt.close(fig)

### Define CVAE model to train


In [None]:
model = CCVAE(image_shape)

### Train the model 

In [None]:
or epoch in range(0, epochs + 1):
    start_time = time.time()
    for train_x in tqdm(tf_train_set,"running training"):
        tmp_x_images, tmp_x_labels = train_x
        train_step(model, tmp_x_images, tmp_x_labels, optimizer)
    end_time = time.time()

    display.clear_output(wait=False)
    generate_and_save_images(model, epoch, test_sample)
    # generate_and_save_images(model, epoch, test_sample,apply_sigmo=True)
    #save model
    if epoch % 20==0:
        filepath_encoder=os.path.join(LOG_DIR,"model_encoder_"+str(epoch)+".h5")
        filepath_decoder=os.path.join(LOG_DIR,"model_decoder_"+str(epoch)+".h5")
        model.encoder.save_weights(filepath_encoder)
        model.decoder.save_weights(filepath_decoder)

### Post-training data generation sample (over training set)

This section demonstrates a way to generate images with the trained model. One may apply it over test set or any OCT sample with embbedings. 

In [None]:
# training set embedding path( the embeddings are extracted from contrastive_model_training)
EMBEDDING_PATH = "/data/processed/contrastive_learning/train_embeddings"


In [None]:
# save paths for real and corresponding real cases
SAVE_PATH = "/data/processed/cvae_train/train_embeddings/CNV/"
SAVE_DIR_REAL = os.path.join(SAVE_PATH,"real")
SAVE_DIR_PRED = os.path.join(SAVE_PATH,"pred")
if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH,exist_ok=True)
    os.makedirs(SAVE_DIR_REAL,exist_ok=True)
    os.makedirs(SAVE_DIR_PRED,exist_ok=True)

In [None]:
from src.dataloader.contrastive_learning_loader import _denorm
from src.utils.cvae_utils import read_image_3_channel, load_np_embedding_feature

In [None]:
def generate_and_save_images(my_model, x_images_test, x_cond_labels_test,apply_sigmo=False):
    mean, logvar = my_model.encode([x_images_test, x_cond_labels_test])
    z = my_model.reparameterize(mean, logvar)
    predictions = my_model.decode([z, x_cond_labels_test],apply_sigmoid=apply_sigmo)[0]
    return predictions

In [None]:
# generate samples
for indx,row in tqdm(train_df.iterrows()):
    #read image 
    pre_procc_img = read_image_3_channel(row['path'])
    pre_procc_img_2 = np.expand_dims(pre_procc_img,axis=0)
    #read embedding 
    tmp_embedding_path = os.path.join(EMBEDDING_PATH,os.path.basename(row['path'])+".npy")
    tmp_embedding = load_np_embedding_feature(tmp_embedding_path)
    emb_exp_dim = np.expand_dims(tmp_embedding,axis=0)
    # predict
    tmp_pred = generate_and_save_images(model, pre_procc_img_2, emb_exp_dim,apply_sigmo=False)
    tmp_pred_np = tmp_pred.numpy()
    #save predicted and original 
    tmp_pred_np_denorm = _denorm(tmp_pred_np, np.min(tmp_pred_np), np.max(tmp_pred_np))
    denorm_img_to_save = _denorm(pre_procc_img, np.min(pre_procc_img), np.max(pre_procc_img))
    #save
    tmp_basename = os.path.basename(row['path'])
    cv2.imwrite(os.path.join(SAVE_DIR_REAL,tmp_basename),denorm_img_to_save*255,[cv2.IMWRITE_JPEG_QUALITY, 100])
    cv2.imwrite(os.path.join(SAVE_DIR_PRED,tmp_basename),tmp_pred_np_denorm*255,[cv2.IMWRITE_JPEG_QUALITY, 100])