In [None]:
import corner
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from data import data_loader, toy_data
import glow as model

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
plt.rc('font', family='serif')
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('axes', labelsize=8)
plt.rc('figure', autolayout=True, dpi=300)
plt.rc('lines', linewidth=1)
plt.rc('legend', fontsize=8)

In [None]:
class hps:
    pass
hps.n_levels = 10 # number of splits
hps.depth = 3 # number of layers in revnet
hps.width = 16 # channels in revnet layers
hps.polyak_epochs = 1
hps.beta1 = .9 # learning rate annealing factor
hps.weight_decay = 1 # learning rate annealing factor
hps.lr = .001 # base learning rate
hps.n_data = 4000 # number of input spectra
hps.batch_size = 50 # number of spectra in a batch
hps.n_batches = int(hps.n_data / hps.batch_size)
hps.n_bins = 2**12

In [None]:
sess = tf.compat.v1.InteractiveSession()
tf.compat.v1.disable_eager_execution()

In [None]:
data_source = 'toy'

In [None]:
if data_source == 'toy':
    spectra, labels = toy_data.generate_spectra(hps.n_data, hps.n_bins)
    labels = labels[:, 1:3] # ignore temperature and sigma (leaving A, mu)
    input_stream, label_stream, initialize_stream = (
        data_loader.create_loader_from_array(sess, hps.batch_size, spectra, labels)
    )
elif data_source == 'real':
    input_stream, label_stream, initialize_stream = (
        data_loader.create_loader_from_hdf5(sess, hps.batch_size, 'data/sample_short.h5')
    )

In [None]:
initialize_stream()
spectra = sess.run(input_stream)

In [None]:
plt.figure(figsize=(6, 4))
for spectrum in spectra[:5]:
    plt.plot(spectrum)

In [None]:
#with tf.device("/device:GPU:0"):
m = model.model(sess, hps, input_stream)

In [None]:
n_processed = 0
training_results = []
lrs = []

In [None]:
hps.epochs = 50
hps.epochs_warmup = .01
hps.lr = .001

for epoch in tqdm(range(1, hps.epochs + 1), desc='Epochs'):
    epoch_results = []
    initialize_stream()
    with tqdm(total=hps.n_batches) as pbar:
        for iteration in range(hps.n_batches):
            pbar.set_description('Epoch ' + str(epoch))
            lr = hps.lr * min(1., n_processed / (hps.batch_size * hps.n_batches * hps.epochs_warmup))
            training_result = [m.train(lr)]
            epoch_results += training_result
            training_results += training_result
            lrs += [lr]
            n_processed += hps.batch_size
            pbar.set_postfix(lr=lr, loss=np.mean(epoch_results))
            pbar.update()

In [None]:
plt.figure(figsize=(12, 10))

plt.subplot(2, 1, 1)
plt.plot(np.linspace(0, len(training_results) / hps.n_batches, len(training_results)), training_results)
#plt.yscale('symlog')
plt.xlabel('epochs')

plt.subplot(2, 1, 2)
plt.plot(np.linspace(0, len(training_results) / hps.n_batches, len(training_results)), lrs)
plt.xlabel('epochs')

In [None]:
plt.figure(figsize=(4, 2))

plt.plot(np.linspace(0, len(training_results) / hps.n_batches, len(training_results)), 
         training_results, label='Negative Log Likelihood')
#plt.yscale('symlog')
plt.xlabel('Epochs')
plt.ylabel('Bits per component')
plt.legend()

In [None]:
i = np.random.randint(0, hps.batch_size)
spectrum = spectra[i:i+1, :, :]
print(i)

In [None]:
#37 is the spectra used for figs

In [None]:
latent_rep, intermediate_zs = m.encode(spectrum)
reconstruction = m.decode(latent_rep)
perfect_reconstruction = m.decode(latent_rep, intermediate_zs)
print(latent_rep.mean(), latent_rep.std())
#print(reconstruction.mean(), reconstruction.std())
print(np.mean((spectrum - reconstruction)), np.mean((spectrum - perfect_reconstruction)))
print(np.mean((spectrum - reconstruction)**2), np.mean((spectrum - perfect_reconstruction)**2))
print(np.mean((spectrum - reconstruction)**4), np.mean((spectrum - perfect_reconstruction)**4))

In [None]:
plt.figure(figsize=(2, 2))
plt.plot(latent_rep.ravel())

In [None]:
window = (0, hps.n_bins)
window = (int(hps.n_bins*.4), int(hps.n_bins*.6)) 
#window = (12000, 14000)

In [None]:
lambdas = np.linspace(0, 30000, 2**12) # in angstrom
lambdas = lambdas[window[0]:window[1]]

In [None]:
plt.figure(figsize=(6, 4))

plt.subplot(2, 1, 1)
plt.plot(lambdas, np.squeeze(reconstruction)[window[0]:window[1]], label='x')
plt.plot(lambdas, np.squeeze(spectrum)[window[0]:window[1]], label='d(z)')
plt.xlabel('Wavelength $[\AA]$')
plt.ylabel('Normalized flux')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(lambdas, np.squeeze(reconstruction - spectrum)[window[0]:window[1]])
plt.xlabel('Wavelength $[\AA]$')
plt.ylabel('d(z) - x')

In [None]:
intermediate_shapes = []
for intermediate_z in intermediate_zs:
    intermediate_shapes.append(intermediate_z.shape)
intermediate_shapes

In [None]:
reconstructions = []
for intermediate_zs_used in range(len(intermediate_shapes)+1):
    new_intermediate_zs = []
    for i in range(len(intermediate_shapes)):
        #print(i, intermediate_zs_used, i < len(intermediate_shapes) - intermediate_zs_used)
        if i < len(intermediate_shapes) - intermediate_zs_used:
            sampled_z = np.random.normal(0, 1, intermediate_shapes[i])
            new_intermediate_zs.append(sampled_z)
        else:
            new_intermediate_zs.append(intermediate_zs[i])
    reconstructions.append(m.decode(latent_rep, new_intermediate_zs))

In [None]:
len(reconstructions)

In [None]:
plt.figure(figsize=(6, 4))
for i in [0, 1, 2]:
    print(np.mean((reconstructions[i] - spectrum)**2))
    plt.subplot(2, 1, 1)
    plt.plot(lambdas, np.squeeze(reconstructions[i])[window[0]:window[1]], 
             label='d($\widetilde{{h}}_{})$'.format(i), alpha=.75)
    plt.subplot(2, 1, 2)
    plt.plot(lambdas, np.squeeze(reconstructions[i] - spectrum)[window[0]:window[1]], 
             label='d($\widetilde{{h}}_{}) - x$'.format(i), alpha=.75)
    
plt.subplot(2, 1, 1)
plt.plot(lambdas, np.squeeze(spectrum)[window[0]:window[1]], label='x')
plt.xlabel('Wavelength $[\AA]$')
plt.ylabel('Normalized flux')
plt.legend()

plt.subplot(2, 1, 2)
plt.axhline(0, color='k')
plt.xlabel('Wavelength $[\AA]$')
plt.ylabel('d($\widetilde{h}_i$) - x')
plt.legend()

In [None]:
l1s = []
l2s = []
l4s = []
for r in reconstructions:
    l1s.append(np.mean(np.abs(r - spectrum)))
    l2s.append(np.mean((r - spectrum)**2)**.5)
    l4s.append(np.mean((r - spectrum)**4)**.25)

In [None]:
plt.figure(figsize=(4, 4))
plt.plot(l1s, label='Mean Absolute Error')
plt.plot(l2s, label='Root Mean Squared Error')
plt.plot(l4s, label='Root Mean Quartic Error')

plt.xlabel('Number of Intermediate Representations used from true $h$')
plt.legend()

# latent variable behavior

In [None]:
latent_rep.shape

In [None]:
latent_reps = np.empty([hps.n_batches, hps.batch_size, latent_rep.shape[1], latent_rep.shape[2]])
initialize_stream()
with tqdm(total=hps.n_batches) as pbar:
    for i in range(hps.n_batches):
        data = sess.run(input_stream)
        latent_reps[i], _ = m.encode(data)
        pbar.set_postfix(mean=latent_reps.mean(), std=latent_reps.std())
        pbar.update()

latent_reps = latent_reps.reshape(hps.n_data, latent_rep.shape[1] * latent_rep.shape[2])

In [None]:
print(latent_reps.shape)

In [None]:
print(means - stds)

In [None]:
plt.figure(figsize=(1, 1))
means = latent_reps.mean(axis=0)
stds =  latent_reps.std(axis=0)

plt.plot(means)
plt.fill_between(range(len(means)), means - stds, means + stds, alpha=.25)

In [None]:
print(latent_reps.mean(axis=0).round(3))
print(latent_reps.std(axis=0).round(3))

In [None]:
plt.figure(figsize=(4, 4))
figure = corner.corner(latent_reps)

axes = np.array(figure.axes).reshape((latent_reps.shape[1], latent_reps.shape[1]))
for yi in range(latent_reps.shape[1]):
    for xi in range(yi):
        ax = axes[yi, xi]
        ax.axvline(0, color="g")
        ax.axhline(0, color="g")

# generate random realization

In [None]:
plt.figure(figsize=(6, 2))
for _ in range(5):
    sampled_latent_rep = np.random.normal(size=latent_rep.shape)
    #plt.plot(sampled_latent_rep.ravel())
    plt.plot(np.squeeze(m.decode(sampled_latent_rep)))

# test dataset

In [None]:
n_test = 4000

In [None]:
test_spectra, test_labels = toy_data.generate_spectra(n_test, hps.n_bins)

In [None]:
test_zs, test_intermediate_zs = m.encode(test_spectra[:, :, np.newaxis])
test_reconstructions = m.decode(test_zs)
print(np.mean(np.abs(test_spectra - test_reconstructions.squeeze())))
print(np.mean((test_spectra - test_reconstructions.squeeze())**2)**.5)
print(np.mean((test_spectra - test_reconstructions.squeeze())**4)**.25)

In [None]:
test_intermediate_shapes = []
for tmp in test_intermediate_zs:
    test_intermediate_shapes.append(tmp.shape)

In [None]:
test_reconstructions = []
for intermediate_zs_used in range(len(test_intermediate_shapes)+1):
    new_intermediate_zs = []
    for i in range(len(test_intermediate_shapes)):
        if i < len(test_intermediate_shapes) - intermediate_zs_used:
            sampled_z = np.random.normal(0, 1, test_intermediate_shapes[i])
            new_intermediate_zs.append(sampled_z)
        else:
            new_intermediate_zs.append(test_intermediate_zs[i])
    test_reconstructions.append(m.decode(test_zs, new_intermediate_zs))

In [None]:
test_l1s = []
test_l2s = []
test_l4s = []
for r in test_reconstructions:
    test_l1s.append(np.mean(np.abs(r.squeeze() - test_spectra)))
    test_l2s.append(np.mean((r.squeeze() - test_spectra)**2)**.5)
    test_l4s.append(np.mean((r.squeeze() - test_spectra)**4)**.25)

In [None]:
plt.figure(figsize=(4, 4))
plt.plot(test_l1s, label='Mean Absolute Error')
plt.plot(test_l2s, label='Root Mean Squared Error')
plt.plot(test_l4s, label='Root Mean Quartic Error')

plt.xlabel('Number of Intermediate Representations used from true $h$')
plt.legend()

# save

In [None]:
from datetime import datetime
model_filename = 'models/model-{}'.format(datetime.now().strftime('%y%m%d-%H%M%S'))
print(model_filename)

In [None]:
m.save(model_filename)

In [None]:
m.restore('models/model-200304-081901')