In [None]:
import corner
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from data import data_loader, toy_data
import glow as model

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
plt.rc('font', family='serif')
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('axes', labelsize=8)
plt.rc('figure', autolayout=True, dpi=300)
plt.rc('lines', linewidth=1)
plt.rc('legend', fontsize=8)

In [None]:
class hps:
    pass
hps.n_levels = 10 # number of splits
hps.depth = 3 # number of layers in revnet
hps.width = 16 # channels in revnet layers
hps.polyak_epochs = 1
hps.beta1 = .9 # learning rate annealing factor
hps.weight_decay = 1 # learning rate annealing factor
hps.lr = .001 # base learning rate
hps.n_data = 4000 # number of input spectra
hps.batch_size = 50 # number of spectra in a batch
hps.n_batches = int(hps.n_data / hps.batch_size)
hps.n_bins = 2**12

In [None]:
sess = tf.compat.v1.InteractiveSession()
tf.compat.v1.disable_eager_execution()

In [None]:
data_source = 'toy'

In [None]:
if data_source == 'toy':
    spectra, labels = toy_data.generate_spectra(hps.n_data, hps.n_bins)
    labels = labels[:, 1:3] # ignore temperature and sigma (leaving A, mu)
    input_stream, label_stream, initialize_stream = (
        data_loader.create_loader_from_array(sess, hps.batch_size, spectra, labels)
    )
elif data_source == 'real':
    input_stream, label_stream, initialize_stream = (
        data_loader.create_loader_from_hdf5(sess, hps.batch_size, 'data/sample_short.h5')
    )

In [None]:
initialize_stream()
spectra = sess.run(input_stream)

In [None]:
plt.figure(figsize=(6, 4))
for spectrum in spectra[:5]:
    plt.plot(spectrum)

In [None]:
#with tf.device("/device:GPU:0"):
m = model.model(sess, hps, input_stream)

In [None]:
m.restore('models/model-200507-075409')

In [None]:
i = np.random.randint(0, hps.batch_size)
spectrum = spectra[i:i+1, :, :]
print(i)

In [None]:
latent_rep, intermediate_zs = m.encode(spectrum)
reconstruction = m.decode(latent_rep)
perfect_reconstruction = m.decode(latent_rep, intermediate_zs)

In [None]:
def get_feed_dict(z):
    feed_dict = {m.z_placeholder: z}
    for i in range(len(intermediate_zs)):
        feed_dict[m.intermediate_z_placeholders[i]] = intermediate_zs[i]
    return feed_dict

In [None]:
window = (0, hps.n_bins)
window = (int(hps.n_bins*.45), int(hps.n_bins*.55))
#window = (12000, 14000)

In [None]:
lambdas = np.linspace(0, 30000, 2**12) # in angstrom
lambdas = lambdas[window[0]:window[1]]

In [None]:
plt.figure(figsize=(6, 4))

plt.subplot(2, 1, 1)
plt.plot(np.linspace(0, 30000, 2**12), np.squeeze(spectrum), label='x')
plt.plot(np.linspace(0, 30000, 2**12), np.squeeze(reconstruction), label='d(z)')
plt.axvline(lambdas[0])
plt.axvline(lambdas[-1])
plt.xlabel('Wavelength $[\AA]$')
plt.ylabel('Normalized flux')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(lambdas, np.squeeze(spectrum)[window[0]:window[1]])
plt.plot(lambdas, np.squeeze(reconstruction)[window[0]:window[1]])
plt.xlabel('Wavelength $[\AA]$')
plt.ylabel('Normalized flux')

In [None]:
def create_gaussian_kernel(size, mean, std):
    d = tf.distributions.Normal(tf.cast(mean, tf.float32), tf.cast(std, tf.float32))
    vals = d.prob(tf.range(start=-int(size/2), limit=int(size/2)+1, dtype=tf.float32))

    kernel = vals[:, np.newaxis, np.newaxis]
    return kernel / tf.reduce_sum(kernel)

In [None]:
gaussian_kernel = create_gaussian_kernel(51, 0, 25)
derivative_kernel = tf.constant([[[-hps.n_bins / 2]], [[0]], [[hps.n_bins / 2]]])

In [None]:
smoothed = tf.nn.conv1d(m.s_from_intermediate_zs, gaussian_kernel, padding="SAME")
first_derivative = tf.nn.conv1d(smoothed, derivative_kernel, padding="SAME")
smoothed_first_derivative = tf.nn.conv1d(first_derivative, gaussian_kernel, padding="SAME")
second_derivative = tf.nn.conv1d(smoothed_first_derivative, derivative_kernel, padding="SAME")

In [None]:
smoothed_spectra = sess.run(smoothed, get_feed_dict(latent_rep))
first_derivative_spectra = sess.run(first_derivative, get_feed_dict(latent_rep))
second_derivative_spectra = sess.run(second_derivative, get_feed_dict(latent_rep))

In [None]:
plt.figure(figsize=(6, 2))
plt.plot(np.squeeze(reconstruction))
plt.plot(np.squeeze(smoothed_spectra))
plt.plot(np.squeeze(first_derivative_spectra) / first_derivative_spectra.std())
plt.plot(np.squeeze(second_derivative_spectra) / second_derivative_spectra.std())
plt.axvline(window[0])
plt.axvline(window[1])
plt.ylim(-5, 5)

In [None]:
# outside window
left_squared_error = tf.reduce_sum((spectrum[:, :window[0]] - m.s_from_intermediate_zs[:, :window[0]])**2)
right_squared_error = tf.reduce_sum((spectrum[:, window[1]:] - m.s_from_intermediate_zs[:, window[1]:])**2)
outside_cost = left_squared_error + right_squared_error

# inside window
#inside_cost = tf.reduce_sum((spectrum[:, window[0]:window[1]] - m.decoded_spectra[:, window[0]:window[1]])**2)
inside_cost = -tf.reduce_sum(second_derivative[:, window[0]:window[1]]**2)

# likelihood
logp = -.5 * tf.reduce_sum(m.z_placeholder**2)

cost = 1*inside_cost - 1e6*outside_cost + 1 * logp # maximize inside cost and likelihood. minimize outside cost
gradient = tf.gradients(cost, m.z_placeholder)

In [None]:
grads = []
latent_reps = [latent_rep]

In [None]:
for _ in range(500):
    grads.append(sess.run(gradient, get_feed_dict(latent_reps[-1]))[0])
    step_size = .01 / np.linalg.norm(grads[-1][0])
    latent_reps.append(latent_reps[-1] + step_size * grads[-1][0])

In [None]:
# exploration analysis
latent_reps_np = np.array(latent_reps).reshape((len(latent_reps), -1))
grads_np = np.array(grads).reshape((len(grads), -1))
print_freq = int(len(grads) / 10) # when plotting changes over time, plot around 10 things

In [None]:
plt.figure(figsize=(6, 6))

plt.subplot(3, 2, 1)
plt.plot(latent_reps_np.mean(axis=0))
plt.xlabel('component position')
plt.ylabel('latent rep (avg over steps)')

plt.subplot(3, 2, 2)
plt.plot(grads_np.mean(axis=0))
plt.xlabel('component position')
plt.ylabel('gradient (avg over steps)')

plt.subplot(3, 2, 3)
plt.plot([np.linalg.norm(l) for l in latent_reps])
plt.xlabel('step')
plt.ylabel('norm of latent representation')

plt.subplot(3, 2, 4)
plt.plot([np.linalg.norm(g) for g in grads])
plt.xlabel('step')
plt.ylabel('norm of gradient')

plt.subplot(3, 2, 5)
for i in range(0, len(latent_reps), print_freq):
    plt.plot(latent_reps_np[i])
plt.xlabel('component position')
plt.ylabel('latent rep over time')

plt.subplot(3, 2, 6)
for i in range(0, len(grads), print_freq):
    plt.plot(grads_np[i])
plt.xlabel('component position')
plt.ylabel('gradient over time')

plt.tight_layout()

In [None]:
indices = np.arange(100, len(latent_reps), 100) #[100, 300, 500, 1000]
indices = [50, 100, 150, 200]
colors = plt.cm.viridis(np.linspace(0, 1, len(indices)))

In [None]:
plt.figure(figsize=(6, 4))

plt.subplot(2, 1, 1)
plt.plot(np.linspace(0, 30000, 2**12), np.squeeze(reconstruction), color='k', label='0')
for i in range(len(indices)):
    plt.plot(np.linspace(0, 30000, 2**12), 
             np.squeeze(m.decode(latent_reps[indices[i]], intermediate_zs)),
             color=colors[i], label=indices[i])
plt.axvline(lambdas[0])
plt.axvline(lambdas[-1])
plt.xlabel('Wavelength $[\AA]$')
plt.ylabel('Normalized flux')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(lambdas, np.squeeze(reconstruction)[window[0]:window[1]], color='k', label='0')
for i in range(len(indices)):
    plt.plot(lambdas, 
             np.squeeze(m.decode(latent_reps[indices[i]], intermediate_zs))[window[0]:window[1]], 
             color=colors[i], label=indices[i])
plt.xlabel('Wavelength $[\AA]$')
plt.ylabel('Normalized flux')
#plt.legend()