In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".home"])

# append to path
sys.path.append(str(root))

%load_ext autoreload
%autoreload 2

In [None]:
# TYPE HINTS
from typing import Tuple, Optional, Dict, Callable, Union

# JAX SETTINGS
import jax
import jax.numpy as jnp
import jax.random as random
import treex as tx

import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions
from sklearn.datasets import make_moons

# NUMPY SETTINGS
import numpy as np

np.set_printoptions(precision=3, suppress=True)

# MATPLOTLIB Settings
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import corner

# SEABORN SETTINGS
import seaborn as sns

sns.set_context(context="talk", font_scale=0.7)

# PANDAS SETTINGS
import pandas as pd

pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)

# LOGGING SETTINGS
import sys
import logging

logging.basicConfig(
    level=logging.INFO,
    stream=sys.stdout,
    format="%(asctime)s:%(levelname)s:%(message)s",
)
logger = logging.getLogger()
# logger.setLevel(logging.INFO)

%load_ext autoreload
%autoreload 2

## Data

In [None]:
def get_toy_data(dataset="classic", n_samples=1000, seed=123):
    rng = np.random.RandomState(seed=seed)

    x = np.abs(2 * rng.randn(n_samples, 1))
    y = np.sin(x) + 0.25 * rng.randn(n_samples, 1)
    data = np.hstack((x, y))

    return data


def plot_joint(data, color: str = "red", title: str = "", logger=None, kind="scatter"):

    plt.figure(figsize=(5, 5))
    g = sns.jointplot(
        x=data[:, 0],
        y=data[:, 1],
        kind=kind,
        color=color,
    )
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.suptitle(title)
    plt.tight_layout()
    # if logger is not None:
    #     wandb.log({title: [wandb.Image(plt)]})
    #     plt.gcf()
    #     plt.clf()
    #     plt.close()
    # else:
    plt.show()

In [None]:
ntrain = 20_000
ntest = 5_000
noise = 0.1
random_state = 123
train_data = make_moons(n_samples=ntrain, noise=noise, random_state=random_state)[0]
test_data = make_moons(n_samples=ntest, noise=noise, random_state=random_state * 10)[0]

In [None]:
fig = corner.corner(train_data, color="blue")

In [None]:
# x_2d_samples = get_toy_data(n_samples=2_000)

# fig = corner.corner(x_2d_samples, color="blue")


## Gaussianization Transforms

In [None]:
from flowjax._src.transforms.bijections.elementwise.invcdf import InverseGaussCDF, Logit
from flowjax._src.transforms.bijections.elementwise.mixturecdf import (
    GaussianMixtureCDF,
    LogisticMixtureCDF,
)
from flowjax._src.transforms.bijections.linear.orthogonal import RandomRotation
from flowjax._src.transforms.base import Composite

In [None]:
# layer params
num_mixtures = 12

# create layers
bijector_block = [
    # marginal uniformization
    GaussianMixtureCDF(num_mixtures=num_mixtures),
    # LogisticMixtureCDF(num_mixtures=num_mixtures),
    # marginal gaussianization
    InverseGaussCDF(),
    # Logit(),
    # orthogonal transform
    RandomRotation(),
]

# multiply blocks
n_layers = 6
bijectors = bijector_block * n_layers

In [None]:
# create composite
model = Composite(*bijectors)

# init keys and data
x_init = jnp.array(train_data)
key_init = jax.random.PRNGKey(123)

# init layer params (data-dependent)
model = model.init(key=123, inputs=x_init)

In [None]:
z_mu = model(train_data)
z_mu, ldj = model.forward_and_log_det(x_init)
x_approx = model.inverse(z_mu)

In [None]:
fig = corner.corner(np.array(test_data), color="red")
fig = corner.corner(np.array(z_mu), color="black")
fig = corner.corner(np.array(x_approx), color="blue")

## Training

### Loss Function

In [None]:
from flowjax._src.utils.tensors import sum_except_batch

# base_dist = tfd.Normal(jnp.zeros(2), jnp.ones(2))
base_dist = tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2))
# base_dist = tfd.Uniform()

# differentiate only w.r.t. parameters
def loss_fn(params, model, x):
    # merge params into model
    model = model.merge(params)

    # forward transformation
    z, ldj = model.forward_and_log_det(x)

    # latent prob
    latent_prob = base_dist.log_prob(z)

    # calculate log prob
    log_prob = sum_except_batch(latent_prob) + sum_except_batch(ldj)

    # calculate nll
    loss = -jnp.mean(log_prob)

    # the model may contain state updates
    # so it should be returned
    return loss, model

In [None]:
params = model.parameters()
loss, model_ = loss_fn(params, model, x_init)

loss

#### Gradients

In [None]:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

(loss_, m_), grads_ = grad_fn(params, model, x_init)

loss_

### Train Step

In [None]:
# both model and optimizer are jit-able
@jax.jit
def train_step(model, x, optimizer):
    # select only the parameters
    params = model.parameters()

    (loss, model), grads = grad_fn(params, model, x)

    # update params and model
    params = optimizer.update(grads, params)
    model = model.merge(params)

    # return new model and optimizer
    return loss, model, optimizer

#### Optimizer

In [None]:
import optax

# learning rate
lr = 0.001

# scheduler (TODO)
optimizer = tx.Optimizer(optax.adam(lr)).init(model)

### Training Loop

In [None]:
from tqdm.notebook import trange, tqdm

n_iterations = 20_000
losses = []
batch_size = 64

with trange(n_iterations) as pbar:
    for i in pbar:
        # train_data = get_toy_data(n_samples=batch_size, seed=i)
        train_data = make_moons(n_samples=batch_size, noise=noise, random_state=i)[0]

        ibatch = jnp.array(train_data)
        loss, model, optimizer = train_step(model, ibatch, optimizer)

        pbar.set_description(f"Loss: {loss:.4f}")
        losses.append(loss)

## Results

In [None]:
model = model.eval()

### Losses

In [None]:
fig, ax = plt.subplots()

ax.plot(losses)

plt.show()

### Forward Transform

In [None]:
z_mg = model(test_data)

fig = corner.corner(np.array(test_data), color="red")
fig = corner.corner(np.array(z_mg), color="black")

### Inverse Transform

In [None]:
x_approx = model.inverse(z_mg)

In [None]:
fig = corner.corner(np.array(x_approx), color="blue")

### Generated Samples

In [None]:
z_samples = base_dist.sample(sample_shape=(100_000), seed=key_init)

In [None]:
%time
x_samples = model.inverse(z_samples)

In [None]:
fig = corner.corner(np.array(test_data), color="red")
fig = corner.corner(np.array(x_samples), color="green")

### Density Estimation

In [None]:
def generate_2d_grid(
    data: np.ndarray, n_grid: int = 1_000, buffer: float = 0.01
) -> np.ndarray:

    xline = np.linspace(data[:, 0].min() - buffer, data[:, 0].max() + buffer, n_grid)
    yline = np.linspace(data[:, 1].min() - buffer, data[:, 1].max() + buffer, n_grid)
    xgrid, ygrid = np.meshgrid(xline, yline)
    xyinput = np.concatenate([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], axis=1)
    return xyinput

In [None]:
xyinput = generate_2d_grid(test_data, 500, buffer=0.1)

In [None]:
# forward transformation
z, ldj = model.forward_and_log_det(xyinput)

# latent prob
latent_prob = base_dist.log_prob(z)

# calculate log prob
x_log_prob = sum_except_batch(latent_prob) + sum_except_batch(ldj)

In [None]:
from matplotlib import cm

# # Original Density
# n_samples = 1_000_000
# n_features = 2
# X_plot = load_data(n_samples, 42)
# X_plot = StandardScaler().fit_transform(X_plot)

# Estimated Density
cmap = cm.magma  # "Reds"
probs = np.exp(x_log_prob)
# probs = np.clip(probs, 0.0, 1.0)
# probs = np.clip(probs, None, 0.0)


cmap = cm.magma  # "Reds"
# cmap = "Reds"

fig, ax = plt.subplots(ncols=2, figsize=(12, 5))
h = ax[0].hist2d(
    test_data[:, 0],
    test_data[:, 1],
    bins=512,
    cmap=cmap,
    density=True,
    vmin=0.0,
    vmax=1.0,
)
ax[0].set_title("True Density")
ax[0].set(
    xlim=[test_data[:, 0].min(), test_data[:, 0].max()],
    ylim=[test_data[:, 1].min(), test_data[:, 1].max()],
)


h1 = ax[1].scatter(
    xyinput[:, 0],
    xyinput[:, 1],
    s=1,
    c=probs,
    cmap=cmap,  # vmin=0.0, vmax=1.0
)
ax[1].set(
    xlim=[xyinput[:, 0].min(), xyinput[:, 0].max()],
    ylim=[xyinput[:, 1].min(), xyinput[:, 1].max()],
)
# plt.colorbar(h1)
ax[1].set_title("Estimated Density")


plt.tight_layout()
plt.show()