# VAE Training

## imports

In [1]:
import os

from models.VAE import VariationalAutoencoder
from utils.loaders import load_mnist

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
# run params
section = 'vae'
run_id = '0002'
data_name = 'digits'
run_folder = 'run/{}/'.format(section)
run_folder += '_'.join([run_id, data_name])

if not os.path.exists(run_folder):
    os.mkdir(run_folder)
    os.mkdir(os.path.join(run_folder, 'viz'))
    os.mkdir(os.path.join(run_folder, 'images'))
    os.mkdir(os.path.join(run_folder, 'weights'))

mode =  'build' #'load' #

## data

In [3]:
(x_train, y_train), (x_test, y_test) = load_mnist()

## architecture

In [4]:
INPUT_DIM = (28,28,1)

CONV_FILTERS = [32,64,64, 64]
CONV_KERNEL_SIZES = [3,3,3,3]
CONV_STRIDES = [1,2,2,1]

CONV_T_FILTERS = [64,64,32,1]
CONV_T_KERNEL_SIZES = [3,3,3,3]
CONV_T_STRIDES = [1,2,2,1]

Z_DIM = 2

In [5]:
VAE = VariationalAutoencoder(INPUT_DIM
                , CONV_FILTERS
                , CONV_KERNEL_SIZES
                , CONV_STRIDES
                , CONV_T_FILTERS
                , CONV_T_KERNEL_SIZES
                , CONV_T_STRIDES
                , Z_DIM)

VAE.save(run_folder)

In [6]:
#VAE.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))

In [7]:
VAE.encoder.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
encoder_conv_0 (Conv2D)         (None, 28, 28, 32)   320         encoder_input[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 28, 28, 32)   0           encoder_conv_0[0][0]             
__________________________________________________________________________________________________
encoder_conv_1 (Conv2D)         (None, 14, 14, 64)   18496       leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
leaky_re_l

In [8]:
VAE.decoder.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
decoder_input (InputLayer)   (None, 2)                 0         
_________________________________________________________________
dense_1 (Dense)              (None, 3136)              9408      
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
decoder_conv_t_0 (Conv2DTran (None, 7, 7, 64)          36928     
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 7, 7, 64)          0         
_________________________________________________________________
decoder_conv_t_1 (Conv2DTran (None, 14, 14, 64)        36928     
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 14, 14, 64)        0         
__________

## training

In [9]:
LEARNING_RATE = 0.0005
R_LOSS_FACTOR = 1000

In [10]:
VAE.compile(LEARNING_RATE, R_LOSS_FACTOR)

In [11]:
BATCH_SIZE = 32
EPOCHS = 200
PRINT_EVERY_N_BATCHES = 100
INITIAL_EPOCH = 0

In [12]:
VAE.train(     
    x_train
    , batch_size = BATCH_SIZE
    , epochs = EPOCHS
    , run_folder = run_folder
    , print_every_n_batches = PRINT_EVERY_N_BATCHES
    , initial_epoch = INITIAL_EPOCH
)

Epoch 1/200

Epoch 00001: saving model to run/vae/0002_digits/weights/weights-001-58.40.h5

Epoch 00001: saving model to run/vae/0002_digits/weights/weights.h5
Epoch 2/200

Epoch 00002: saving model to run/vae/0002_digits/weights/weights-002-51.57.h5

Epoch 00002: saving model to run/vae/0002_digits/weights/weights.h5
Epoch 3/200
  416/60000 [..............................] - ETA: 2:12 - loss: 52.1034 - vae_r_loss: 48.0262 - vae_kl_loss: 4.0772

KeyboardInterrupt: 