In [None]:
import corner
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tftables
import time
from tqdm import tqdm_notebook as tqdm

import data_loader
import model_short as model
import toy_data_loader

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
np.set_printoptions(suppress=True, precision=4)

In [None]:
class hps:
    pass
hps.n_levels = 3 # number of splits
hps.depth = 3 # number of layers in revnet
hps.width = 16 # channels in revnet layers
hps.polyak_epochs = 1
hps.beta1 = .9 # learning rate annealing factor
hps.weight_decay = 1 # learning rate annealing factor
hps.lr = .001 # base learning rate
hps.n_data = 16000 # number of input spectra
hps.batch_size = 50 # number of spectra in a batch
hps.n_batches = int(hps.n_data / hps.batch_size)
hps.n_bins = 2**15

In [None]:
sess = tf.InteractiveSession()

In [None]:
# select real or toy data by uncommenting the appropriate line
# real data must have n_data=8000, n_bins=40000
#input_stream, initialize_input_stream, data_init = data_loader.create_data_loader(
input_stream, initialize_input_stream, data_init = toy_data_loader.create_data_loader(
    sess, hps.batch_size, hps.n_data, hps.n_bins
)

In [None]:
print(data_init.shape)
plt.figure(figsize=(15, 5))
for spectrum in data_init[:5]:
    plt.plot(spectrum)

In [None]:
with tf.device("/device:GPU:0"):
    m = model.model(sess, hps, input_stream, data_init)

In [None]:
%%time
m.train(.001)

In [None]:
n_processed = 0

hps.epochs = 20
hps.epochs_warmup = 1
hps.print_freq = 10

for epoch in tqdm(range(1, hps.epochs + 1), desc='Epochs'):
    train_results = []
    initialize_input_stream()
    with tqdm(total=hps.n_batches) as pbar:
        for iteration in range(hps.n_batches):
            pbar.set_description('Epoch ' + str(epoch))
            lr = hps.lr * min(1., n_processed / (hps.batch_size * hps.n_batches * hps.epochs_warmup))
            train_results += [m.train(lr)]
            n_processed += hps.batch_size
            pbar.set_postfix(lr=lr, loss=np.mean(train_results))
            pbar.update()

In [None]:
i = np.random.randint(0, hps.batch_size)
spectrum = data_init[i:i+1, :, :]
print(i)

In [None]:
latent_rep = m.encode(spectrum)
reconstruction = m.decode(latent_rep)
print(latent_rep.mean(), latent_rep.std())

In [None]:
for channel in range(latent_rep.shape[-1]):
    plt.plot(latent_rep[0, :, channel])

In [None]:
plt.plot(np.squeeze(reconstruction))
plt.plot(np.squeeze(spectrum))

In [None]:
plt.figure(figsize=(15, 10))

plt.subplot(2, 1, 1)
plt.plot(np.squeeze(np.clip(reconstruction, 0, 1)))
plt.plot(np.squeeze(spectrum))
#plt.xlim(12000, 14000)
plt.xlim(hps.n_bins*.475, hps.n_bins*.525)

plt.subplot(2, 1, 2)
plt.plot(np.squeeze(np.clip(reconstruction, 0, 1)) - np.squeeze(spectrum))
#plt.xlim(12000, 14000)
plt.xlim(hps.n_bins*.475, hps.n_bins*.525)

In [None]:
# Make a corner plot with 4 components of the latent representation
latent_reps = np.empty([0, 2500])
initialize_input_stream()
with tqdm(total=hps.n_batches) as pbar:
    for _ in range(hps.n_batches):
        data = sess.run(input_stream)
        latent_rep = m.encode(data)
        latent_reps = np.append(latent_reps, latent_rep[:, :, 0], axis=0) # select all channels of one component
        pbar.set_postfix(mean=latent_reps.mean(), std=latent_reps.std())
        pbar.update()

In [None]:
start_position = 700
components = 8
print(latent_reps.shape)
print(latent_reps[:, start_position:start_position + components].mean(axis=0))
print(latent_reps[:, start_position:start_position + components].std(axis=0))

In [None]:
figure = corner.corner(latent_reps[:, start_position:start_position + components], 
                       range=components*[(-.5, .5)])

axes = np.array(figure.axes).reshape((components, components))
for yi in range(components):
    for xi in range(yi):
        ax = axes[yi, xi]
        ax.axvline(0, color="g")
        ax.axhline(0, color="g")

In [None]:
plt.hist(latent_reps[:, 700], bins=[-1.5, -1, -.5, -.25, 0, .5, 1, 1.5])

In [None]:
m.save('test')

In [None]:
m.restore('test')