In [1]:
from gan import Generator, Discriminator
from hierarchical_negbin import RecordGenerator
from peak_detector7 import Features, SignalHead, DeconvHead


import os
import json
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from matplotlib import cm
viridis = cm.get_cmap('viridis', 12)
cols = ["#926cb6", "#93b793", "#d31d00", "#ff900d", "#fefb03", "black"]
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

In [2]:
first_run = True

In [3]:
MAXBINS = 400
BATCHSIZE = 32  # max size is 64

In [4]:
simulator = RecordGenerator(n_obs=(1500, 100000),
                            n_bins=(50, MAXBINS),
                            n_meanings=(1, 3),
                            noise_ratio=(0.1, 1.0),
                            noise_dispersion=(0.05, 1.5),
                            alpha_meanings=(1.0, 4.0),
                            rounding=list(range(1, 17)),
                            inner_mode_dist_tol=0.25,
                            inner_sigma_ratio=3.0,
                            max_sigma_to_bins_ratio=0.125,
                            sigmas=(1.0, 50.0),
                            trim_corners=True)

In [5]:
geninput_noise = tf.keras.Input(shape=(MAXBINS, 1), dtype=tf.float32)
geninput_signal = tf.keras.Input(shape=(MAXBINS, 1), dtype=tf.float32)
geninput_noiseratio = tf.keras.Input(shape=(1, 1), dtype=tf.float32)
disinput = tf.keras.Input(shape=(MAXBINS, 1), dtype=tf.float32)

genoutput = Generator(ksize=5, filters=32, nblocks=16)([geninput_noise, geninput_signal, geninput_noiseratio])
disoutput = Discriminator(ksize=7, filters=32, nblocks=16)(disinput)

generator = tf.keras.Model(inputs=[geninput_noise, geninput_signal, geninput_noiseratio], outputs=genoutput)
discriminator = tf.keras.Model(inputs=disinput, outputs=disoutput)

# inputs_smoother = tf.keras.Input(shape=(MAXBINS, 1), dtype=tf.float32)
# feats_smoother = Features(ksize=7, filters=32, nblocks=12)(inputs_smoother)
# signal, peaks = SignalHead(ksize=5, filters=32, nblocks_signal=2, nblocks_peaks=4)([feats_smoother, inputs_smoother])
# smoother = tf.keras.Model(inputs=inputs_smoother, outputs=[signal, peaks])

In [6]:
# generator.load_weights("generator.h5")
# discriminator.load_weights("discriminator.h5")
smoother.load_weights("tmp_back_to_two_smoother.h5")
# smoother.load_weights("smoother.h5")

In [7]:
# generator.summary(line_length=150)

In [8]:
# discriminator.summary(line_length=150)

In [9]:
def standardize_bins(obs):
    if len(obs) > MAXBINS:
        start = (len(obs) - MAXBINS) // 2
        obs = obs[start:(start + MAXBINS)]
    L = len(obs)
    assert L <= MAXBINS
    start = MAXBINS // 2 - L // 2
    end = start + L
    out = np.zeros(MAXBINS, dtype=type(obs[0]))
    out[start:end] = obs
    return out, start

    
def preprocess_input(obs):
    # assumes 1 vector of observation and possibly multiple peaks
    obs, start = standardize_bins(obs)
    x = np.array(obs, dtype=np.float32)
    x /= x.sum()
    xinput = x * np.sqrt(MAXBINS)
    xinput = np.expand_dims(xinput, -1)
    return xinput


def preprocess_batch(batch):
    # assumes 1 vector of observation and possibly multiple peaks
    x = [preprocess_input(x) for x in batch]
    x = np.stack(x)
    return x


def get_inputs(batch):
    nbins_sizes = [len(x) for x in batch]
    noises = []
    signals = []
    modes = []
    for n in nbins_sizes:
        z = np.random.standard_t(df=2, size=n)
        z, _ = standardize_bins(z)
        dat = simulator.generate(n_bins=n)
        pdf = dat['pdf']
        pdf, _ = standardize_bins(pdf)
        m = dat['modes_onehot']
        m, _ = standardize_bins(m)
        modes.append(m)
        noises.append(z)
        signals.append(pdf)
    noises = np.stack(noises, 0)       
    noises = np.expand_dims(noises, -1)
    signals = np.stack(signals, 0)       
    signals = np.expand_dims(signals, -1)
    modes = np.stack(modes, 0)       
    modes = np.expand_dims(modes, -1)
    nr = np.random.uniform(0.1, 0.5, size=(len(batch), 1, 1))
    return noises, signals, nr, modes


def get_batch(file):
    with open(file, "r") as f:
        x = json.load(f)
    return x


def get_gan_inputs(batch):
    batch = [x['counts'] for x in batch]
    x = preprocess_batch(batch)
    noises, pdfs, noiseratio, modes = get_inputs(batch)
    return x, noises, pdfs, noiseratio, modes

In [10]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss


def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)


# def smoother_loss(y, yhat):
#     # symmetric kl
#     x = - 0.5 * y * tf.math.log((yhat + 1e-10) / (y + 1e-10))
#     x += - 0.5 * yhat * tf.math.log((y + 1e-10) / (yhat + 1e-10))
#     x = tf.math.reduce_sum(x, 1)
#     loss = tf.math.reduce_mean(x)
#     return loss

# def peaks_loss(y, yhat):
#     # symmetric kl
#     x = - 0.5 * y * tf.math.log((yhat + 1e-10) / (y + 1e-10))
#     x += - 0.5 * yhat * tf.math.log((y + 1e-10) / (yhat + 1e-10))
#     x = tf.math.reduce_sum(x, 1)
#     loss = tf.math.reduce_mean(x)
#     return loss


generator_optimizer = tf.keras.optimizers.Adam(5e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(5e-5)
# smoother_optimizer = tf.keras.optimizers.Adam(5e-4)


# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(real_signal, noise, pdf, nr, modes):

#     with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape, tf.GradientTape() as sm_tape:
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator([noise, pdf, nr], training=True)
        real_output = discriminator(real_signal, training=True)
        fake_output = discriminator(generated_images, training=True)
#         smooth, peaks = smoother(generated_images, training=True)
        
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
#         sm_loss = smoother_loss(real_signal, smooth) + peaks_loss(modes, peaks)


    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
#     gradients_of_smoother = sm_tape.gradient(sm_loss, smoother.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
#     smoother_optimizer.apply_gradients(zip(gradients_of_smoother, smoother.trainable_variables))
    
#     return gen_loss, disc_loss, sm_loss
    return gen_loss, disc_loss

In [11]:
fdir = "../SmoothDoQ/doq_noun_batches/"
files = os.listdir(fdir)
files = [os.path.join(fdir, x) for x in files]

In [12]:
print_every = 5
plot_every = 50
save_every = 250
lam = 0.01
epochs = 10
entries_per_file = 64
batches_per_file = entries_per_file // BATCHSIZE
np.random.shuffle(files)


for epoch in range(epochs):
    print("===== Epoch 1: =====")
    i = 0
    for file in files:
        file_data = get_batch(file)
        for b in range(batches_per_file):
            batch = file_data[b*BATCHSIZE:(b + 1)*BATCHSIZE]
            x, z, s, nr, modes = get_gan_inputs(batch)
            x_ = tf.constant(x, tf.float32)
            z_ = tf.constant(z, tf.float32)
            s_ = tf.constant(s, tf.float32)
            nr_ = tf.constant(nr, tf.float32)
            modes_ = tf.constant(modes, tf.float32)
#             gen_loss, disc_loss, smoother_loss = train_step(x_, z_, s_, nr_, modes_)
            gen_loss, disc_loss = train_step(x_, z_, s_, nr_, modes_)

            # record losses
            if not first_run:
                gen_loss_ += lam * (gen_loss.numpy() - gen_loss_)
                disc_loss_ += lam * (disc_loss.numpy() - disc_loss_)
#                 sm_loss_ += lam * (sm_loss_.numpy() - sm_loss_)
                first_run = False
            else:
                gen_loss_ = gen_loss.numpy()
                disc_loss_ = disc_loss.numpy()
#                 sm_loss_ = sm_loss_.numpy()
                
            # print
            if i % print_every == 0:
                fr = f"iter: {i}, completed: {(i + 1) / len(files):.2f}%"
                gl = f"gen_loss: {gen_loss_:.3f}"
                dl = f"disc_loss: {disc_loss_:.3f}"
#                 sl = f"smoother_loss: {smoother_loss_:.3f}"
#                 msg = f"{fr}, {gl}, {dl}, {sl}"
                msg = f"{fr}, {gl}, {dl}"
                print(msg)

            if i % plot_every == 0:
                x0 = np.expand_dims(x[0, :, :], 0).astype(np.float32)
                z0 = np.expand_dims(z[0, :, :], 0).astype(np.float32)
                s0 = np.expand_dims(s[0, :, :], 0).astype(np.float32)
                nr0 = np.expand_dims(nr[0, :, :], 0).astype(np.float32)
                pdf = np.squeeze(s0)            
                fake = generator([z0, s0, nr0])
                fake = np.squeeze(fake.numpy())
                in_range = np.where(pdf > 0.0)[0]
                h = range(in_range[0], in_range[-1])    
                plt.figure(figsize=(15, 4))
                plt.bar(h, fake[h], width=1.0, alpha=0.75)
                plt.plot(h, pdf[h], c="red")
                u = pdf[h]
                plt.ylim(0, 25 * np.mean(u[u > 0]))
                plt.title(f"noise ratio: {nr0[0,0,0]:.2f}")
                plt.show()

            # Save the model every 15 epochs
            if (i + 1) % save_every == 0:
                generator.save_weights("generator.h5")
                discriminator.save_weights("discriminator.h5")
#                 smoother.save_weights("smoother.h5")
                
                
            i += 1

===== Epoch 1: =====


KeyboardInterrupt: 