In [4]:
import numpy as np

import keras
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Conv2D, LeakyReLU, Reshape, Conv2DTranspose, Activation, Lambda
from keras.optimizers import Adam
from keras import backend as K

In [5]:
# Autoencoder class. It builds an internal encoder, and an internal decoder
class Autoencoder:
    def __init__(self, 
                input_dim,
                encoder_conv_filters,
                encoder_conv_kernel_size,
                encoder_conv_strides,
                decoder_conv_t_filters,
                decoder_conv_t_kernel_size,
                decoder_conv_t_strides,
                z_dim,
                variational = False,
                r_loss_factor = 1000):
        
        self.input_dim = input_dim
        self.encoder_conv_filters = encoder_conv_filters
        self.encoder_conv_kernel_size = encoder_conv_kernel_size
        self.encoder_conv_strides = encoder_conv_strides
        self.decoder_conv_t_filters = decoder_conv_t_filters
        self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size
        self.decoder_conv_t_strides = decoder_conv_t_strides
        self.z_dim = z_dim
        self.variational = variational
        self.r_loss_factor = r_loss_factor
    
        self.n_layers_encoder = len(encoder_conv_filters)
        self.n_layers_decoder = len(decoder_conv_t_filters)
        
        assert len(encoder_conv_filters) == len(encoder_conv_kernel_size) == len(encoder_conv_strides), 'len of encoder input params must be the same'
        assert len(decoder_conv_t_filters) == len(decoder_conv_t_kernel_size) == len(decoder_conv_t_strides), 'len of decoder input params must be the same'
        
        if variational == False:
            assert r_loss_factor == 1, 'r_loss_factor is not 1, but the encoder is not variational. r_loss_factor is ignored if not a variational encoder'
        
        encoder, encoder_input, encoder_output, shape_before_flattening = self.__build_encoder()
        decoder, decoder_input = self.__build_decoder(shape_before_flattening)
        
        model_input = encoder_input
        model_output= decoder(encoder_output)
        
        self.model = Model(model_input, model_output)
    
        self.__compile(learning_rate = 0.0005)
        
    def fit(self, x_train, batch_size, shuffle, epochs, callbacks):
        
        self.model.fit(
                      x = x_train,
                      y = x_train,
                      batch_size = batch_size,
                      shuffle = shuffle,
                      epochs = epochs,
                      callbacks = callbacks
                      )
    
    def get_model(self):
        return self.model
    
    def load_model(self, model):
        self.model = model
    
    # Optional: test different loss functions
    def __compile(self, learning_rate):
        optimizer = Adam(lr = learning_rate)
        
        if self.variational == True:
            self.model.compile(optimizer = optimizer, loss = self.__vae_loss, metrics = [self.__vae_r_loss, self.__vae_kl_loss])
        else:
            self.model.compile(optimizer = optimizer, loss = self.__vae_r_loss,)
        
    def __vae_r_loss(self, y_true, y_pred):
        r_loss = K.mean(K.square(y_true - y_pred), axis=[1,2,3])
        if self.variational == True:
            return self.r_loss_factor * r_loss
        else:
            return r_loss
    
    def __vae_kl_loss(self, y_true, y_pred):
        kl_loss = -0.5 * K.sum(1 + self.log_var - K.square(self.mu) - K.exp(self.log_var), axis = 1)
        return kl_loss
    
    def __vae_loss(self, y_true, y_pred):
        r_loss = self.__vae_r_loss(y_true, y_pred)
        kl_loss = self.__vae_kl_loss(y_true, y_pred)
        return r_loss + kl_loss
    
    def __build_encoder(self):
        # Create the input layer
        encoder_input = Input(shape = self.input_dim, name = 'encoder_input')
        
        x = encoder_input
        
        # Create the intermediate layers. Each intermediate layer is composed of a Conv2D layer and a
        # LeakyReLU layer.
        for i in range(self.n_layers_encoder):
            conv_layer = Conv2D(filters = self.encoder_conv_filters[i],
                                kernel_size = self.encoder_conv_kernel_size[i],
                                strides = self.encoder_conv_strides[i],
                                padding = 'same',
                                name = 'encoder_conv_' + str(i)
                               )
            x = conv_layer(x)
            x = LeakyReLU()(x)
        
        # K.int_shape returns the shape of the vector as a tuple of integers (or None)
        # Skip the first element, as it's the size of the batch
        shape_before_flattening = K.int_shape(x)[1:]
        
        # Flatten the layer before connecting it to a Dense (Activation) layer
        x = Flatten()(x)
        
        # If it's a variational encoder, build the decoder so it encode the images to two vectors,
        # the mu (mean point of distribution) and the log_var (lgo of variance of each dimension)
        # This encoder will have two outputs
        if self.variational == True:
            self.mu = Dense(self.z_dim, name = 'mu')(x)
            self.log_var = Dense(self.z_dim, name = 'log_var')(x)
            
            encoder_mu_log_var = Model(encoder_input, (self.mu, self.log_var))
            
            # Sampling function that receives mu and log_var and returns a point
            # sampled from the normal distribution function defined by mu and log_var
            def sampling(args):
                mu, log_var = args
                epsilon = K.random_normal(shape = K.shape(mu), mean = 0., stddev = 1.)
                return mu + K.exp(log_var / 2) * epsilon
            
            # Create a layer from a lambda
            encoder_output = Lambda(sampling, name = 'encoder_output')([self.mu, self.log_var])
        
        else:
            encoder_output = Dense(self.z_dim, name = 'encoder_output')(x)
        
        return Model(encoder_input, encoder_output), encoder_input, encoder_output, shape_before_flattening
    
    
    def __build_decoder(self, shape_before_flattening):
        decoder_input = Input(shape = (self.z_dim,), name = 'decoder_input')
        
        # Connect input to a dense layer
        x = Dense(np.prod(shape_before_flattening))(decoder_input)
        x = Reshape(shape_before_flattening)(x)
        
        for i in range(self.n_layers_decoder):
            conv_t_layer = Conv2DTranspose(
                filters = self.decoder_conv_t_filters[i],
                kernel_size = self.decoder_conv_t_kernel_size[i],
                strides = self.decoder_conv_t_strides[i],
                padding = 'same',
                name = 'decoder_conv_t_' + str(i)
            )
        
            x = conv_t_layer(x)
            
            if i < self.n_layers_decoder - 1:
                x = LeakyReLU()(x)
            else:
                x = Activation('sigmoid')(x)
            
        decoder_output = x
        
        return Model(decoder_input, decoder_output), decoder_output
        
    
            