In [None]:
import os
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import backend
import wandb
from wandb.keras import WandbCallback

In [None]:
from encoders import EncoderResNet18, encoderCNN
from decoders import DecoderResNet18, decoderCNN
from datasets import data_loader
from embedding import embedding
from src.CVAE import CVAE



backend.clear_session()

In [None]:
# TO DO: this should be passed as arguments
dataset_name = 'cifar10'
model_name = 'CVAE_resent'
kl_coefficient = .65
encoded_dim = 128
learning_rate = 0.0001 
epoch_count = 2
batch_size = 100
patience = 5

In [None]:
#TO DO: move datasets in the repo and change root_folder

train_x, test_x, val_x, train_y, test_y, val_y, train_y_one_hot, test_y_one_hot, val_y_one_hot, input_shape, category_count, labels = data_loader(name=dataset_name,
                                                                                                                                     root_folder='/home/PERSONALE/nicolas.derus2/HistoDL/datasets/')

In [None]:

wandb.init(project="HistoDL", entity="nrderus",
  config = {
  "dataset": dataset_name,
  "model": model_name,
  "encoded_dim": encoded_dim,
  "kl_coefficient": kl_coefficient,
  "learning_rate": learning_rate,
  "epochs": epoch_count,
  "batch_size": batch_size,
  "patience": patience,
})

In [None]:
if 'resnet' in model_name:
    encoder = EncoderResNet18(encoded_dim = encoded_dim)
    encoder = encoder.model(input_shape=(input_shape[0], input_shape[1], input_shape[2] + category_count))
else:
    encoder = encoderCNN(input_shape, category_count, encoded_dim)

encoder

In [None]:
if 'resnet' in model_name:
    decoder = DecoderResNet18( encoded_dim = encoded_dim)
    decoder = decoder.model(input_shape=(encoded_dim + category_count,))
else:
    decoder = decoderCNN(input_shape, category_count, encoded_dim)

decoder

In [None]:
#kl_coefficient = encoded_dim / (input_shape[0] * input_shape[1] * input_shape[2]) # from b vae paper, use beta = encoded_dimension / pixel_dimension i.e. -> 0.068
print('kl coefficient: {:.3f}'.format(kl_coefficient))
cvae = CVAE(encoder, decoder, kl_coefficient, input_shape)
cvae.built = True
cvae.summary()

In [None]:
try:
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver('grpc://' + os.environ['COLAB_TPU_ADDR'])
    # This is the TPU initialization code that has to be at the beginning.
    tf.tpu.experimental.initialize_tpu_system(resolver)
    print("All devices: ", tf.config.list_logical_devices('TPU'))

    strategy = tf.distribute.experimental.TPUStrategy(resolver)
    with strategy.scope():
        cvae = CVAE(encoder, decoder, kl_coefficient, input_shape)
        cvae.built = True
        cvae_input = cvae.encoder.input[0]
        cvae_output = cvae.decoder.output
        mu = cvae.encoder.get_layer('mu').output
        log_var = cvae.encoder.get_layer('log_var').output

        opt = keras.optimizers.Adam(learning_rate = learning_rate)
        cvae.compile(optimizer = opt, run_eagerly=False)
except:
    cvae = CVAE(encoder, decoder, kl_coefficient, input_shape)
    cvae.built = True
    cvae_input = cvae.encoder.input[0]
    cvae_output = cvae.decoder.output
    mu = cvae.encoder.get_layer('mu').output
    log_var = cvae.encoder.get_layer('log_var').output


    opt = keras.optimizers.Adam(learning_rate = learning_rate)
    cvae.compile(optimizer = opt, run_eagerly=False)

In [None]:

early_stop = keras.callbacks.EarlyStopping(monitor='val_loss',
             patience=patience, restore_best_weights=False)

history = cvae.fit([train_x, train_y_one_hot],
                   validation_data = ([val_x, val_y_one_hot],None),
                   epochs = epoch_count,
                   batch_size = batch_size,
                   callbacks=[early_stop, WandbCallback(save_model = False) ]) #save_weights_only -> ValueError: Unable to create dataset (name already exists)

In [None]:
embedding(cvae, encoded_dim, category_count, train_x, test_x, val_x, train_y, test_y, val_y,
             train_y_one_hot,  test_y_one_hot, val_y_one_hot, labels, xy_lim = 80, avg_latent = True)