In [13]:
import os
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import backend
import wandb
from wandb.keras import WandbCallback
from tensorflow.keras import regularizers


In [14]:
from encoders import EncoderResNet18, EncoderResNet34, EncoderResNet50, encoderCNN, EncoderMixNet18
from decoders import DecoderResNet18, DecoderResNet34, DecoderResNet50, decoderCNN
from datasets import data_loader
from embeddings import embedding
from reconstructions import reconstructions
from generations import Generations
from activations import VisualizeActivations
from gradcam import GradCam
from src.CVAE import CVAE

# import importlib
# importlib.reload(embeddings)
# from embeddings import embedding

backend.clear_session()

In [15]:
# TO DO: this should be passed as arguments
dataset_name = 'celeba'
model_name = 'CVAE_2stage'
kl_coefficient = 4
encoded_dim = 64
second_dim = 2048
learning_rate = 0.0001 
epoch_count = 50
epoch_count2 = 50
batch_size = 100
patience = 5

In [16]:
if dataset_name == 'experimental':
    #TO DO: move datasets in the repo and change root_folder

    train_ds, val_ds, input_shape, category_count, labels = data_loader(name=dataset_name, root_folder='/home/PERSONALE/nicolas.derus2/HistoDL/datasets/')
else:
    #TO DO: move datasets in the repo and change root_folder

    train_x, test_x, val_x, train_y, test_y, val_y, train_y_one_hot, test_y_one_hot, val_y_one_hot, input_shape, category_count, labels = data_loader(name=dataset_name,
                                                                                                                                        root_folder='/home/PERSONALE/nicolas.derus2/HistoDL/datasets/')

In [17]:
train_x.squeeze().shape

(100000, 64, 64, 3)

In [18]:
%env "WANDB_NOTEBOOK_NAME" "train.ipynb"

env: "WANDB_NOTEBOOK_NAME"="train.ipynb"


In [19]:

wandb.init(project="H-VAE", entity="nrderus",
  config = {
  "dataset": dataset_name,
  "model": model_name,
  "encoded_dim": encoded_dim,
  "kl_coefficient": kl_coefficient,
  "learning_rate": learning_rate,
  "epochs": epoch_count,
  "batch_size": batch_size,
  "patience": patience,
})

In [20]:
if 'resnet' in model_name:
    encoder = EncoderResNet18(encoded_dim = encoded_dim)
    encoder = encoder.model(input_shape=(input_shape[0], input_shape[1], input_shape[2] + category_count))

else:
    encoder = encoderCNN(input_shape, category_count, encoded_dim,  regularizer=regularizers.L2(.001))
    #encoder = EncoderMixNet18(encoded_dim = encoded_dim)
    #encoder = encoder.model(input_shape=(input_shape[0], input_shape[1], input_shape[2] + category_count))

encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 Input (InputLayer)             [(None, 64, 64, 43)  0           []                               
                                ]                                                                 
                                                                                                  
 block1_conv1 (Conv2D)          (None, 64, 64, 16)   6208        ['Input[0][0]']                  
                                                                                                  
 block1_conv2 (Conv2D)          (None, 64, 64, 16)   2320        ['block1_conv1[0][0]']           
                                                                                                  
 batch_normalization (BatchNorm  (None, 64, 64, 16)  64          ['block1_conv2[0][0]']     

In [21]:

encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 Input (InputLayer)             [(None, 64, 64, 43)  0           []                               
                                ]                                                                 
                                                                                                  
 block1_conv1 (Conv2D)          (None, 64, 64, 16)   6208        ['Input[0][0]']                  
                                                                                                  
 block1_conv2 (Conv2D)          (None, 64, 64, 16)   2320        ['block1_conv1[0][0]']           
                                                                                                  
 batch_normalization (BatchNorm  (None, 64, 64, 16)  64          ['block1_conv2[0][0]']     

In [22]:
if 'resnet' in model_name:
    decoder = DecoderResNet18( encoded_dim = encoded_dim, final_stride = 2)
    decoder = decoder.model(input_shape=(encoded_dim + category_count,))
else:
    decoder = decoderCNN(input_shape, category_count, encoded_dim, final_stride = 1, regularizer=regularizers.L2(.001))

decoder.summary()

Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 decoder_input (InputLayer)  [(None, 104)]             0         
                                                                 
 dense_3 (Dense)             (None, 65536)             6881280   
                                                                 
 reshape (Reshape)           (None, 32, 32, 64)        0         
                                                                 
 batch_normalization_3 (Batc  (None, 32, 32, 64)       256       
 hNormalization)                                                 
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 32, 32, 64)        0         
                                                                 
 up_block4_conv1 (Conv2DTran  (None, 32, 32, 64)       36928     
 spose)                                                    

In [23]:
try:
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver('grpc://' + os.environ['COLAB_TPU_ADDR'])
    # This is the TPU initialization code that has to be at the beginning.
    tf.tpu.experimental.initialize_tpu_system(resolver)
    print("All devices: ", tf.config.list_logical_devices('TPU'))

    strategy = tf.distribute.experimental.TPUStrategy(resolver)
    with strategy.scope():
        cvae = CVAE(encoder, decoder, kl_coefficient, input_shape, category_count)
        cvae.built = True
        cvae_input = cvae.encoder.input[0]
        cvae_output = cvae.decoder.output
        mu = cvae.encoder.get_layer('mu').output
        log_var = cvae.encoder.get_layer('log_var').output

        # def scheduler(epoch, lr):
        #     if epoch < 30:
        #         return lr
        #     else:
        #         return lr * tf.math.exp(-0.1)
            
        opt = keras.optimizers.Adam(learning_rate = learning_rate)
        cvae.compile(optimizer = opt, run_eagerly=False)
except:
    
    cvae = CVAE(encoder, decoder, kl_coefficient, input_shape, category_count)
    cvae.built = True
    cvae_input = cvae.encoder.input[0]
    cvae_output = cvae.decoder.output
    mu = cvae.encoder.get_layer('mu').output
    log_var = cvae.encoder.get_layer('log_var').output

    # def scheduler(epoch, lr):
    #     if epoch < 30:
    #         return lr
    #     else:
    #         return lr * tf.math.exp(-0.1)
    opt = keras.optimizers.Adam(learning_rate = learning_rate)
    cvae.compile(optimizer = opt, run_eagerly=False)

In [24]:
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss',
             patience=patience, restore_best_weights=False)


# lr_decay = tf.keras.callbacks.LearningRateScheduler(
#     lambda epoch: learning_rate * learning_rate_exp_decay**epoch,
#     verbose=True)

history = cvae.fit([train_x, train_y_one_hot],
                   validation_data = ([val_x, val_y_one_hot],None),
                   epochs = epoch_count,
                   batch_size = batch_size,
                   callbacks=[early_stop, WandbCallback(save_model = False) ]) #save_weights_only -> ValueError: Unable to create dataset (name already exists)



Epoch 1/50

In [None]:
# tf.saved_model.save(cvae.encoder, 'cvae_encoder')
# tf.saved_model.save(cvae.decoder, 'cvae_decoder')

In [None]:
_, input_label_train, train_input = cvae.conditional_input([train_x[:1000], train_y_one_hot[:1000]])
_, input_label_test, test_input = cvae.conditional_input([test_x[:1000], test_y_one_hot[:1000]])
_, input_label_val, val_input = cvae.conditional_input([val_x[:1000], val_y_one_hot[:1000]])

train_x_mean, train_log_var = cvae.encoder.predict(train_input)
test_x_mean, test_log_var = cvae.encoder.predict(test_input)
val_x_mean, val_log_var = cvae.encoder.predict(val_input)

In [None]:
from src.CVAE import SecondStage
from encoders import encoder2
from decoders import decoder2
z_cond = cvae.sampling(train_x_mean, train_log_var,train_y_one_hot[:1000])
encoder_2 = encoder2(encoded_dim, category_count, second_dim=1024, second_depth=3 )
decoder_2 = decoder2(encoded_dim, category_count, second_dim=1024, second_depth=3 )
cvae2 = SecondStage(encoder_2, decoder_2, category_count, batch_size)
opt2 = keras.optimizers.Adam(learning_rate = learning_rate)
cvae2.compile(optimizer = opt2, run_eagerly=False)

In [None]:
history2 = cvae2.fit([z_cond, train_y_one_hot[:1000]],
                    validation_data = None,
                   epochs = epoch_count2,
                   batch_size = batch_size)

In [None]:
embedding(encoded_dim, category_count, train_x_mean, test_x_mean, val_x_mean, train_y, test_y, val_y, train_log_var, test_log_var, val_log_var, labels, quantity = 1000, avg_latent=True)

In [None]:
reconstructions(cvae, train_x, train_y, train_x_mean, train_log_var, input_label_train, labels, set = 'train')

In [None]:
from reconstructions import reconstructions2
z_hat = cvae2.posterior(z_cond, train_y_one_hot[:1000])

reconstructions2(cvae, train_x, train_y, z_hat, labels)

In [None]:
reconstructions(cvae, test_x, test_y, test_x_mean, test_log_var, input_label_test, labels, set = 'test')

In [None]:
generator = Generations(cvae, encoded_dim, category_count, input_shape, labels)
generator()

In [None]:
from generations import generations2
generator = Generations(cvae, encoded_dim, category_count, input_shape, labels, cvae2, True)
generator()

In [None]:

activations_encoder = VisualizeActivations(cvae, cvae.encoder, test_x, test_y_one_hot)
activations_decoder = VisualizeActivations(cvae, cvae.decoder, test_x, test_y_one_hot)
activations_encoder()
activations_decoder()

In [None]:
if 'resnet' in model_name:
    target_layer = "layer4"
else:
    target_layer = "block3_conv2"


In [None]:
gc = GradCam(cvae, test_x, test_y_one_hot, HQ = True, target_layer = target_layer)
gc.gradcam()


In [None]:
gc.guided_gradcam()

In [None]:
wandb.finish(exit_code=0, quiet = True) 