In [None]:
import corner
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import tables

from analysis_utils import *
from data import data_loader, toy_data
import glow as model

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
plt.rc('font', family='serif')
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('axes', labelsize=8)
plt.rc('figure', autolayout=True, dpi=300)
plt.rc('lines', linewidth=1)
plt.rc('legend', fontsize=8)

In [None]:
class hps:
    pass
hps.level_depths = [*[3]*3, *[3]*10] #[3, *[1]*11, 3] # array of length n_levels
hps.n_levels = len(hps.level_depths) # number of splits
hps.width = 16 # channels in revnet layers
hps.window_size = 25 # conv window size in f()
hps.n_data = 16000 # number of input spectra
hps.batch_size = 50 # number of spectra in a batch
hps.n_batches = int(hps.n_data / hps.batch_size)
hps.n_bins = 2**15

In [None]:
sess = tf.compat.v1.InteractiveSession()
tf.compat.v1.disable_eager_execution()

In [None]:
data_source = 'real'

In [None]:
if data_source == 'toy':
    spectra, labels = toy_data.generate_spectra(hps.n_data, hps.n_bins)
    labels = labels[:, 1:3] # ignore temperature and sigma (leaving A, mu)
    input_stream, label_stream, initialize_stream = (
        data_loader.create_loader_from_array(sess, hps.batch_size, spectra, labels)
    )
elif data_source == 'real':
    '''input_stream, label_stream, initialize_stream = (
        data_loader.create_loader_from_hdf5(sess, hps.batch_size, 'data/sample_short.h5')
    )'''
    file = np.load('data/sample_short.npz')
    spectra, labels = file['spectra'], file['labels']
    input_stream, label_stream, initialize_stream = (
        data_loader.create_loader_from_array(sess, hps.batch_size, spectra, labels)
    )

In [None]:
initialize_stream()
spectra = sess.run(input_stream)

In [None]:
plt.figure(figsize=(6, 2))
for spectrum in spectra[:25]:
    plt.plot(spectrum, alpha=.75)

In [None]:
with tf.device("/device:GPU:0"):
    m = model.model(sess, hps, input_stream)

In [None]:
tf.trainable_variables()

In [None]:
n_processed = 0
training_results = []
lrs = []
prev_loss = 100

In [None]:
training_level = 0

In [None]:
hps.epochs = 30
hps.epochs_warmup = .01
hps.lr = .001

for epoch in tqdm(range(1, hps.epochs + 1), desc='Epochs'):
    epoch_results = []
    initialize_stream()
    with tqdm(total=hps.n_batches) as pbar:
        for iteration in range(hps.n_batches):
            pbar.set_description('Epoch ' + str(epoch))
            lr = hps.lr * min(1., n_processed / (hps.batch_size * hps.n_batches * hps.epochs_warmup))
            training_result = [m.train(lr, training_level)]
            epoch_results += training_result
            training_results += training_result
            lrs += [lr]
            n_processed += hps.batch_size
            pbar.set_postfix(lr=lr, loss=np.mean(epoch_results), training_level=training_level)
            pbar.update()
        current_loss = np.mean(epoch_results)
        if (np.abs(prev_loss - current_loss) < .001) and training_level < hps.n_levels - 1:
            training_level += 1
        elif (np.abs(prev_loss - current_loss) < .001) and training_level == hps.n_levels - 1:
            pass
            #break
        prev_loss = current_loss

In [None]:
hps.epochs = 15
hps.epochs_warmup = .01
hps.lr = .001

for epoch in tqdm(range(1, hps.epochs + 1), desc='Epochs'):
    epoch_results = []
    initialize_stream()
    with tqdm(total=hps.n_batches) as pbar:
        for iteration in range(hps.n_batches):
            pbar.set_description('Epoch ' + str(epoch))
            lr = hps.lr * min(1., n_processed / (hps.batch_size * hps.n_batches * hps.epochs_warmup))
            training_result = [m.train(lr)]
            epoch_results += training_result
            training_results += training_result
            lrs += [lr]
            n_processed += hps.batch_size
            pbar.set_postfix(lr=lr, loss=np.mean(epoch_results))
            pbar.update()

In [None]:
plt.figure(figsize=(6, 4))

epochs = np.linspace(0, len(training_results) / hps.n_batches, len(training_results))

plt.subplot(2, 1, 1)
plt.plot(epochs, training_results)
plt.ylim(-7, 0)
plt.xlabel('epochs')

plt.subplot(2, 1, 2)
plt.plot(epochs, lrs)
plt.xlabel('epochs')

# generate reconstructions of spectra

In [None]:
i = np.random.randint(0, hps.batch_size)
spectrum = spectra[i:i+1, :, :]
print(i)

In [None]:
latent_rep, intermediate_zs = m.encode(spectrum)

In [None]:
SAMPLES_PER_INDEX = 10
reconstructions = [] 
stats = []
for zs_used in tqdm(range(len(intermediate_zs) + 1)):
    index = len(intermediate_zs) - zs_used
    reconstructions_i = [
        m.decode(latent_rep, intermediate_zs[index:]) 
        for _ in range(SAMPLES_PER_INDEX)
    ]
    stats_i = [get_stats(spectrum, r) for r in reconstructions_i]
    reconstructions.append(reconstructions_i)
    stats.append(stats_i)
reconstructions = np.array(reconstructions)
stats = np.array(stats)

In [None]:
window = (0, hps.n_bins)
#window = (int(hps.n_bins*.4), int(hps.n_bins*.6)) 
window = (12000, 14000)

lambdas = np.arange(0, hps.n_bins) # remap bins to wavelengths here
#lambdas = np.linspace(0, 30000, 2**12) # in angstrom

In [None]:
plt.figure(figsize=(6, 4))

zs_used = [0, 3, 6, 9, 12]
reconstruction_labels = ["d($\widetilde{{h}}_{{{}}}$)".format(i) for i in zs_used]
residual_labels = [s + "-x" for s in reconstruction_labels]

plt.subplot(2, 1, 1)
plot_window(lambdas, 
            [reconstructions[i][0] for i in zs_used],
            window=window, labels=reconstruction_labels, alpha=.75)
plot_window(lambdas, [spectrum], labels=["x"], window=window)
plt.xlabel('Wavelength $[\AA]$')
plt.ylabel("Normalized flux")
plt.legend()

plt.subplot(2, 1, 2)
plot_window(lambdas, 
            [reconstructions[i][0] - spectrum for i in zs_used],
            window=window, labels=residual_labels, alpha=.75)
plt.axhline(0, color="k")
plt.xlabel('Wavelength $[\AA]$')
plt.ylabel("d(z)-x")
plt.legend()

In [None]:
plt.figure(figsize=(3, 2))
plt.xticks(range(13))
plot_mean_w_error(stats[:, :, 0], label="L1")
plot_mean_w_error(stats[:, :, 1], label="L2")
plot_mean_w_error(stats[:, :, 2], label="L4")
plt.legend()

# latent variable behavior

In [None]:
latent_reps = np.empty([hps.n_batches, hps.batch_size, latent_rep.shape[1], latent_rep.shape[2]])
initialize_stream()

for i in tqdm(range(hps.n_batches)):
    data = sess.run(input_stream)
    latent_reps[i], _ = m.encode(data)

latent_reps = latent_reps.reshape(hps.n_data, latent_rep.shape[1] * latent_rep.shape[2])

In [None]:
plt.figure(figsize=(3, 1))
plot_mean_w_error(latent_reps, axis=0)
plt.axhline(0, color="k")

In [None]:
plt.figure(figsize=(4, 4))
figure = corner.corner(latent_reps)

axes = np.array(figure.axes).reshape((latent_reps.shape[1], latent_reps.shape[1]))
for yi in range(latent_reps.shape[1]):
    for xi in range(yi):
        ax = axes[yi, xi]
        ax.axvline(0, color="g")
        ax.axhline(0, color="g")

# generate random realization

In [None]:
plt.figure(figsize=(6, 2))
N = 15
random_latent_reps = [np.random.normal(size=latent_rep.shape) for _ in range(N)]
plot_window(lambdas, [m.decode(z) for z in random_latent_reps], window=None, alpha=.75)

# test dataset

In [None]:
n_test = 4000

In [None]:
if data_source == 'toy':
    test_spectra, test_labels = toy_data.generate_spectra(n_test, hps.n_bins)
    test_labels = test_labels[:, 1:3] # ignore temperature and sigma (leaving A, mu)
    test_input_stream, test_label_stream, initialize_test_stream = (
        data_loader.create_loader_from_array(sess, n_test, test_spectra, test_labels)
    )
elif data_source == 'real':
    test_input_stream, test_label_stream, initialize_test_stream = (
        data_loader.create_loader_from_hdf5(sess, n_test, 'data/test_short.h5')
    )
initialize_test_stream()
test_spectra, test_labels = sess.run([test_input_stream, test_label_stream])

In [None]:
test_zs, test_intermediate_zs = m.encode(test_spectra)
test_reconstructions = m.decode(test_zs)

In [None]:
stats = []
for zs_used in tqdm(range(len(intermediate_zs) + 1)):
    index = len(intermediate_zs) - zs_used
    test_reconstructions = m.decode(test_zs, test_intermediate_zs[index:])
    stats_i = get_stats(test_spectra, test_reconstructions)
    stats.append(stats_i)
stats = np.array(stats)

In [None]:
plt.figure(figsize=(3, 2))
plt.xticks(range(13))
plt.plot(stats[:, 0], label="L1")
plt.plot(stats[:, 1], label="L2")
plt.plot(stats[:, 2], label="L4")
plt.legend()

# save

In [None]:
from datetime import datetime
model_filename = 'models/model-{}'.format(datetime.now().strftime('%y%m%d-%H%M%S'))
print(model_filename)

In [None]:
m.save(model_filename)

In [None]:
m.restore('models/model-200304-081901')