## Create Model

In [0]:
import tensorflow as tf
print('Tensorflow version: {}'.format(tf.__version__))

from tensorflow.keras.layers import Dense, Flatten, LeakyReLU, BatchNormalization, Reshape
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

from abc import ABC, abstractmethod

In [0]:
class AbstractModelCreator(ABC):

    @abstractmethod
    def create_model(self):
        raise NotImplementedError('Abstract class shall not be implemented')

In [0]:
class GeneratorModelCreator(AbstractModelCreator):

    def __init__(self, input_shape):
        self.input_shape = input_shape

    def create_model(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.input_shape))

        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(np.prod(self.input_shape), activation='tanh'))
        model.add(Reshape(self.input_shape))

        print('Generator model:')
        model.summary()

        return model

In [0]:
class DiscriminatorModelCreator(AbstractModelCreator):

    def __init__(self, input_shape):
        self.input_shape = input_shape

    def create_model(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.input_shape))

        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))

        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))

        model.add(Dense(1, activation='sigmoid'))

        optimizer = Adam(learning_rate=0.001)
        model.compile(loss='binary_crossentropy', 
                      optimizer=optimizer, 
                      metrics=['accuracy'])

        print('Discriminator model:')
        model.summary()

        return model

In [0]:
class EncoderGanModelCreator(AbstractModelCreator):

    def __init__(self, 
                 encoder_generator, 
                 encoder_discriminator):
        self.encoder_generator = encoder_generator
        self.encoder_discriminator = encoder_discriminator

    def create_model(self):
        # Create logical model to combine encoder generator and encoder discriminator
        model = Sequential()

        model.add(self.encoder_generator)
        model.add(self.encoder_discriminator)

        optimizer = Adam(learning_rate=0.001)
        model.compile(loss='binary_crossentropy', 
                      optimizer=optimizer)

        print('Encoder GAN model:')
        model.summary()

        return model

In [0]:
class DecoderGanModelCreator(AbstractModelCreator):

    def __init__(self, 
                 encoder_generator, 
                 decoder_generator):
        self.encoder_generator = encoder_generator
        self.decoder_generator = decoder_generator

    def create_model(self):
        # Create logical model to combine encoder generator and decoder generator
        model = Sequential()

        model.add(self.encoder_generator)
        model.add(self.decoder_generator)

        optimizer = Adam(learning_rate=0.001)
        model.compile(loss='mae', 
                      optimizer=optimizer)

        print('Decoder GAN model:')
        model.summary()

        return model

## Train Model

In [0]:
from numpy import ones
from numpy import zeros

In [0]:
class AbstractModelTrainer(ABC):

    @abstractmethod
    def train_model(self):
        raise NotImplementedError('Abstract class shall not be implemented')

In [0]:
class EncoderTrainer(AbstractModelTrainer):

    def __init__(self, 
                 encoder_generator,
                 encoder_discriminator,
                 encoder_gan,
                 training_epochs, 
                 batch_size,
                 input_data, # Input data of encoder
                 exp_output_data): # Expected output data of encoder
        
        self.encoder_generator = encoder_generator
        self.encoder_discriminator = encoder_discriminator
        self.encoder_gan = encoder_gan

        self.training_epochs = training_epochs
        self.batch_size = batch_size

        self.input_data = input_data
        self.exp_output_data = exp_output_data

        self.y_zeros = zeros((self.batch_size, 1))
        self.y_ones = ones((self.batch_size, 1))

    def train_model(self):

        for current_epoch in range(self.training_epochs):

            # Select a random batch of data
            input_indexes = np.random.randint(0, self.input_data.shape[0], self.batch_size)
            x_input = self.input_data[input_indexes]

            output_indexes = np.random.randint(0, self.exp_output_data.shape[0], self.batch_size)
            x_exp_output = self.exp_output_data[output_indexes]

            # Generate output data via encoder generator
            x_gen_output = self.encoder_generator.predict(x_input)
            
            # ---------------------
            #  Train encoder discriminator
            # ---------------------
            d_loss = self.__train_encoder_discriminator(x_gen_output, x_exp_output)

            # ---------------------
            #  Train encoder generator
            # ---------------------
            g_loss = self.__train_encoder_generator(x_input)

            # Plot the progress
            print('[Encoder] - epochs: {}, d_loss: {}, g_loss: {}'.format((current_epoch + 1), d_loss, g_loss))
            
    def __train_encoder_discriminator(self, x_gen_output, x_exp_output):

        # 1) Set discriminator to trainable
        self.encoder_discriminator.trainable = True

        # 2) Train discriminator
        # Generated output is marked as 0
        d_loss_fake = self.encoder_discriminator.train_on_batch(x_gen_output, self.y_zeros)

        # Expected output is marked as 1
        d_loss_real = self.encoder_discriminator.train_on_batch(x_exp_output, self.y_ones)

        return 0.5 * np.add(d_loss_real, d_loss_fake)

    def __train_encoder_generator(self, x_input):

        # 1) Set discriminator to non-trainable
        self.encoder_discriminator.trainable = False

        # 2) Set generator to trainable
        self.encoder_generator.trainable = True

        # 3) Train generator via GAN model
        return self.encoder_gan.train_on_batch(x_input, self.y_ones)

In [0]:
class DecoderTrainer(AbstractModelTrainer):

    def __init__(self, 
                 encoder_generator,
                 decoder_generator,
                 decoder_gan,
                 training_epochs,
                 batch_size, 
                 input_data): # Input data of encoder
        
        self.encoder_generator = encoder_generator
        self.decoder_generator = decoder_generator
        self.decoder_gan = decoder_gan

        self.training_epochs = training_epochs
        self.batch_size = batch_size

        self.input_data = input_data

    def train_model(self):

        for current_epoch in range(self.training_epochs):

            # Select a random batch of data
            input_indexes = np.random.randint(0, self.input_data.shape[0], self.batch_size)
            x_input = self.input_data[input_indexes]

            # ---------------------
            #  Train decoder generator
            # ---------------------
            self.__train_decoder_generator(x_input)

    def __train_decoder_generator(self, x_input):

        # 1) Set encoder generator to non-trainable
        self.encoder_generator.trainable = False

        # 2) Set decoder generator to trainable
        self.decoder_generator.trainable = True

        # 3) Train decoder generator via GAN model
        self.decoder_generator.train_on_batch(x_input, x_input)