In [None]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.colors import n_colors
from scipy.stats import norm
import numpy as np
import pandas as pd
from functools import partial

In [None]:
def gaussian(xaxis, A, B, mu, sigma):
    y = B * (1 - norm(loc=mu, scale=sigma).cdf(xaxis) + A) 
    return y

In [None]:
from doepy import build

In [None]:
def generate_dataset(num_samples=1000):
    # levels = {'A': [1, 3], 'B': [0, 1], 'mu': [-2, 2], 'sigma': [0.2, 4]}
    levels = {'B': [1, 3], 'mu': [-2, 2], 'sigma': [0.2, 4]}
    xaxis = np.linspace(-4, 4)

    samples_params = build.space_filling_lhs(levels, num_samples=num_samples)

    Y = pd.DataFrame([gaussian(xaxis, 1e-2, *row) for row in samples_params.itertuples(index=False)])
    return Y

In [None]:
import random

def random_split(values:pd.DataFrame, size: float):
    train_index = random.sample(list(values.index), int(values.shape[0] * size))
    train, test = values.loc[train_index], values.loc[~values.index.isin(train_index)]
    return train, test

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import make_pipeline

In [None]:
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin_min

In [None]:
Y = generate_dataset()
train, test = train_test_split(Y, test_size=0.5)

In [None]:
import tensorflow.keras as K

In [None]:
class GANMonitor(K.callbacks.Callback):
    def __init__(self, figure_widget):
        self.figure_widget = figure_widget
        
    def on_epoch_end(self, epoch, logs=None):
        if epoch % 10 == 0:
            with self.figure_widget.batch_update():
                c_trace, g_trace = self.figure_widget.data
                x = list(c_trace.x)
                x.append(x[-1]+1 if len(x) > 0 else 1)

                y1 = list(c_trace.y)
                y1.append(logs['d_loss'])

                y2 = list(g_trace.y)
                y2.append(logs['g_loss'])

                c_trace.x = x
                c_trace.y = y1
                g_trace.x = x
                g_trace.y = y2


In [None]:
%load_ext autoreload
%autoreload 2

from GAN import WGANGP

In [None]:
# generator_optimizer = Adam(
#     learning_rate=1e-5, 
    # beta_1=0.5, beta_2=0.9
# )
# discriminator_optimizer = Adam(
#     learning_rate=1e-5, 
#beta_1=0.5, beta_2=0.9
# )
generator_optimizer = RMSprop(1e-5)
discriminator_optimizer = RMSprop(1e-5)


layers = (100, 50, 25)
model = WGANGP(train.shape[1], 5, layers, layers[::-1],
               gp_weight=1e-2, activation='relu',
               critic_extra_steps=3,
               critic_dropout=None, generator_dropout=None)

model.compile(discriminator_optimizer, generator_optimizer)

BATCH_SIZE = 500
fig = go.FigureWidget(
    [
        go.Scatter(x=[], y=[], name='critic'),
        go.Scatter(x=[], y=[], name='generator')
    ]
)

callback = GANMonitor(fig)
fig

In [None]:
epochs = 10
model.fit(train.values, epochs=epochs, callbacks=[callback])

In [None]:
fig = go.Figure()
for row in model.generator(np.random.randn(20, model.latent_dim)).numpy():
    fig.add_scatter(y=row)
fig

In [None]:
fig = go.Figure()
for row in train.sample(20).values:
    fig.add_scatter(y=row)
fig

In [None]:
fake = model.generator(random_latent_vectors)

In [None]:
tmp = Y.loc[0:1].values

In [None]:
model.gradient_penalty(2, tmp, fake)