In [1]:
import os
import shutil
import zipfile
import numpy as np

from google.colab import drive
from glob import glob
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint 
from tensorflow.keras.utils import plot_model

# Wird benötigt, da ansonsten ein 'symbolic constant' Fehler auftritt
from tensorflow.python.framework.ops import disable_eager_execution

## VARIATIONAL AUTOENCODER TEMPLATE CLASS
class VariationalAutoencoder():
  def __init__(
      self,
      input_dim,
      encoder_conv_filters,
      encoder_conv_kernel_size,
      encoder_conv_strides,
      decoder_conv_t_filters,
      decoder_conv_t_kernel_size,
      decoder_conv_t_strides,
      z_dim,
      use_batch_norm=False,
      use_dropout=False,
      version=1
  ):

    self.name = 'variational_autoencoder_{}'.format(version)
    self.version=version
    self.input_dim = input_dim
    self.encoder_conv_filters = encoder_conv_filters
    self.encoder_conv_kernel_size = encoder_conv_kernel_size
    self.encoder_conv_strides = encoder_conv_strides
    
    self.decoder_conv_t_filters = decoder_conv_t_filters
    self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size
    self.decoder_conv_t_strides = decoder_conv_t_strides

    self.z_dim = z_dim
    self.use_batch_norm = use_batch_norm
    self.use_dropout = use_dropout

    self.n_layers_encoder = len(encoder_conv_filters)
    self.n_layers_decoder = len(decoder_conv_t_filters)

    self._build()

  def train_with_generator(self, data_flow, epochs, lr_decay=1):
    self.model.fit_generator(
      data_flow,
      shuffle=True,
      epochs=epochs
    )


  def _build(self):
    
    ### THE ENCODER
    encoder_input = Input(shape=self.input_dim, name='encoder_input')

    x = encoder_input

    for i in range(self.n_layers_encoder):
        conv_layer = Conv2D(
            filters = self.encoder_conv_filters[i]
            , kernel_size = self.encoder_conv_kernel_size[i]
            , strides = self.encoder_conv_strides[i]
            , padding = 'same'
            , name = 'encoder_conv_' + str(i)
            )

        x = conv_layer(x)

        if self.use_batch_norm:
            x = BatchNormalization()(x)

        x = LeakyReLU()(x)

        if self.use_dropout:
            x = Dropout(rate = 0.25)(x)

    shape_before_flattening = K.int_shape(x)[1:]

    x = Flatten()(x)
    self.mu = Dense(self.z_dim, name='mu')(x)
    self.log_var = Dense(self.z_dim, name='log_var')(x)

    self.encoder_mu_log_var = Model(encoder_input, (self.mu, self.log_var))

    def sampling(args):
        mu, log_var = args
        epsilon = K.random_normal(shape=K.shape(mu), mean=0., stddev=1.)
        return mu + K.exp(log_var / 2) * epsilon

    encoder_output = Lambda(sampling, name='encoder_output')([self.mu, self.log_var])

    self.encoder = Model(encoder_input, encoder_output)
    
    

    ### THE DECODER

    decoder_input = Input(shape=(self.z_dim,), name='decoder_input')

    x = Dense(np.prod(shape_before_flattening))(decoder_input)
    x = Reshape(shape_before_flattening)(x)

    for i in range(self.n_layers_decoder):
        conv_t_layer = Conv2DTranspose(
            filters = self.decoder_conv_t_filters[i]
            , kernel_size = self.decoder_conv_t_kernel_size[i]
            , strides = self.decoder_conv_t_strides[i]
            , padding = 'same'
            , name = 'decoder_conv_t_' + str(i)
            )

        x = conv_t_layer(x)

        if i < self.n_layers_decoder - 1:
            if self.use_batch_norm:
                x = BatchNormalization()(x)
            x = LeakyReLU()(x)
            if self.use_dropout:
                x = Dropout(rate = 0.25)(x)
        else:
            x = Activation('sigmoid')(x)

        

    decoder_output = x

    self.decoder = Model(decoder_input, decoder_output)

    ### THE FULL VAE
    model_input = encoder_input
    model_output = self.decoder(encoder_output)

    self.model = Model(model_input, model_output)


  def compile(self, learning_rate, r_loss_factor):
      def vae_r_loss(y_true, y_pred):
        r_loss = K.mean(K.square(y_true - y_pred), axis=[1,2,3])
        return r_loss_factor * r_loss

      def vae_kl_loss(y_true, y_pred):
        '''
          Berechnet die Kullback-Leibler Divergenz zwischen dem
          vorhergesagtem Wert und dem ursprünglichen Wert.

          Da es sich um Gauß-Verteilungen handelt, hat die KL-Divergenz die Form:

          DL(p,q) = 1/2 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        '''
        kl_loss = -.5 * K.sum(1 + self.log_var - K.square(self.mu) - K.exp(self.log_var), axis=1)
        return kl_loss

      def vae_loss(y_true, y_pred):
        r_loss = vae_r_loss(y_true, y_pred)
        kl_loss = vae_kl_loss(y_true, y_pred)
        return r_loss + kl_loss

      optimizer = Adam(learning_rate=learning_rate)
      self.model.compile(optimizer=optimizer, loss=vae_loss, metrics=[vae_r_loss, vae_kl_loss])



In [8]:
INPUT_DIMENSION = (316,316,1)

# Hyperparameters
LEARNING_RATE = 0.0001
LOSS_FACTOR = 10000

BATCH_SIZE = 32

vae = VariationalAutoencoder(
  input_dim = INPUT_DIMENSION
  , encoder_conv_filters=[32,64,64, 64]
  , encoder_conv_kernel_size=[3,3,3,3]
  , encoder_conv_strides=[2,2,2,2]
  , decoder_conv_t_filters=[64,64,32,3]
  , decoder_conv_t_kernel_size=[3,3,3,3]
  , decoder_conv_t_strides=[2,2,2,2]
  , z_dim=192
  , use_batch_norm=True
  , use_dropout=True)


vae.compile(LEARNING_RATE, LOSS_FACTOR)
vae.encoder.summary()

Model: "model_17"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      [(None, 316, 316, 1) 0                                            
__________________________________________________________________________________________________
encoder_conv_0 (Conv2D)         (None, 158, 158, 32) 320         encoder_input[0][0]              
__________________________________________________________________________________________________
batch_normalization_32 (BatchNo (None, 158, 158, 32) 128         encoder_conv_0[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_32 (LeakyReLU)      (None, 158, 158, 32) 0           batch_normalization_32[0][0]     
___________________________________________________________________________________________