In [1]:
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Layer, Conv2D, Flatten, Dense, Reshape, Conv2DTranspose
from tensorflow import keras

import tensorflow as tf

import time
import numpy as np
from pathlib import Path

import tensorflow as tf
from albumentations.core.serialization import load as load_albumentations_transform
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers.schedules import ExponentialDecay

try:
    from tensorflow.keras.optimizers.legacy import Adam
except ImportError:
    from tensorflow.keras.optimizers import Adam

from src.data.image_datasets import ImageRandomDataset
from src.utils.config import read_json_config
from src.models.trainer import VAETrainer

2023-07-28 12:26:48.430107: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-07-28 12:26:48.498514: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 



In [2]:
class Sampling(Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [3]:
def create_encoder(
    img_height,
    img_width,
    latent_dim=2,
    ):

    input_layer = Input(shape=(img_height, img_width, 1))

    conv1 = Conv2D(32, 3, activation="relu", strides=2, padding="same")
    conv2 = Conv2D(64, 3, activation="relu", strides=2, padding="same")

    dense = Dense(16, activation="relu")

    dense_mean = Dense(latent_dim, name="z_mean")
    dense_log_var = Dense(latent_dim, name="z_log_var")

    flatten = Flatten()
    sampling = Sampling()

    x = conv2(conv1(input_layer))
    x = dense(flatten(x))

    z_mean, z_log_var = dense_mean(x), dense_log_var(x)
    z = sampling([z_mean, z_log_var])

    return Model(input_layer, [z_mean, z_log_var, z], name="encoder")

In [4]:
def create_decoder(
    img_height,
    img_width,
    latent_dim=2,
    ):
    input_layer = Input(shape=(latent_dim,))
    
    # Two deconvolutions with stride 2, means out_dim = 2 * 2 * in_dim
    in_height, in_width = img_height // 4, img_width // 4

    dense1 = Dense(in_height * in_width * 64, activation="relu")
    reshape_layer = Reshape((in_height, in_width, 64))
    conv1 = Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")
    conv2 = Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")
    conv3 = Conv2DTranspose(1, 3, activation="sigmoid", padding="same")

    x = dense1(input_layer)
    x = reshape_layer(x)
    x = conv1(x)
    x = conv2(x)
    decoder_outputs = conv3(x)
    return Model(input_layer, decoder_outputs, name="decoder")

In [5]:
CONFIG_PATH = '../configs/debug_training_unet.json'
config = read_json_config(CONFIG_PATH)
parameters = config['parameters']

results_dir = Path('..')

# TF dimension ordering in this code
K.set_image_data_format('channels_last')

parameters = config['parameters']

running_time = time.strftime('%b-%d-%Y_%H-%M')
model_dir = results_dir / 'model'

In [6]:
TRANSOFRM_PATH = '../configs/transforms/0_5_fold_res/val.json'
IMG_PATH = Path('../data/debug/img/')

transform = load_albumentations_transform(TRANSOFRM_PATH)
dataset = ImageRandomDataset(
    IMG_PATH,
    transform=transform,
    batch_size=50,
)

In [11]:
encoder = create_encoder(
    img_height=parameters['target_height'],
    img_width=parameters['target_width'],
    latent_dim=16,
)

decoder = create_decoder(
    img_height=parameters['target_height'],
    img_width=parameters['target_width'],
    latent_dim=16,
)

optimizer = Adam(ExponentialDecay(
    initial_learning_rate = parameters['start_lr'],
    decay_steps = parameters['samples_per_epoch']*parameters['scheduler']['step_size'],
    decay_rate = parameters['scheduler']['gamma'],
    ))

In [12]:
trainer = VAETrainer(
    encoder,
    decoder,
    optimizer,
)

In [13]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

model = VAE(encoder, decoder)

In [14]:
trainer.fit(
    10,
    dataset,
    dataset,
)

Epoch #0
Train loss: 45307.78515625
Val loss: 45252.95703125
Epoch #1
Train loss: 45248.8125
Val loss: 45242.30078125
Epoch #2
Train loss: 45240.75390625
Val loss: 45238.890625
Epoch #3


KeyboardInterrupt: 