In [None]:
import corner
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from scipy.ndimage.filters import gaussian_filter

from analysis_utils import *
from data import data_loader, toy_data
import glow as model

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
plt.rc('font', family='serif')
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('axes', labelsize=8)
plt.rc('figure', autolayout=True, dpi=300)
plt.rc('lines', linewidth=1)
plt.rc('legend', fontsize=8)

In [None]:
class hps:
    pass
hps.n_levels = 13 # number of splits
hps.depth = 6 # number of flow steps in each level
hps.final_depth = 64 # number of flow steps in the final level
hps.width = 32 # channels in revnet layers
hps.n_data = 16000 # number of input spectra
hps.batch_size = 50 # number of spectra in a batch
hps.n_batches = int(hps.n_data / hps.batch_size)
hps.n_bins = 2**15

In [None]:
sess = tf.compat.v1.InteractiveSession()
tf.compat.v1.disable_eager_execution()

In [None]:
data_source = 'real'

In [None]:
if data_source == 'toy':
    spectra, labels = toy_data.generate_spectra(hps.n_data, hps.n_bins)
    labels = labels[:, 1:3] # ignore temperature and sigma (leaving A, mu)
    input_stream, label_stream, initialize_stream = (
        data_loader.create_loader_from_array(sess, hps.batch_size, spectra, labels)
    )
elif data_source == 'real':
    input_stream, label_stream, initialize_stream = (
        data_loader.create_loader_from_hdf5(sess, hps.batch_size, 'data/sample_short.h5')
    )

In [None]:
initialize_stream()
spectra = sess.run(input_stream)
labels = sess.run(label_stream)

In [None]:
window = (0, hps.n_bins)
#window = (int(hps.n_bins*.4), int(hps.n_bins*.6)) 
window = (13350, 13400)

lambdas = np.arange(0, hps.n_bins) # remap bins to wavelengths here
#lambdas = np.linspace(0, 30000, 2**12) # in angstrom

In [None]:
plt.figure(figsize=(12, 6))
# figure out where the window should be

plot_window(lambdas, 
            sort_spectra(spectra, labels, 2)[0][:5], 
            window=(window[0]-750, window[1]+100))
plt.axvline(13350, color='k')
plt.axvline(13400, color='k')

In [None]:
# define function wrapper
find = lambda label: find_spectra(label, hps.n_batches, sess, input_stream, label_stream)

In [None]:
label = np.array([np.nan, np.nan, np.nan, np.nan]).astype('float32')
spectra, labels = find(label)

In [None]:
print("Temperature options:", np.unique(labels[:, 0]))
print("log g options:", np.unique(labels[:, 1]))
print("Fe/H options:", np.unique(labels[:, 2]))
print("alpha/H options:", np.unique(labels[:, 3]))

In [None]:
# T_eff, log_g, fe/h, alpha/h
label = np.array([6000, 3, np.nan, np.nan]).astype('float32')
spectra, labels = sort_spectra(*find(label), 3) # sort by fe/h
print(spectra.shape, labels.shape)

In [None]:
plt.figure(figsize=(6, 2))
plot_window(lambdas, spectra,
            window=(window[0]-100, window[1]+100))
plt.axvline(window[0], color='k')
plt.axvline(window[1], color='k')

In [None]:
i = np.random.randint(0, len(spectra))
spectrum = spectra[i:i+1, :, :]
label = labels[i]
print(i, label)

In [None]:
plt.figure(figsize=(6, 2))
plot_window(lambdas, spectrum,
            window=(window[0]-100, window[1]+100),
            )
plt.axvline(window[0], color='k')
plt.axvline(window[1], color='k')
plt.legend()

In [None]:
with tf.device("/device:GPU:0"):
    m = model.model(sess, hps, input_stream)

In [None]:
m.restore('models/model-200716-234221')

In [None]:
latent_rep, intermediate_zs = m.encode(spectrum)
reconstruction = m.decode(latent_rep)

In [None]:
m.create_peak_remover(window, spectrum)

In [None]:
grads = []
latent_reps = [latent_rep]

In [None]:
for _ in tqdm(range(100)):
    z, grad = m.remove_peak(latent_reps[-1], .01)
    latent_reps.append(z)
    grads.append(grad)

In [None]:
# exploration analysis
latent_reps_np = np.array(latent_reps).reshape((len(latent_reps), -1))
grads_np = np.array(grads).reshape((len(grads), -1))
print_freq = int(len(grads) / 10) # when plotting changes over time, plot around 10 things

In [None]:
plt.figure(figsize=(6, 6))

plt.subplot(3, 2, 1)
plt.plot(latent_reps_np.mean(axis=0))
plt.xlabel('component position')
plt.ylabel('latent rep (avg over steps)')

plt.subplot(3, 2, 2)
plt.plot(grads_np.mean(axis=0))
plt.xlabel('component position')
plt.ylabel('gradient (avg over steps)')

plt.subplot(3, 2, 3)
plt.plot([np.linalg.norm(l) for l in latent_reps])
plt.xlabel('step')
plt.ylabel('norm of latent representation')

plt.subplot(3, 2, 4)
plt.plot([np.linalg.norm(g) for g in grads])
plt.xlabel('step')
plt.ylabel('norm of gradient')

plt.subplot(3, 2, 5)
for i in range(0, len(latent_reps), print_freq):
    plt.plot(latent_reps_np[i])
plt.xlabel('component position')
plt.ylabel('latent rep over time')

plt.subplot(3, 2, 6)
for i in range(0, len(grads), print_freq):
    plt.plot(grads_np[i])
plt.xlabel('component position')
plt.ylabel('gradient over time')

plt.tight_layout()

In [None]:
indices = np.arange(0, len(latent_reps), len(latent_reps) // 5)
#indices = [25, 50, 75, 100]
colors = plt.cm.viridis(np.linspace(0, 1, len(indices)))

In [None]:
plt.figure(figsize=(6, 2))

intermediate_spectra = [m.decode(latent_reps[index], intermediate_zs) for index in indices]

plot_window(lambdas, [spectrum, *intermediate_spectra], 
            window=(window[0] - 100, window[1] + 100),
            colors=['k', *colors], labels=['og', *indices])

plt.axvline(window[0], color='k')
plt.axvline(window[1], color='k')
plt.xlabel('Wavelength $[\AA]$')
plt.ylabel('Normalized flux')
plt.legend()

In [None]:
print(label)

In [None]:
comparison_label = [label[0], label[1], label[2], np.nan]
comparison_spectra, comparison_labels = sort_spectra(*find(comparison_label), 3)
comparison_spectra.shape, comparison_labels.shape

In [None]:
colors = plt.cm.viridis(np.linspace(0, 1, len(comparison_spectra)))

plot_window(lambdas, comparison_spectra, 
            window=(window[0] - 100, window[1] + 100),
            colors=colors,
            labels=comparison_labels[:, 2])

plot_window(lambdas, [intermediate_spectra[-1]],
            window=(window[0] - 100, window[1] + 100),
            colors=['r'], labels=['peak removed'])

plt.axvline(window[0], color='k')
plt.axvline(window[1], color='k')
plt.legend()

In [None]:
label < 4

In [None]:
similar_spectra = []
similar_labels = []
for iteration in tqdm(range(hps.n_batches)):
    spectra = sess.run(input_stream)
    labels = sess.run(label_stream)
    for i, s in enumerate(spectra):
        if np.any(get_stats(intermediate_spectra[-1], s) < .032):
            similar_spectra.append(s)
            similar_labels.append(labels[i])
    if len(similar_spectra) > 10:
        break
similar_spectra = np.array(similar_spectra)
similar_labels = np.array(similar_labels)
print(similar_spectra.shape, similar_labels.shape)

In [None]:
for s in similar_spectra:
    print(get_stats(s, intermediate_spectra[-1]))

In [None]:
colors = plt.cm.viridis(np.linspace(0, 1, len(similar_spectra)))

plot_window(lambdas, similar_spectra[:10], 
            window=(window[0] - 700, window[1] + 100),
            colors=colors,
            labels=similar_labels)

plot_window(lambdas, [intermediate_spectra[-1]],
            window=(window[0] - 700, window[1] + 100),
            colors=['r'], labels=['peak removed'])

plt.axvline(window[0], color='k')
plt.axvline(window[1], color='k')
plt.legend()

In [None]:
# for debugging

In [None]:
def blur(s, sigma=3):
    return gaussian_filter(s.squeeze(), sigma=sigma)

def differentiate(s):
    bins = len(s.squeeze())
    return np.convolve(s.squeeze(), [-bins/2, 0, bins/2], mode='same')

In [None]:
plt.figure(figsize=(6, 2))
s = intermediate_spectra[-1]
plot_window(lambdas,
            [
                spectrum,
                s,
                #blur(s),
                #differentiate(blur(s)) / 750,
                #blur(differentiate(blur(s))) / 750,
                differentiate(blur(differentiate(blur(spectrum)))) / 5000000,
                differentiate(blur(differentiate(blur(s)))) / 5000000
            ], 
            window=[window[0]-100, window[1]+100])
plt.axvline(window[0], color='k')
plt.axvline(window[1], color='k')