In [None]:
import tftables
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

from vae import VariationalAutoencoder
from conv_vae import ConvolutionalVAE
from data import toy_data
from data import data_loader

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
sess = tf.compat.v1.InteractiveSession()
tf.compat.v1.disable_eager_execution()

In [None]:
network_architecture = {
    'input_size': 2**15,
    'latent_representation_size': 40,
    'encoder_layer_sizes': [2**15, 100, 100, 90, 90, 80, 80, 80],
    'decoder_layer_sizes': [40, 80, 80, 90, 90, 100, 100, 2**15],
    'label_predictor_layer_sizes': [40, 40, 20, 20, 10, 10, 5, 5, 3]
}
n_data = 16000
learning_rate = .0005
batch_size = 50

In [None]:
spectra, labels = toy_data.generate_spectra(n_data, network_architecture['input_size'])
labels = labels[:, 1:] # ignore temperature (leaving A, mu, sigma)

In [None]:
input_stream, label_stream, initialize_stream = (
    data_loader.create_loader_from_array(sess, batch_size, spectra, labels)
)

In [None]:
vae = VariationalAutoencoder(sess, network_architecture, input_stream, label_stream, learning_rate, batch_size)

In [None]:
total_batches = int(n_data / batch_size)

costs = []
l1_costs = []
l2_costs = []
l_costs = []
label_costs = []

In [None]:
training_epochs = 25

for epoch in tqdm(range(1, training_epochs + 1), desc='Epochs'):
    initialize_stream()
    with tqdm(total=total_batches) as pbar:
        for batch in range(total_batches):
            pbar.set_description('Epoch ' + str(epoch))
            _, cost, l1_loss, r_cost, l_cost, _, m_cost = vae.optimize()

            costs += [cost]
            l1_costs += [l1_loss]
            l2_costs += [r_cost]
            l_costs += [l_cost]
            label_costs += [m_cost]
            
            pbar.set_postfix(loss=cost, l1=l1_loss, r=r_cost, l=l_cost, m=m_cost)
            pbar.update()

In [None]:
# compare to Epoch 20 Batch 160 Iter 003200 | r_cost=17809.174 l_cost=1558.191 l1_loss=0.032 time=8.9

In [None]:
plt.figure(figsize=(16, 6))
plt.subplot(2, 1, 1)
plt.plot(costs)
plt.plot(l2_costs)
plt.plot(l_costs)
plt.yscale('log')

plt.subplot(2, 1, 2)
plt.plot(label_costs)
plt.yscale('log')

In [None]:
i = np.random.randint(0, n_data)
spectrum = spectra[i]
print(i)

In [None]:
reconstruction = vae.reconstruct(spectrum)

In [None]:
plt.figure(figsize=(16, 12))
plt.subplot(2, 1, 1)
plt.plot(spectrum)
reconstruction[reconstruction == 0] = np.nan # get rid of zeros just for the plot
plt.plot(reconstruction)

plt.subplot(2, 1, 2)
plt.plot(spectrum - reconstruction)

In [None]:
outfile_name = 'output/output-{}'.format(datetime.now().strftime('%y%m%d-%H%M%S'))
vae.save(outfile_name)

In [None]:
vae.restore('output/output-191002-000332')

In [None]:
vae.close()