# Flow on 2d toy data sets

1. Data generation
2. Model training
3. Visual inspection of model generation vs actual
4. Building a classifier to distinguish original from generated data

https://github.com/acids-ircam/pytorch_flows/blob/master/flows_01.ipynb

In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import seaborn as sns

from fastai.tabular.all import *
import ipywidgets as widgets

from gen import utils, data, ae, vae, flow

from torch import distributions

In [None]:
n_in = 2
n_flows = 3
n_z = 2
n_h = 200
p = .01

In [None]:
encoder = flow.NNDiagGaussian(
    mean_encoder_model=ae.get_encoder_model(n_in, n_h, p, n_z),
    logvar_encoder_model=ae.get_encoder_model(n_in, n_h, p, n_z)
)
# encoder = nf.distributions.NNDiagGaussian(encoder_nn)

In [None]:
x = torch.as_tensor(np.zeros((3, n_in)), dtype=torch.float)
x

In [None]:
z, log_q = encoder(x)
print(f'{z=}, \n{log_q=}')

In [None]:
encoder.log_prob(z, x)

In [None]:
z, log_det = flow.Planar((n_z,))(z)
print(f'{z=}, \n{log_det=}')

In [None]:
prior = distributions.MultivariateNormal(torch.zeros(n_z),
                                         torch.eye(n_z)) # prior in latent space

In [None]:
prior.log_prob(z)

In [None]:
flows = [flow.Planar((n_z,)) for k in range(n_flows)]

In [None]:
# decoder = flow.NNDiagGaussianDecoder(
#     ae.get_decoder_model(n_in, n_h, p, n_z)
# )abs
decoder = flow.NNDiagGaussian(
    ae.get_decoder_model(n_in, n_h, p, n_z),
    ae.get_decoder_model(n_in, n_h, p, n_z),
)

In [None]:
decoder.log_prob(x,z)

In [None]:
r, _ = decoder(z)
r

In [None]:
n = 4200
y_col = 'target'
pattern = 'checkerboard'
all_train_data = data.DataGenerator.generate(pattern, n).assign(**{y_col: np.random.choice([0,1], size=n)})
all_train_data.head()

In [None]:
fig, ax = plt.subplots()
ax.scatter(all_train_data['x_0'], all_train_data['x_1'], alpha=.1)
plt.show()

In [None]:
splits = RandomSplitter(valid_pct=.2)(all_train_data)

original_features = L([c for c in all_train_data.columns if c != 'id' and c != y_col])

to = TabularPandas(all_train_data, procs=[FillMissing, Normalize],
                   cont_names=original_features,
                   y_names=y_col,
                   splits=splits)

bs = 32
kld_weight = .05
dls = to.dataloaders(bs=bs)

In [None]:
n_in = 2
n_flows = 3
n_z = 2
n_h = 200
p = .01

prior = distributions.MultivariateNormal(torch.zeros(n_z),
                                         torch.eye(n_z)) # prior in latent space

encoder = flow.NNDiagGaussian(
    mean_encoder_model=ae.get_encoder_model(n_in, n_h, p, n_z),
    logvar_encoder_model=ae.get_encoder_model(n_in, n_h, p, n_z)
)

flows = [flow.Planar((n_z,)) for k in range(n_flows)]

decoder = flow.NNDiagGaussian(
    ae.get_decoder_model(n_in, n_h, p, n_z),
    ae.get_decoder_model(n_in, n_h, p, n_z),
)

model = flow.NormalizingFlow(prior, encoder, flows, decoder)

loss_func = flow.Flow_Loss()

In [None]:
%%time
learn = Learner(dls, model, loss_func=loss_func)

var = std^2
log(std) = log(sqrt(var))

In [None]:
lrs = learn.lr_find()
lrs

In [None]:
learn.fit_one_cycle(10, lr_max=lrs.valley)

In [None]:
%%time
(rec, ori, mu, var), _ = learn.get_preds(ds_idx=1)

In [None]:
fig, report = utils.check_identifiability_of_generated_data(ori, rec, original_features)
fig.show()
print(report)

## Clicking through data patterns

In [None]:
patterns = ['twospirals', 'twomoons', 'sign', 'abs', 'sinewave', 'crescentcube', 'crescent', 'gaussian', 'checkerboard']

pattern = widgets.Dropdown(description='pattern', options=patterns, value='checkerboard')
n_data = widgets.IntText(description='data points', value=4_200)
n_epoch = widgets.IntText(description='epochs', value=7)
n_flows = widgets.IntText(description='n_flows', value=3)
bs = widgets.IntText(description='bs', value=256)
ui = widgets.VBox([pattern, n_data, n_epoch, n_flows, bs])

def run_stuff(p, n, epochs, n_flows, bs):
    # generating data
    y_col = 'target'
    all_train_data = data.DataGenerator.generate(p, n).assign(**{y_col: np.random.choice([0,1], size=n)})
    
    # pre-processing data
    splits = RandomSplitter(valid_pct=.2)(all_train_data)

    original_features = L([c for c in all_train_data.columns if c != 'id' and c != y_col])

    to = TabularPandas(all_train_data, procs=[FillMissing, Normalize],
                       cont_names=original_features,
                       y_names=y_col,
                       splits=splits)

#     bs = 256
    dls = to.dataloaders(bs=bs)
    
    # setting up the model
    n_in = 2
    n_flows = 3
    n_z = 2
    n_h = 200
    p = .01

    prior = distributions.MultivariateNormal(torch.zeros(n_z),
                                             torch.eye(n_z)) # prior in latent space

    encoder = flow.NNDiagGaussian(
        mean_encoder_model=ae.get_encoder_model(n_in, n_h, p, n_z),
        logvar_encoder_model=ae.get_encoder_model(n_in, n_h, p, n_z)
    )

    flows = [flow.Planar((n_z,)) for k in range(n_flows)]

    decoder = flow.NNDiagGaussian(
        ae.get_decoder_model(n_in, n_h, p, n_z),
        ae.get_decoder_model(n_in, n_h, p, n_z),
    )

    model = flow.NormalizingFlow(prior, encoder, flows, decoder)
    loss_func = flow.Flow_Loss()
    
    # training
    learn = Learner(dls, model, loss_func=loss_func)
    lrs = learn.lr_find()
    learn.fit_one_cycle(epochs, lr_max=lrs.valley)
    
    # inspecting model generated data
    (rec, ori, mu, logvar), _ = learn.get_preds(ds_idx=1)
#     print(loss_func.loss(rec, ori, mu, logvar))
    fig, report = utils.check_identifiability_of_generated_data(ori, rec, 
                                                  original_features)
    
    fig.show()
    print(report)

params = {
    'p':pattern, 
    'n':n_data, 
    'epochs':n_epoch, 
    'n_flows':n_flows,
    'bs': bs,
}
out = widgets.interactive_output(run_stuff, params)
display(ui, out)

In [None]:
import torch.distributions as distrib
import torch.distributions.transforms as transforms


x = np.linspace(-4, 4, 1000)
z = np.array(np.meshgrid(x, x)).transpose(1, 2, 0)
z = np.reshape(z, [z.shape[0] * z.shape[1], -1])

In [None]:
print(f'{z=}')

In [None]:
z = torch.as_tensor(z, dtype=torch.float)
print(f'{z=}')

In [None]:
mu = torch.zeros(2, dtype=torch.float)
cov = torch.eye(2, dtype=torch.float)
print(f'{mu=}\n{cov=}')

In [None]:
# Initial distribution
q0 = distrib.MultivariateNormal(mu, 
                                covariance_matrix=cov)
# Defining Affine Transformation
# f1 = transforms.ExpTransform()
f1 = transforms.PowerTransform(2)
# Transforming
q1 = distrib.TransformedDistribution(q0, f1)

In [None]:
q0.log_prob(z)

In [None]:
f1(q0(z))

In [None]:
q1.log_prob(z)