In [None]:
%load_ext autoreload
# %autoreload 2

### Import libraries

In [None]:
import sys
import numpy as np
import pandas as pd
import os
import glob 
import random
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import cv2
from datetime import datetime
import tensorflow as tf

In [None]:
 assert tf.__version__ == '2.4.1' , "TF version is not matching! Make sure you have tf 2.4.1-gpu installed!"

In [None]:
# # Enable GPU memory growth - avoid allocating all memory at start
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(device=gpu, enable=True)

#### Import custom functions 

In [None]:
import sys
sys.path.append('../')
from src.dataloader.contrastive_learning_loader import debug_batch_of_data, get_training_tfdata,get_test_or_validation_tfdata,view_batch_of_images
from src.models import get_model_func
from src.utils.tf_utils import get_preprocess_func
from src.train_config import config
from src.utils.custom_losses import SupervisedContrastiveLoss

### Get Model and print summary

In [None]:
### MODEL ###
model_func = get_model_func(config)
model = model_func()

In [None]:
model.summary()

### Define keras callbacks 

In [None]:
def get_keras_callbacks(log_dir:str):
    early_stop = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=10) #loss - training, val_loss - validationset loss
    csv_logger = tf.keras.callbacks.CSVLogger(os.path.join(log_dir,'training.log'))
    model_ckp = tf.keras.callbacks.ModelCheckpoint(
        # Path where to save the model
        # The two parameters below mean that we will overwrite
        # the current checkpoint if and only if
        # the `val_loss` score has improved.
        # The saved model name will include the current epoch.
        filepath=os.path.join(log_dir,"model_{epoch}.h5"),
        save_best_only=True,  # Only save a model if `val_loss` has improved.
        save_weights_only=True,
        monitor="loss", # training onlu
    )
    term_nan = tf.keras.callbacks.TerminateOnNaN()

    tensorboard = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        histogram_freq=0,  # How often to log histogram visualizations
        embeddings_freq=0,  # How often to log embedding visualizations
        update_freq="epoch",
        profile_batch = 2
    )  # How often to write logs (default: once per epoch)
    return [early_stop,csv_logger,model_ckp,tensorboard,term_nan]

In [None]:
## LOGDIR to save results
log_dir = os.path.join('/model_registry/output',datetime.now().strftime("%Y%m%d-%H%M%S"))
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

### Define hyperparameters and optimizer for the model

In [None]:
num_epochs = 50
learning_rate = 0.001
dropout_rate = 0.5
temperature = 0.05
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    learning_rate, decay_steps=100000, decay_rate=0.95, staircase=True
)

m_optimizer = tf.keras.optimizers.Adam(learning_rate)
## Compile model
model.compile(
    optimizer=m_optimizer,
    #custom loss function
    loss=SupervisedContrastiveLoss(temperature),
)

### Load training set as tf.data instance

In [None]:
# DATA
train_csv_path =  "/data/oct_train_filtered.csv" 
print(train_csv_path)
# read csv file 
train_df = pd.read_csv(train_csv_path)
train_df.head()

In [None]:
# read the set as tf.data
tf_train_set = get_training_tfdata(train_df,batch_size=8)

In [None]:
# visualize batch of images
view_batch_of_images(tf_train_set,8)

In [None]:
debug_batch_of_data(tf_train_set)

### Get callbacks and train the model

In [None]:
cb_list =  get_keras_callbacks(log_dir)
history = model.fit(tf_train_set,
                epochs=num_epochs,
                callbacks=cb_list,verbose=1)

### Extract embeddings for each sample in training set and save them under data folder
The embedding of each sample from the trained contrastive model will be used for training Conditional Variational Autoencoder mode. Hence, it is important to execute this step after training.

In [None]:
# load model (change model name as needed)
model.load_weights("model_registry/output/20220422-160825/model_50.h5")

In [None]:
# load train csv
train_csv_path =  "/data/oct_train_filtered.csv" 
print(train_csv_path)
# read csv file 
train_df = pd.read_csv(train_csv_path)
train_df.head()

In [None]:
# UTILITY FUNCTIONS
from src.dataloader.contrastive_learning_loader import _denorm

preprocessing_func=get_preprocess_func(config)
def open_gray(fn):
    img = cv2.cvtColor(cv2.imread(fn), cv2.COLOR_BGR2GRAY)
    img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    return img

def read_image_3_channel(path):
    img = open_gray(path)
    img = cv2.resize(img, (256, 256), cv2.INTER_LINEAR)
#     img = img/255.
     # Apply model-specific preprocessing function
    img = preprocessing_func(img)
    return img


In [None]:
# embedding save pat 
emb_save_path = "/data/processed/contrastive_learning/train_embeddings/" 
os.makedirs(emb_save_path)

In [None]:
# extract embedding for each sample and save
encoder_emb_list = []
for indx,row in tqdm(train_df.iterrows()):
    pre_procc_img_2 = read_image_3_channel(row['path'])
    pre_procc_img_2 = np.expand_dims(pre_procc_img_2,axis=0)
    pred_emb = model.predict(pre_procc_img_2)
    pred_emb = np.asarray(pred_emb).ravel()
    # basename 
    basename = os.path.basename(row['path'])
    tmp_emb_save = os.path.join(emb_save_path,basename)
    np.save(tmp_emb_save, pred_emb)
    # run inference and quantize embedding precision
    encoder_emb_list.append(pred_emb)

In [None]:
## save embeddings and train ids as a complete numpy array
save_folder = "/data/processed/contrastive_learning/"
embeddings = np.concatenate(encoder_emb_list, axis=0)
np.save(os.path.join(save_folder, 'train_embeddings.npy'), embeddings)
# save filenames
fn_arr = np.array(list(train_df.path.values))
print(fn_arr.shape)
np.save(os.path.join(save_folder, 'train_ids.npy'), fn_arr)