# VAE on 2d toy data sets

1. Data generation
2. Model training
3. Visual inspection of model generation vs actual
4. Building a classifier to distinguish original from generated data

https://github.com/didriknielsen/survae_flows

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

from fastai.tabular.all import *
import ipywidgets as widgets

from gen import utils, data, vae

In [None]:
n = 4200
y_col = 'target'
all_train_data = data.DataGenerator.generate('gaussian', n).assign(**{y_col: np.random.choice([0,1], size=n)})
all_train_data.head()

In [None]:
fig, ax = plt.subplots()
ax.scatter(all_train_data['x_0'], all_train_data['x_1'], alpha=.1)
plt.show()

In [None]:
splits = RandomSplitter(valid_pct=.2)(all_train_data)

original_features = L([c for c in all_train_data.columns if c != 'id' and c != y_col])

to = TabularPandas(all_train_data, procs=[FillMissing, Normalize],
                   cont_names=original_features,
                   y_names=y_col,
                   splits=splits)

bs = 256
kld_weight = .05
dls = to.dataloaders(bs=bs)

In [None]:
model = vae.VAE(n_in=len(original_features), n_h=200, n_z=2)
loss_func = vae.VAE_Loss(kld_weight)

In [None]:
%%time
learn = Learner(dls, model, loss_func=loss_func)

In [None]:
lrs = learn.lr_find()
lrs

In [None]:
learn.fit_one_cycle(10, lr_max=lrs.valley)

In [None]:
%%time
(ori, rec, mu, var), _ = learn.get_preds(ds_idx=1)

In [None]:
utils.check_identifiability_of_generated_data(ori, rec, original_features)

## Clicking through data patterns

In [None]:
patterns = ['twospirals', 'twomoons', 'sign', 'abs', 'sinewave', 'crescentcube', 'crescent', 'gaussian', 'checkerboard']

pattern = widgets.Dropdown(description='pattern', options=patterns, value='checkerboard')
n_data = widgets.IntText(description='data points', value=4_200)
n_epoch = widgets.IntText(description='epochs', value=7)
kld_weight = widgets.FloatText(description='kld_weight', value=1.)
ui = widgets.VBox([pattern, n_data, n_epoch, kld_weight])

def run_stuff(p, n, epochs, kld_weight):
    # generating data
    y_col = 'target'
    all_train_data = data.DataGenerator.generate(p, n).assign(**{y_col: np.random.choice([0,1], size=n)})
    
    # pre-processing data
    splits = RandomSplitter(valid_pct=.2)(all_train_data)

    original_features = L([c for c in all_train_data.columns if c != 'id' and c != y_col])

    to = TabularPandas(all_train_data, procs=[FillMissing, Normalize],
                       cont_names=original_features,
                       y_names=y_col,
                       splits=splits)

    bs = 256
    dls = to.dataloaders(bs=bs)
    
    # setting up the model
    model = vae.VAE(n_in=len(original_features), n_h=200, n_z=2)
    loss_func = vae.VAE_Loss(kld_weight=kld_weight) # bs
    
    # training
    learn = Learner(dls, model, loss_func=loss_func)
    lrs = learn.lr_find()
    learn.fit_one_cycle(epochs, lr_max=lrs.valley)
    
    # inspecting model generated data
    (ori, rec, mu, logvar), _ = learn.get_preds(ds_idx=1)
    print(loss_func.loss(rec, ori, mu, logvar))
    utils.check_identifiability_of_generated_data(ori, rec, 
                                                  original_features)

out = widgets.interactive_output(run_stuff, {'p':pattern, 'n':n_data, 'epochs':n_epoch, 'kld_weight':kld_weight})
display(ui, out)