In [None]:
import os

from utils.loaders import load_mnist

from models.VariationalAutoEncoder import VariationalAutoEncoder

## Set Run Parameters

In [None]:
# run params
SECTION = 'vae'
RUN_ID = '0001'
DATA_NAME = 'digits'
if not os.path.exists("run"):
    os.mkdir("run")
if not os.path.exists(f"run/{SECTION}"):
    os.mkdir(f"run/{SECTION}")
RUN_FOLDER = f'run/{SECTION}/'
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

MODE =  'build' #'load' #

## Load the data

In [None]:
(x_train, y_train), (x_test, y_test) = load_mnist()

## Define the structure of the neural network

In [None]:
encoder_architecture = [
    {'filter': 32, 'kernel': (3, 3), 'stride': 1},
    {'filter': 64, 'kernel': (3, 3), 'stride': 2},
    {'filter': 64, 'kernel': (3, 3), 'stride': 2},
    {'filter': 64, 'kernel': (3, 3), 'stride': 1},
    ]
decoder_architecture = [
    {'filter': 64, 'kernel': (3, 3), 'stride': 1},
    {'filter': 64, 'kernel': (3, 3), 'stride': 2},
    {'filter': 32, 'kernel': (3, 3), 'stride': 2},
    {'filter': 1, 'kernel': (3, 3), 'stride': 1},
    ]
input_dim = (28,28,1)
latent_dim = 2

In [None]:
ae = VariationalAutoEncoder(
    input_dim=input_dim,
    latent_dim=latent_dim,
    encoder_params=encoder_architecture,
    decoder_params=decoder_architecture
)

In [None]:
if MODE == 'build':
    ae.save(RUN_FOLDER)
else:
    ae.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))

In [None]:
ae.autoencoder_model.summary()

In [None]:
ae.encoder_model.summary()

In [None]:
ae.decoder_model.summary()

## Train the autoencoder

In [None]:
LEARNING_RATE = 0.0005
BATCH_SIZE = 32
INITIAL_EPOCH = 0
TOTAL_EPOCHS = 50

In [None]:
ae.compile(LEARNING_RATE)

In [None]:
ae.train(     
    x_train, 
    batch_size = BATCH_SIZE, 
    epochs = TOTAL_EPOCHS, 
    run_folder = RUN_FOLDER, 
    initial_epoch = INITIAL_EPOCH
)