In [None]:
import corner
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tftables
import time
from tqdm import tqdm_notebook as tqdm

import data_loader
import model_short as model
import toy_data_loader

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
np.set_printoptions(suppress=True, precision=4)

In [None]:
class hps:
    pass
hps.n_levels = 3 # number of splits
hps.depth = 3 # number of layers in revnet
hps.width = 16 # channels in revnet layers
hps.polyak_epochs = 1
hps.beta1 = .9 # learning rate annealing factor
hps.weight_decay = 1 # learning rate annealing factor
hps.lr = .001 # base learning rate
hps.n_data = 100000 # number of input spectra
hps.batch_size = 50 # number of spectra in a batch
hps.n_batches = int(hps.n_data / hps.batch_size)
hps.n_bins = 2**12

In [None]:
sess = tf.InteractiveSession()

In [None]:
# select real or toy data by uncommenting the appropriate line
# real data must have n_data=8000, n_bins=40000
#input_stream, initialize_input_stream, data_init = data_loader.create_data_loader(
input_stream, initialize_input_stream, data_init = toy_data_loader.create_data_loader(
    sess, hps.batch_size, hps.n_data, hps.n_bins
)
'''
spectra = np.load('sample_short.npz')['spectra']
sqrt = np.sqrt(spectra)

# add noise
#sums = spectra.sum(axis=1)
#sqrtsums = sqrt.sum(axis=1)
#As = .02 * sums / (np.sqrt(2 / 3.14) * sqrtsums)
#noise = np.random.normal(scale=(np.repeat(As[:, np.newaxis], hps.n_bins, axis=1) * sqrt))
#print((np.abs(noise).sum(axis=1) / spectra.sum(axis=1)))

scaled_spectra = spectra / spectra.std(axis=1)[:, np.newaxis]
#scaled_spectra = (spectra + noise) / (spectra + noise).std(axis=1)[:, np.newaxis]
centered_spectra = scaled_spectra - scaled_spectra.mean(axis=1)[:, np.newaxis]
#normalized_spectra = spectra / np.max(spectra, axis=1)[:, np.newaxis]

def create_data_loader(sess, data, batch_size):
    placeholder_data = tf.compat.v1.placeholder(tf.float32, data.shape)
    dataset = tf.data.Dataset.from_tensor_slices(placeholder_data)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_initializable_iterator()
    input_stream = iterator.get_next()
    
    def initialize_input_stream():
        sess.run(iterator.initializer, feed_dict={placeholder_data: data})
    
    initialize_input_stream()
    data_init = sess.run(input_stream)
    return input_stream, initialize_input_stream, data_init

input_stream, initialize_input_stream, data_init = create_data_loader(
    sess, centered_spectra[:, :, np.newaxis], hps.batch_size
)'''

In [None]:
print(data_init.shape)
plt.figure(figsize=(15, 5))
for spectrum in data_init[:5]:
    plt.plot(spectrum)

In [None]:
with tf.device("/device:GPU:0"):
    m = model.model(sess, hps, input_stream, data_init)

In [None]:
n_processed = 0
training_results = []
lrs = []

In [None]:
hps.epochs = 100
hps.epochs_warmup = .01
hps.lr = .0001

for epoch in tqdm(range(1, hps.epochs + 1), desc='Epochs'):
    epoch_results = []
    initialize_input_stream()
    with tqdm(total=hps.n_batches) as pbar:
        for iteration in range(hps.n_batches):
            pbar.set_description('Epoch ' + str(epoch))
            lr = hps.lr * min(1., n_processed / (hps.batch_size * hps.n_batches * hps.epochs_warmup))
            training_result = [m.train(lr)]
            epoch_results += training_result
            training_results += training_result
            lrs += [lr]
            n_processed += hps.batch_size
            pbar.set_postfix(lr=lr, loss=np.mean(epoch_results))
            pbar.update()

In [None]:
plt.figure(figsize=(12, 10))

plt.subplot(2, 1, 1)
plt.plot(np.linspace(0, len(training_results) / hps.n_batches, len(training_results)), training_results)
#training_results_per_epoch = np.reshape(training_results, [-1, hps.n_batches]).mean(axis=1)
#plt.plot(np.arange(0+hps.n_batches/2, len(training_results)+hps.n_batches/2, hps.n_batches), training_results_per_epoch)
#plt.yscale('symlog')
plt.ylim(-5, -4)
plt.xlabel('epochs')

plt.subplot(2, 1, 2)
plt.plot(np.linspace(0, len(training_results) / hps.n_batches, len(training_results)), lrs)
plt.xlabel('epochs')

In [None]:
i = np.random.randint(0, hps.batch_size)
spectrum = data_init[i:i+1, :, :]
print(i)

In [None]:
latent_rep = m.encode(spectrum)
reconstruction = m.decode(latent_rep)
print(latent_rep.mean(), latent_rep.std())
print(reconstruction.mean(), reconstruction.std())
print(np.mean((spectrum - reconstruction)**2))

In [None]:
for channel in range(latent_rep.shape[-1]):
    plt.plot(latent_rep[0, :, channel])

In [None]:
window = (1850, 2200) #(12000, 14000)

In [None]:
plt.figure(figsize=(10, 7))

plt.subplot(3, 1, 1)
plt.plot(np.squeeze(reconstruction))
plt.plot(np.squeeze(spectrum))
plt.axvline(window[0])
plt.axvline(window[1])

plt.subplot(3, 1, 2)
plt.plot(range(*window), np.squeeze(reconstruction)[window[0]:window[1]])
plt.plot(range(*window), np.squeeze(spectrum)[window[0]:window[1]])

plt.subplot(3, 1, 3)
plt.plot(range(*window), np.squeeze(reconstruction - spectrum)[window[0]:window[1]])

In [None]:
latent_reps = np.empty([hps.n_batches, hps.batch_size, latent_rep.shape[1], latent_rep.shape[2]])
initialize_input_stream()
with tqdm(total=hps.n_batches) as pbar:
    for i in range(hps.n_batches):
        data = sess.run(input_stream)
        latent_reps[i] = m.encode(data)
        pbar.set_postfix(mean=latent_reps.mean(), std=latent_reps.std())
        pbar.update()

latent_reps = latent_reps.reshape(hps.n_data, latent_rep.shape[1], latent_rep.shape[2])

In [None]:
plt.plot(latent_reps.mean(axis=0)[:, 0])
plt.plot(latent_reps.mean(axis=0)[:, 1])
plt.plot(latent_reps.mean(axis=0)[:, 2])
plt.plot(latent_reps.mean(axis=0)[:, 3])

In [None]:
components = 16
start_position = int(latent_reps.shape[1] / 2 - components / 2)

print(latent_reps.shape)
print(latent_reps[:, start_position:start_position + components, 0].mean(axis=0))
print(latent_reps[:, start_position:start_position + components, 0].std(axis=0))

In [None]:
figure = corner.corner(latent_reps[:, start_position:start_position + components].sum(axis=2))

axes = np.array(figure.axes).reshape((components, components))
for yi in range(components):
    for xi in range(yi):
        ax = axes[yi, xi]
        ax.axvline(0, color="g")
        ax.axhline(0, color="g")

In [None]:
sampled_latent_rep = np.random.normal(size=latent_rep.shape)
#sampled_latent_rep[0, :, 0] = latent_reps.mean(axis=0)[:, 0]
#sampled_latent_rep[0, :, 1] = latent_reps.mean(axis=0)[:, 1]
#sampled_latent_rep[0, :, 2] = latent_reps.mean(axis=0)[:, 2]
#sampled_latent_rep[0, :, 3] = latent_reps.mean(axis=0)[:, 3]

In [None]:
plt.plot(sampled_latent_rep[0, :, 0])
plt.plot(sampled_latent_rep[0, :, 1])
plt.plot(sampled_latent_rep[0, :, 2])
plt.plot(sampled_latent_rep[0, :, 3])

In [None]:
plt.plot(np.squeeze(m.decode(sampled_latent_rep)))

In [None]:
from datetime import datetime
model_filename = 'models/model-{}'.format(datetime.now().strftime('%y%m%d-%H%M%S'))
print(model_filename)

In [None]:
m.save(model_filename)

In [None]:
m.restore('models/model-200304-081901')