# Train a Variational Autoencoder for generating cat images

In [1]:
import numpy as np
from glob import glob
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from src.model import genModel
from src.vaeHelpers import *

%matplotlib inline

Using TensorFlow backend.


## Load the data

Load the data, create a training/validation split. Save the split to a file to help with continuation of training.

In [2]:
trainFresh = False

if ( trainFresh ):
    catFiles = np.array( glob('./catCroped/*.jpg') )
    
    valFrac = 0.1
    n = len(catFiles)
    
    inds = np.random.permutation( len(catFiles) )
    trainInds, valInds = inds[ : -int(n*valFrac) ], inds[ -int(n*valFrac) : ]
    
    train, val = catFiles[ trainInds ], catFiles[ valInds ]
    
    writeFilesList( "trainFiles.txt", train )
    writeFilesList( "valFiles.txt", val )

else:
    val   = readSavedFiles( "valFiles.txt" )
    train = readSavedFiles( "trainFiles.txt" )

## Generate the model 

In [3]:
batchSize = 16
imgSize = 256

encoder, decoder, VAE = genModel()
VAE.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 256, 256, 3)       0         
_________________________________________________________________
encoder (Model)              [(None, 2048), (None, 204 12621608  
_________________________________________________________________
decoder (Model)              (None, 256, 256, 3)       8431379   
Total params: 21,052,987
Trainable params: 21,040,235
Non-trainable params: 12,752
_________________________________________________________________


## Train the model

In [None]:
if ( not trainFresh ):
    VAE.load_weights( "weights/catGen.hdf5" )

earlyStopper = EarlyStopping( patience = 50, verbose = 1 )
checkPointer = ModelCheckpoint( filepath = "weights/catGen.hdf5", save_best_only = True, verbose = 1 )
rateReduce   = ReduceLROnPlateau( monitor = 'val_loss', factor = 0.5, patience = 20, cooldown = 5 )

losses = VAE.fit_generator( genBatch( train, batchSize, imgSize, True ),
                   validation_data = genBatch( val, batchSize, imgSize, False ),
                   epochs = 5000,
                   validation_steps = len(val)   // batchSize,
                   steps_per_epoch  = len(train) // batchSize,
                   callbacks = [ earlyStopper, checkPointer ] )
                   #callbacks = [ earlyStopper, checkPointer, rateReduce ] )

Epoch 1/5000
  1/575 [..............................] - ETA: 7:19:07 - loss: 1718.6731

## Plot a learning curve

In [None]:
plotLosses( losses.history )