In [1]:
import os
import pickle

from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization, Flatten, Dense, Reshape, Conv2DTranspose, Activation, Lambda
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
import numpy as np
import tensorflow as tf


tf.compat.v1.disable_eager_execution()


class VAE:
    """
    VAE is a conv variational autoencoder that has a mirror architecture it down sample the input
    and then upsample it again to the needed format.
    """

    def __init__(self,
                 input_shape,
                 conv_filters, # is the num of filters to apply
                 conv_kernels, # the shape of the filter 
                 conv_strides, # the steps that the filters move with --> 2 means to reducing shape of input by 2
                 latent_space_dim # the botteleneck size
                ):
        self.input_shape = input_shape 
        self.conv_filters = conv_filters 
        self.conv_kernels = conv_kernels 
        self.conv_strides = conv_strides 
        self.latent_space_dim = latent_space_dim 
        self.reconstruction_loss_weight = 1000 # the number to multiply the loss so it won't be ignored when adding kl loss

        self.encoder = None
        self.decoder = None
        self.model = None

        self._num_conv_layers = len(conv_filters)
        self._shape_before_bottleneck = None # we save the shape before bottleneck so we can upsample again
        self._model_input = None # it will be defined at the creation of the encoder

        self._build()

    def summary(self):
        self.encoder.summary()
        self.decoder.summary()
        self.model.summary()

    def compile(self, learning_rate=0.0001):
        optimizer = Adam(learning_rate=learning_rate)
        self.model.compile(optimizer=optimizer,
                           loss=self._calculate_combined_loss,
                           metrics=[self._calculate_reconstruction_loss,
                                    self._calculate_kl_loss])

    def train(self, x_train, y_train, batch_size, num_epochs):
        self.model.fit(x_train,
                       y_train,
                       batch_size=batch_size,
                       epochs=num_epochs)

    def save(self, save_folder="."):
        self._create_folder_if_it_doesnt_exist(save_folder)
        self._save_parameters(save_folder)
        self._save_weights(save_folder)

    def load_weights(self, weights_path):
        self.model.load_weights(weights_path)

    def denoise(self, images):
        latent_representations = self.encoder.predict(images)
        reconstructed_images = self.decoder.predict(latent_representations)
        return reconstructed_images, latent_representations

    @classmethod
    def load(cls, save_folder="."):
        parameters_path = os.path.join(save_folder, "parameters.pkl")
        with open(parameters_path, "rb") as f:
            parameters = pickle.load(f)
        autoencoder = VAE(*parameters)
        weights_path = os.path.join(save_folder, "weights.h5")
        autoencoder.load_weights(weights_path)
        return autoencoder

    def _calculate_combined_loss(self, y_target, y_predicted):
        # custume loss function in keras we need always to pass (y_target, y_predicted) even if we don't use them
        reconstruction_loss = self._calculate_reconstruction_loss(y_target, y_predicted)
        kl_loss = self._calculate_kl_loss(y_target, y_predicted)
        combined_loss = self.reconstruction_loss_weight * reconstruction_loss + kl_loss
        return combined_loss

    def _calculate_reconstruction_loss(self, y_target, y_predicted):
        # here we use mean squared error in a 3d input 
        error = y_target - y_predicted
        reconstruction_loss = K.mean(K.square(error), axis=[1, 2, 3])
        return reconstruction_loss

    def _calculate_kl_loss(self, y_target, y_predicted):
        # here we calculate kl losse which is used to force all the distrubutions to become a standard normal destribution
        # mu = 0 and std = 1 for all normal dist in our multi var normal dist
        kl_loss = -0.5 * K.sum(1 + self.log_variance - K.square(self.mu) -
                               K.exp(self.log_variance), axis=1)
        return kl_loss

    def _create_folder_if_it_doesnt_exist(self, folder):
        if not os.path.exists(folder):
            os.makedirs(folder)

    def _save_parameters(self, save_folder):
        # here we save all the params we need using pickle to recreate our model and once created we load the weights
        parameters = [
            self.input_shape,
            self.conv_filters,
            self.conv_kernels,
            self.conv_strides,
            self.latent_space_dim
        ]
        save_path = os.path.join(save_folder, "parameters.pkl")
        with open(save_path, "wb") as f:
            pickle.dump(parameters, f)

    def _save_weights(self, save_folder):
        save_path = os.path.join(save_folder, "weights.h5")
        self.model.save_weights(save_path)

    def _build(self):
        self._build_encoder()
        self._build_decoder()
        self._build_autoencoder()

    def _build_autoencoder(self):
        model_input = self._model_input
        model_output = self.decoder(self.encoder(model_input))
        self.model = Model(model_input, model_output, name="autoencoder")

    def _build_decoder(self):
        decoder_input = Input(shape=self.latent_space_dim, name="decoder_input")
        num_neurons = np.prod(self._shape_before_bottleneck) # [2, 2, 2] -> 8  
        dense_layer = Dense(num_neurons, name="decoder_dense")(decoder_input) # to get the same shape when we flatten in the encoder
        reshape_layer = Reshape(self._shape_before_bottleneck)(dense_layer) # 8 -> [2, 2, 2]
        conv_transpose_layers = self._add_conv_transpose_layers(reshape_layer)
        last_conv_transpose_layer = Conv2DTranspose(filters=1,kernel_size=self.conv_kernels[0],strides=self.conv_strides[0],padding="same",name=f"decoder_conv_transpose_layer_{self._num_conv_layers}")(conv_transpose_layers)
        decoder_output = Activation("sigmoid", name="sigmoid_layer")(last_conv_transpose_layer)
        self.decoder = Model(decoder_input, decoder_output, name="decoder")


    def _add_conv_transpose_layers(self, x):
        """Add conv transpose blocks."""
        # loop through all the conv layers in reverse order and stop at the
        # first layer so we can recunstruct the same shape of the input data
        for layer_index in reversed(range(1, self._num_conv_layers)):
            x = self._add_conv_transpose_layer(layer_index, x)
        return x

    def _add_conv_transpose_layer(self, layer_index, x):
        layer_num = self._num_conv_layers - layer_index
        conv_transpose_layer = Conv2DTranspose(
            filters=self.conv_filters[layer_index],
            kernel_size=self.conv_kernels[layer_index],
            strides=self.conv_strides[layer_index],
            padding="same",
            name=f"decoder_conv_transpose_layer_{layer_num}"
        )
        x = conv_transpose_layer(x)
        x = ReLU(name=f"decoder_relu_{layer_num}")(x)
        x = BatchNormalization(name=f"decoder_bn_{layer_num}")(x)
        return x


    def _build_encoder(self):
        encoder_input = Input(shape=self.input_shape, name="encoder_input")
        conv_layers = self._add_conv_layers(encoder_input)
        bottleneck = self._add_bottleneck(conv_layers)
        self._model_input = encoder_input
        self.encoder = Model(encoder_input, bottleneck, name="encoder")


    def _add_conv_layers(self, encoder_input):
        """Create all convolutional blocks in encoder."""
        x = encoder_input
        for layer_index in range(self._num_conv_layers):
            x = self._add_conv_layer(layer_index, x)
        return x

    def _add_conv_layer(self, layer_index, x):
        """Add a convolutional block to a graph of layers, consisting of
        conv 2d + ReLU + batch normalization.
        """
        layer_number = layer_index + 1
        conv_layer = Conv2D(
            filters=self.conv_filters[layer_index],
            kernel_size=self.conv_kernels[layer_index],
            strides=self.conv_strides[layer_index],
            padding="same",
            name=f"encoder_conv_layer_{layer_number}"
        )
        x = conv_layer(x)
        x = ReLU(name=f"encoder_relu_{layer_number}")(x)
        x = BatchNormalization(name=f"encoder_bn_{layer_number}")(x)
        return x

    def _add_bottleneck(self, x):
        """Flatten data and add bottleneck with Guassian sampling so it performes better then simple autoencoder.
        """
        self._shape_before_bottleneck = K.int_shape(x)[1:] # here we save the shape without the batch size fore further use 
        x = Flatten()(x)
        # here we our graph of layers is devided to 2 branches wich outputs 2 vectors one for mu and one for log_variance
        self.mu = Dense(self.latent_space_dim, name="mu")(x)
        self.log_variance = Dense(self.latent_space_dim,
                                  name="log_variance")(x)
        
        # using the two vectors we have oure normal dist we just need to sample a random point near the mean from this distribution 
        def sample_point_from_normal_distribution(args):
            mu, log_variance = args
            # the random part is here where we sample an epsilon that will defin our random point
            epsilon = K.random_normal(shape=K.shape(self.mu), mean=0.,stddev=1.)
            sampled_point = mu + K.exp(log_variance / 2) * epsilon
            return sampled_point
        # lambda is layer that give as the possibility to apply custom functions on keras outputs 
        x = Lambda(sample_point_from_normal_distribution,name="encoder_output")([self.mu, self.log_variance])
        return x




In [2]:
# lets just try our class with random shapes
autoencoder = VAE(
    input_shape=(28, 28, 1),
    conv_filters=(32, 64, 64, 64),
    conv_kernels=(3, 3, 3, 3),
    conv_strides=(1, 2, 2, 1),
    latent_space_dim=2
)
autoencoder.summary()

Instructions for updating:
Colocations handled automatically by placer.
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
encoder_conv_layer_1 (Conv2D)   (None, 28, 28, 32)   320         encoder_input[0][0]              
__________________________________________________________________________________________________
encoder_relu_1 (ReLU)           (None, 28, 28, 32)   0           encoder_conv_layer_1[0][0]       
__________________________________________________________________________________________________
encoder_bn_1 (BatchNormalizatio (None, 28, 28, 32)   128         encoder_relu_1[0][0]             
____________________