In [None]:
%load_ext autoreload
%autoreload 2

from functions.chase import *

In [None]:
setup_gpus()

In [None]:
ds_all, ds_all_centered, datasets, datasets_centered, ds_counts = load_data()

In [None]:
seq_len = 100
index_start = np.random.randint(0,len(ds_all_centered)-seq_len)
print("Seeding with frame {}".format(index_start))
xtest = ds_all[index_start:index_start+seq_len]
HTML(animate_stick(xtest))

In [None]:
seq_len      = 128
latent_dim   = 256
n_layers     = 3 #2
n_units      = 384 #256
use_dense    = True
kl_weight    = 1 #1e-2
resolution   = 3e-1 #1e-2
lr           = 3e-4
do_rotations = True
extrap_len   = seq_len//2
#do_shift     = False
#do_inplace   = False

encoder, decoder, auto, mk_continuizer = mk_seq_ae(ds_all, seq_len=seq_len, latent_dim=latent_dim,
                                   n_units=n_units, n_layers=n_layers,
                                  use_dense=use_dense, kl_weight=kl_weight,
                                  resolution=resolution, do_rotations=do_rotations, extrap_len=extrap_len)
continuizer = mk_continuizer(1)
encoder.summary()
decoder.summary()
auto.summary()

K.set_value(auto.optimizer.lr, lr)

loss_history = []

In [None]:
# Save the model architecture
with open('vae_lstm_enc_model.json', 'w') as f:
    f.write(encoder.to_json())
with open('vae_lstm_dec_model.json', 'w') as f:
    f.write(decoder.to_json())
with open('vae_lstm_auto_model.json', 'w') as f:
    f.write(auto.to_json())

# Train:

In [None]:
batch_size = 128 #32
epochs = 10
lr = 1e-5
kl_weight = 2e-4  # range from 1e-5 to 1e-2
nstep = sum([c-seq_len for c in ds_counts])//batch_size

K.set_value(auto.optimizer.lr, lr) 
K.set_value(auto.hp_kl_weight, kl_weight)

try:
    auto.fit_generator(gen_batches_safe(ds_all_centered, ds_counts, batch_size, seq_len),steps_per_epoch=nstep, epochs=epochs, verbose=1)
    
except KeyboardInterrupt:
    print("Interrupted.")

print("Updating loss history")
loss_history.extend(auto.history.history['loss'])

In [None]:
nskip = 0
xepochs = np.arange(len(loss_history))+1
plt.plot(xepochs[nskip:], loss_history[nskip:])

In [None]:
# Save weights:
encoder.save_weights('learning_rate_'+lr+'vae_lstm_enc_weights.h5')
decoder.save_weights('learning_rate_'+lr+'vae_lstm_dec_weights.h5')
auto.save_weights('learning_rate_'+lr+'vae_lstm_auto_weights.h5')

# Save model: 
encoder.save('learning_rate_'+lr+'vae_lstm_enc_model.h5')
decoder.save('learning_rate_'+lr+'vae_lstm_dec_model.h5')
auto.save('learning_rate_'+lr+'vae_lstm_auto_model.h5')

## Check autoencoder reconstruction performance

In [None]:
encoder.load_weights('weights/checkpoint_weights_vae_lstm_continued2_lr_0.001_encoder.h5')
decoder.load_weights('weights/checkpoint_weights_vae_lstm_continued2_lr_0.001_decoder.h5')
auto.load_weights('weights/checkpoint_weights_vae_lstm_continued2_lr_0.001_autoencoder.h5')
auto.summary()

The autoencoder (red) tries to imitate the real Mariel (pink):

In [None]:
# index_start = np.random.randint(0,len(ds_all_centered)-seq_len)
index_start = 9259
print("Seeding with frame {}".format(index_start))
xtest = ds_all_centered[index_start:index_start+seq_len]
xpred = auto.predict(np.expand_dims(xtest,axis=0))[0]
HTML(animate_stick(xtest,ghost=xpred, ghost_shift=0.2))

## try some variations by adding noise to latent space

In [None]:
_, ztest, _ = encoder.predict(np.expand_dims(xtest,axis=0))
xproj = decoder.predict(ztest + np.random.normal(0,0.25,latent_dim))[0]
HTML(animate_stick(xtest, ghost=xproj, ghost_shift=0.2))

In [None]:
_, ztest, _ = encoder.predict(np.expand_dims(xtest,axis=0))
xproj = decoder.predict(ztest + np.random.normal(0,0.5,latent_dim))[0]
HTML(animate_stick(xtest, ghost=xproj, ghost_shift=0.2))

In [None]:
_, ztest, _ = encoder.predict(np.expand_dims(xtest,axis=0))
xproj = decoder.predict(ztest + np.random.normal(0,1,latent_dim))[0]
HTML(animate_stick(xtest, ghost=xproj, ghost_shift=0.2))

## Try sampling randomly from the latent space

In [None]:
sigma = 1

xgen = decoder.predict(np.random.normal(0,sigma,(1,latent_dim)))[0]
HTML(animate_stick(xgen))