In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

from vae import VariationalAutoencoder
from conv_vae import ConvolutionalVAE
from data import toy_data
from data import data_loader

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
plt.rc('font', family='serif')
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('axes', labelsize=8)
plt.rc('figure', autolayout=True, dpi=300)
plt.rc('lines', linewidth=1)
plt.rc('legend', fontsize=8)

In [None]:
sess = tf.compat.v1.InteractiveSession()
tf.compat.v1.disable_eager_execution()

In [None]:
input_size = 2**15 #1569128
latent_rep_size = 40
label_size = 2
n_data = 16000 # if using hdf5 file, should be set to match file
learning_rate = .001
batch_size = 50

network_architecture = {
    'input_size': input_size,
    'latent_representation_size': latent_rep_size,
    'encoder_layer_sizes': [input_size, 100, 100, 90, 90, 80, 80, 2 * latent_rep_size],
    'decoder_layer_sizes': [latent_rep_size, 80, 80, 90, 90, 100, 100, input_size],
    'label_predictor_layer_sizes': [latent_rep_size, 40, 20, 20, 10, 10, 5, 5, label_size]
}

In [None]:
data_source = 'real'

In [None]:
if data_source == 'toy':
    spectra, labels = toy_data.generate_spectra(n_data, network_architecture['input_size'])
    labels = labels[:, 1:3] # ignore temperature and sigma (leaving A, mu)
    input_stream, label_stream, initialize_stream = (
        data_loader.create_loader_from_array(sess, batch_size, spectra, labels)
    )
elif data_source == 'real':
    input_stream, label_stream, initialize_stream = (
        data_loader.create_loader_from_hdf5(sess, batch_size, 'data/sample_short.h5')
    )

In [None]:
vae = VariationalAutoencoder(sess, network_architecture, input_stream, label_stream, learning_rate, batch_size)

In [None]:
total_batches = int(n_data / batch_size)

costs = []
l1_costs = []
l2_costs = []
l_costs = []
label_costs = []

In [None]:
training_epochs = 5

for epoch in tqdm(range(1, training_epochs + 1), desc='Epochs'):
    initialize_stream()
    with tqdm(total=total_batches) as pbar:
        for batch in range(total_batches):
            pbar.set_description('Epoch ' + str(epoch))
            _, cost, l1_loss, r_cost, l_cost, _, m_cost = vae.optimize()

            costs += [cost]
            l1_costs += [l1_loss]
            l2_costs += [r_cost]
            l_costs += [l_cost]
            label_costs += [m_cost]
            
            pbar.set_postfix(loss=cost, l1=l1_loss, l2=r_cost/input_size, l=l_cost, m=m_cost)
            pbar.update()

In [None]:
epochs = np.linspace(0, len(costs) // total_batches, len(costs))

plt.figure(figsize=(4, 2))
plt.plot(epochs, costs, label='Loss')
plt.plot(epochs, l2_costs, label='Reconstruction Loss')
plt.plot(epochs, l_costs, label='KL Divergence')

plt.xlabel('Epoch')
plt.xticks(np.arange(0, 51, 5))
#plt.xlim(0, 50)

plt.yscale('log')
#plt.ylim(10**2, 10**4)

plt.legend()

In [None]:
initialize_stream()
spectra = sess.run(input_stream)

In [None]:
i = np.random.randint(0, batch_size)
spectrum = spectra[i, :, 0]
print(i)

In [None]:
reconstruction = vae.reconstruct(spectrum)

In [None]:
window = (0, input_size)
#window = (12000, 14000)
#window = (700000, 700000+2**15)
current_lambda = 500
lambdas = [500]
for i in range(1, 1569128):
    if current_lambda >= 500 and current_lambda <= 3000:
        current_lambda += .1
    elif current_lambda > 3000 and current_lambda <= 25000:
        delta = current_lambda / 650000
        current_lambda += delta
    elif current_lambda > 25000 and current_lambda <= 55000:
        delta = current_lambda / 250000
        current_lambda += delta
    lambdas.append(current_lambda)
lambdas = lambdas[700000:700000 + 2**15]
#lambdas = np.linspace(0, 30, 2**15)

lambdas = lambdas[window[0]:window[1]]

In [None]:
plt.figure(figsize=(6, 4))
plt.subplot(2, 1, 1)
plt.plot(lambdas, spectrum[window[0]:window[1]], label='x')
plt.plot(lambdas, reconstruction[window[0]:window[1]], label='d(e(x))')

plt.xlabel('Wavelength $[\AA]$')
plt.ylabel('Normalized flux')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(lambdas, (reconstruction - spectrum)[window[0]:window[1]])
plt.xlabel('Wavelength $[\AA]$')
plt.ylabel('d(e(x)) - x')
plt.show()

In [None]:
epochs = np.linspace(0, len(costs) // total_batches, len(costs))

plt.figure(figsize=(4, 2))
plt.plot(epochs, label_costs)
#plt.plot(epochs, l2_costs, label='Reconstruction Loss')
#plt.plot(epochs, l_costs, label='KL Divergence')

plt.xlabel('Epoch')
plt.xticks(np.arange(0, 51, 5))
plt.xlim(0, 50)

plt.ylabel('Squared Error')

#plt.axhline(.5, color='k')

# latent space exploration

In [None]:
window = (0, input_size)
#window = (12000, 14000)

In [None]:
plt.figure(figsize=(6, 4))
for _ in range(5):
    z = vae.encode(spectrum)
    plt.subplot(2, 1, 1)
    plt.plot(z)
    plt.subplot(2, 1, 2)
    plt.plot(vae.decode(z)[window[0]:window[1]])

plt.plot(spectrum[window[0]:window[1]], alpha = .5)

In [None]:
plt.figure(figsize=(6, 2))
for _ in range(5):
    z = np.random.normal(size=40)
    plt.plot(vae.decode(z)[window[0]:window[1]])

In [None]:
plt.figure(figsize=(6, 2))
for i in range(5):
    plt.plot(spectra[i], alpha=.25)

# evaluate on test set

In [None]:
test_stream, test_label_stream, initialize_test_stream = (
        data_loader.create_loader_from_hdf5(sess, batch_size, 'data/test_short.h5')
    )

In [None]:
initialize_test_stream()
test_l1s = []
test_l2s = []
test_l4s = []
for i in range(1000 // 50):
    test_batch = sess.run(test_stream).squeeze()
    for test_spectrum in test_batch:
        test_reconstructions = vae.reconstruct(test_spectrum)
        test_l1s.append(np.mean(np.abs(test_reconstructions - test_spectrum)))
        test_l2s.append(np.mean((test_reconstructions - test_spectrum)**2))
        test_l4s.append(np.mean((test_reconstructions - test_spectrum)**4))

In [None]:
print(np.mean(test_l1s))
print(np.mean(test_l2s)**.5)
print(np.mean(test_l4s)**.25)

In [None]:
from datetime import datetime

In [None]:
outfile_name = 'output/output-{}'.format(datetime.now().strftime('%y%m%d-%H%M%S'))
vae.save(outfile_name)

In [None]:
vae.restore('output/output-200505-215641')

In [None]:
vae.close()