here we import python modules first

In [None]:
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras import Model, layers, Input
from tensorflow.keras import metrics
from tensorflow.keras.optimizers import Nadam
#from tensorflow.keras.optimizer_v2 import Nadam
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, BatchNormalization, Dense, Flatten, Reshape, Activation, InputLayer
import cv2
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import os
import pandas as pd
import math

we use followin function to import and normalize image. only three images are used as import. As a byproduct I am expecting to it can create some interesting images which may blend inputs together. imagedatagenerator is used here for the flexiblity, we may want to augment image later.

In [None]:
# The functions normalizes in the scale 0-1.
# Since they are standarized before that, then dividing everything by 255 
# won't work here
def normalize(numpy_array):
    minimum = numpy_array.min()
    maximum = numpy_array.max()

    normalized = (numpy_array - minimum)/(maximum - minimum)

    return normalized

In [None]:
# loading 3 original pictures
def loadImage(path, input_shape = INPUT_SHAPE):
    img2 = cv2.imread(path, cv2.IMREAD_COLOR)
    img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
    img2 = cv2.resize(img2, (input_shape[1], input_shape[0]))
    img2 = img2.reshape((1,) + img2.shape)
    
    mean = np.mean(img2)
    std = np.std(img2)
    data = (img2 - mean)/(std + 1e-7)
    data = normalize(data)

    return data

In [None]:
def loadTrainTensor():
    path = '/Users/justin/projects/autoencoder/20210208_073907.jpg'
    data = loadImage(path).copy()
    
    path = '/Users/justin/projects/autoencoder/2021-02-08_07-45-44.jpg'
    data = np.append(data, loadImage(path).copy(), 0)
    
    path = '/Users/justin/projects/autoencoder/20210208_074628.jpg'
    data = np.append(data, loadImage(path).copy(), 0)

    train_gen = ImageDataGenerator(rotation_range=0, width_shift_range=0.0, height_shift_range=0.0, zoom_range=0.0, fill_mode='nearest')

    return train_gen.flow(data, data)

since we only have 3 image input, we need to add a logic to make the image data generator to output arbitory batch size of data

In [None]:
def custom_data_flow(batch_size = 12, mix_match = False):

    input_data = loadTrainTensor()
    while True:
        input = None
        for repeats, (train_data, _) in enumerate(input_data):
            if repeats >= np.floor(batch_size/3): 
                break
            if input is None:
                input = train_data
            else:
                input = np.append(input, train_data.copy(), axis = 0)
        
        target = input.copy()
        if mix_match:
            np.random.shuffle(target)

        yield input, target
    return

we keras class Model to create our encoder and decoder model classs. decoder and encoder have the reversed layer setting.

In [None]:
class Jencoder(Model):
    def __init__(self, nlayers=7, filters=[32,64,128,196,256,512,1024], kernel_sizes=[3,3,3,3,3,3,3], strides = [2,2,2,2,2,2,2], image_shape=INPUT_SHAPE, latent_dim = LATENT_DIM):
        super().__init__()

        #check on parameters
        assert(len(filters)==nlayers)
        assert(len(kernel_sizes)==nlayers)
        assert(len(strides)==nlayers)

        #prepare layers and calculate last layer kernal size
        self.nlayers = nlayers
        self.conv_list = []
        self.dense_list = []
        self.batch_norm_list = []
        size  = image_shape
        activ = 'selu'

        for i in range(nlayers):
            self.conv_list.append(Conv2D(filters[i], kernel_sizes[i], activation=activ, strides = (strides[i], strides[i]), name = 'encode_conv_%d' % i))
            self.batch_norm_list.append(BatchNormalization())
            size = (int(tf.math.floor((size[0] - kernel_sizes[i])/strides[i]) + 1), int(tf.math.floor((size[1] - kernel_sizes[i])/strides[i]) + 1))
            
        self.last_conv_size = size

        self.flatten_layer = Flatten()
        self.dense_list.append(Dense(self.last_conv_size[0]*self.last_conv_size[1] * filters[-1] / 8 * 8, activation = activ, name = 'encode_dense_1'))
        self.dense_list.append(Dense(self.last_conv_size[0]*self.last_conv_size[1] * filters[-1] / 8 * 4, activation = activ, name = 'encode_dense_2'))
        self.dense_list.append(Dense(latent_dim, name = 'encode_dense_3'))
    
    def last_conv_size():
        return self.last_conv_size

    def call(self, inputs, training = False):
        x = inputs
        for i in range(self.nlayers):
            x = self.conv_list[i](x)
            x = self.batch_norm_list[i](x, training=training)

        x = self.flatten_layer(x)
        x = self.dense_list[0](x)
        x = self.dense_list[1](x)
        x = self.dense_list[2](x)

        return x

In [None]:
class Jdecoder(Model):
    def __init__(self, nlayers=8, filters=[1024, 512, 256, 196, 128, 64, 32, 3], kernel_sizes = [3,3,3,3,3,3,3,1], strides=[2,2,2,2,2,2,2,1], input_kernel_size = None):
        super().__init__()

        #check on parameters
        assert(len(filters)==nlayers)
        assert(len(kernel_sizes)==nlayers)
        assert(len(strides)==nlayers)

        #prepare layers and calculate last layer kernal size
        self.nlayers = nlayers
        self.conv_pos_list = []
        self.batch_norm_list = []
        activ = 'selu'
        size  = input_kernel_size
        for i in range(nlayers-1):
            self.conv_pos_list.append(Conv2DTranspose(filters[i], kernel_sizes[i], activation=activ, strides = (strides[i], strides[i]), name = 'decode_conv_pos_%d' % i))
            self.batch_norm_list.append(BatchNormalization())
        
        self.conv_pos_list.append(Conv2DTranspose(filters[nlayers-1], kernel_sizes[nlayers-1], activation='sigmoid', strides = (strides[nlayers-1], strides[nlayers-1]), name = 'decode_conv_pos_%d' % i))
        self.batch_norm_list.append(BatchNormalization())

        self.activation_layer = Activation(activ)
        self.dense_layer = Dense(input_kernel_size[0]*input_kernel_size[1] * filters[0] / 8 * 8, activation = activ, name='decode_dense_1')
        self.reshape_layer = Reshape((input_kernel_size[0], input_kernel_size[1], filters[0]))
    
    def call(self, inputs, training=False):
        x = self.activation_layer(inputs)
        x = self.dense_layer(x)
        x = self.reshape_layer(x)

        for i in range(self.nlayers):
            x = self.batch_norm_list[i](x, training = training)            
            x = self.conv_pos_list[i](x)
        return x

VAE model is composed by encoder and decoder. 2 additonal layers are added to estimate mean and variant of the distribution. Please take a look on isVaem switch. It makes this class model support both AEM and VAEM.

In [None]:
class Vaem(Model):
    def __init__(self, latent_dim=LATENT_DIM, image_shape = None, learning_rate = 0.00015, isVaem = False):
        super().__init__()        
        
        self.isVaem = isVaem
        self.latent_dim = latent_dim
        self.image_shape = image_shape
 
        self.z_mean_layer = Dense(latent_dim, name='z_mean_layer')
        self.z_var_layer = Dense(latent_dim, name='z_var_layer')

        self.optimizer = Nadam(learning_rate = learning_rate)
        self.batch_norm_z_mean = BatchNormalization(name = 'batch_norm_z_mean')
        self.batch_norm_z_log_var = BatchNormalization(name='batch_norm_z_log_var')

        loss_tracker = metrics.Mean(name="loss")
        mae_metric = metrics.MeanAbsoluteError(name="mae")

    def aem_loss_func(self, input, decoded):
        return MeanSquaredError()(input, decoded)

    def vaem_loss_func(self, input, decoded, z_mean, z_log_var):
        x = K.flatten(input)
        x_decoded = K.flatten(decoded)

        #xent_loss = tf.keras.metrics.binary_crossentropy(x, x_decoded)
        xent_loss = MeanSquaredError()(x, x_decoded)

        kl_loss = -5e-4*K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        loss = K.mean(xent_loss + kl_loss)    

        return loss

    def z_sampling(self, args):
        z_mean, z_log_var = args
        epsilon = K.random_normal(shape=(K.shape(z_mean)[0], self.latent_dim), mean=0., stddev=1.)
        zi = z_mean + K.exp(z_log_var) * epsilon

        return zi

    def compile(self):
        super().compile()

        self.encoder = Jencoder(image_shape = self.image_shape, latent_dim = self.latent_dim)
        self.last_conv_size = self.encoder.last_conv_size
        self.encoder.build((None,) + INPUT_SHAPE)
        
        if self.isVaem:
            self.z_mean_layer.build((None, self.latent_dim))
            self.z_var_layer.build((None, self.latent_dim))

        self.decoder = Jdecoder(input_kernel_size = self.last_conv_size)
        self.decoder.build((None, self.latent_dim))

        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.mae_metric = tf.keras.metrics.MeanAbsoluteError(name="diff")

        # set train-step and call logic based on isVaem switch
        self.train_step = self.VAEM_step if self.isVaem else self.AEM_step
        self.call = self.VAEM_call if self.isVaem else self.AEM_call

        self.encoder.summary()
        self.decoder.summary()

    def VAEM_call(self, input, training = False):
        x = self.encoder(input, training = training)

        z_mean = self.z_mean_layer(x)
        z_mean = self.batch_norm_z_mean(z_mean, training= training)

        z_log_var = self.z_var_layer(x)
        z_log_var = self.batch_norm_z_log_var(z_log_var, training = training)

        z_sample = self.z_sampling((z_mean, z_log_var))
        x = self.decoder(z_sample, training= training)
        return x       

    def AEM_call(self, input, training = False):
        z_sample = self.encoder(input, training= training)
        x = self.decoder(z_sample, training= training)
        return x    

    def VAEM_step(self, data):
        inp, tar = data
        with tf.GradientTape() as tape:
            x = self.encoder(inp, training=True)
            
            z_mean = self.z_mean_layer(x)
            z_mean = self.batch_norm_z_mean(z_mean, training= True)
            z_log_var = self.z_var_layer(x)
            z_log_var = self.batch_norm_z_log_var(z_log_var, training = True)

            z = self.z_sampling((z_mean, z_log_var))
            target = self.decoder(z, training = True)

            loss = self.vaem_loss_func(tar, target, z_mean, z_log_var)            
            trainable_vars = self.trainable_variables
            #self.encoder.trainable_variables + self.z_mean_layer.trainable_variables + self.z_var_layer.trainable_variables + self.decoder.trainable_variables

        grads = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(grads, trainable_vars))

        # Compute our own metrics
        self.loss_tracker.update_state(loss)
        self.mae_metric.update_state(tar, target)

        return {"loss": self.loss_tracker.result(), "diff": self.mae_metric.result()}

    def AEM_step(self, data):
        inp, tar = data
        with tf.GradientTape() as tape:
            z = self.encoder(inp, training = True)
            target = self.decoder(z, training = True)

            loss = self.aem_loss_func(inp, target)
            trainable_vars = self.encoder.trainable_variables + self.decoder.trainable_variables
    
        grads = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(grads, trainable_vars))

        # Compute our own metrics
        self.loss_tracker.update_state(loss)
        self.mae_metric.update_state(tar, target)

        return {"loss": self.loss_tracker.result(), "diff": self.mae_metric.result()}

    # list loss_tracker and mae_metric here so that they will be reset for each epoch in training
    @property
    def metrics(self):
        return [self.loss_tracker, self.mae_metric]


Following function is used to train the model.

In [None]:
def train_model(model):
    logdir = 'jmodel' + datetime.now().strftime('%m_%d_%YT%H')
    tbpath = '/Users/justin/projects/autoencoder/tensorboard/%s' % logdir
    filepathstr = 'weights_%s.tf' % logdir
    tensor_board_callback = TensorBoard(
        log_dir = tbpath,
        histogram_freq=1,
        update_freq = 'epoch',
        write_graph = True,
        embeddings_freq = 0)
    reduce_lr_callback = ReduceLROnPlateau(
        monitor='loss',
        factor=0.618,
        patience=25,
        mode = 'auto',
        verbose=1,
        min_delta = 0.0001)
    model_checkpoint_callback = ModelCheckpoint(
        filepath=filepathstr,
        save_weights_only=True,
        monitor='loss',
        mode='min',
        verbose = 1,
        save_best_only=True)
    early_stopping_callback = EarlyStopping(
        monitor='loss', 
        patience = 100,
        restore_best_weights=True)

    hist = vae.fit(custom_data_flow(BATCH_SIZE), batch_size = BATCH_SIZE, shuffle = False, epochs = EPOCHS, steps_per_epoch = STEPS_IN_EPOCH,
        callbacks=[reduce_lr_callback, model_checkpoint_callback, early_stopping_callback, tensor_board_callback])

    loss = min(hist.history['loss'])
    df = pd.DataFrame(hist.history['loss'])
    df.to_csv('weights_%s.tf.loss%.4f' % (logdir, loss))

    return


INPUT_SHAPE = (255, 255, 3)
LATENT_DIM = 4
    
EPOCHS = 1000
STEPS_IN_EPOCH = 20
BATCH_SIZE = 12 

vae = Vaem(latent_dim = LATENT_DIM, image_shape = INPUT_SHAPE, isVaem = True, learning_rate = 0.0000618)
vae.compile()

# use this option for debugging only
# vae.run_eagerly = True

vae.build((None,) + INPUT_SHAPE)
vae.summary()
train_model(vae)