In [1]:
import os

from models import VariationalAutoEncoder
from utils import load_mnist

In [2]:
# run params
SECTION = "vae"
RUN_ID = "0002"
DATA_NAME = "digits"
RUN_FOLDER = "run/{}/".format(SECTION)
RUN_FOLDER += "_".join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.makedirs(RUN_FOLDER)
    os.makedirs(os.path.join(RUN_FOLDER, "viz"))
    os.makedirs(os.path.join(RUN_FOLDER, "images"))
    os.makedirs(os.path.join(RUN_FOLDER, "weights"))

mode =  "build" # "load" 

In [3]:
(x_train, y_train), (x_test, y_test) = load_mnist()

In [4]:
vae = VariationalAutoEncoder(
    input_dim=(28, 28, 1),
    encoder_conv_filters=[32, 64, 64, 64],
    encoder_conv_kernel_size=[3, 3, 3, 3],
    encoder_conv_strides=[1, 2, 2, 1],
    decoder_conv_t_filters=[64, 64, 32, 1],
    decoder_conv_t_kernel_size=[3, 3, 3, 3],
    decoder_conv_t_strides=[1, 2, 2, 1],
    z_dim=2
)

if mode == "build":
    vae.save(RUN_FOLDER)
else:
    vae.load_weights(os.path.join(RUN_FOLDER, "weights/weights.h5"))

In [5]:
vae.encoder.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
encoder_conv_0 (Conv2D)         (None, 28, 28, 32)   320         encoder_input[0][0]              
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 28, 28, 32)   0           encoder_conv_0[0][0]             
__________________________________________________________________________________________________
encoder_conv_1 (Conv2D)         (None, 14, 14, 64)   18496       leaky_re_lu[0][0]                
____________________________________________________________________________________________

In [6]:
vae.decoder.summary()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
decoder_input (InputLayer)   [(None, 2)]               0         
_________________________________________________________________
dense (Dense)                (None, 3136)              9408      
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 64)          0         
_________________________________________________________________
decoder_conv_t_0 (Conv2DTran (None, 7, 7, 64)          36928     
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 7, 7, 64)          0         
_________________________________________________________________
decoder_conv_t_1 (Conv2DTran (None, 14, 14, 64)        36928     
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 14, 14, 64)        0   

In [7]:
LEARNING_RATE = 0.0005
R_LOSS_FACTOR = 1000

BATCH_SIZE = 32
EPOCHS = 200
PRINT_EVERY_N_BATCHES = 100
INITIAL_EPOCH = 0

vae.compile(LEARNING_RATE, R_LOSS_FACTOR)

vae.train(
    x_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    run_folder=RUN_FOLDER,
    print_every_n_batches=PRINT_EVERY_N_BATCHES,
    initial_epoch=INITIAL_EPOCH
)

Train on 60000 samples
Epoch 1/200
   32/60000 [..............................] - ETA: 5:40:13 - loss: 231.1615 - vae_r_loss: 231.1602 - vae_kl_loss: 0.0013



Epoch 00001: saving model to run/vae/0002_digits\weights\weights-001-58.70.h5

Epoch 00001: saving model to run/vae/0002_digits\weights\weights.h5
Epoch 2/200
Epoch 00002: saving model to run/vae/0002_digits\weights\weights-002-51.96.h5

Epoch 00002: saving model to run/vae/0002_digits\weights\weights.h5
Epoch 3/200
Epoch 00003: saving model to run/vae/0002_digits\weights\weights-003-50.43.h5

Epoch 00003: saving model to run/vae/0002_digits\weights\weights.h5
Epoch 4/200
Epoch 00004: saving model to run/vae/0002_digits\weights\weights-004-49.46.h5

Epoch 00004: saving model to run/vae/0002_digits\weights\weights.h5
Epoch 5/200
Epoch 00005: saving model to run/vae/0002_digits\weights\weights-005-48.78.h5

Epoch 00005: saving model to run/vae/0002_digits\weights\weights.h5
Epoch 6/200
Epoch 00006: saving model to run/vae/0002_digits\weights\weights-006-48.22.h5

Epoch 00006: saving model to run/vae/0002_digits\weights\weights.h5
Epoch 7/200
Epoch 00007: saving model to run/vae/0002_digi

KeyboardInterrupt: 