# Residual Flow model with triangular Jacobian

In [None]:
import jax
from jax import numpy as jnp
import numpy as np
import distrax
import haiku as hk
from residual import TriangularResidual, spectral_norm_init, spectral_normalization, masks_triangular_weights, make_weights_triangular, LipSwish

from jax.experimental.optimizers import adam

In [None]:
key = jax.random.PRNGKey(1)

## Generate data

In [None]:
# N:=Number of samples
N = 3000
# D:=Number of dimensions
D = 2

# Generate the samples
S = jax.random.uniform(key, shape=(N, D), minval=0.0, maxval=1.0)

S -= 0.5


from plotting import cart2pol, scatterplot_variables

_, colors = cart2pol( S[:,0], S[:,1])

# Plot the sources
scatterplot_variables(S, 'Sources', colors=colors, savefig=False)

In [None]:
from mixing_functions import build_conformal_map

nonlinearity = lambda x : jnp.exp(1.3*x)

mixing, mixing_gridplot = build_conformal_map(nonlinearity)

from jax import vmap

mixing_batched = vmap(mixing)

X = mixing_batched(S)
X -= jnp.mean(X, axis=0)
X /= jnp.std(X, axis=0)

scatterplot_variables(X, 'Observations', colors=colors, savefig=False)#True, fname="observations_flow")

## Set up model

In [None]:
n_layers = 32
hidden_units = [128, 128]

def log_prob(x):
    base_dist = distrax.Independent(distrax.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)),
                                                    reinterpreted_batch_ndims=1)
    flows = distrax.Chain([TriangularResidual(hidden_units + [2], name='residual_' + str(i))
                           for i in range(n_layers)])
    model = distrax.Transformed(base_dist, flows)
    return model.log_prob(x)

# Init model
logp = hk.transform(log_prob)
params = logp.init(key, jnp.array(np.random.randn(5, 2)))

# Make triangular
masks = masks_triangular_weights([h // 2 for h in hidden_units])
params = make_weights_triangular(params, masks)

# Apply spectral normalization
uv = spectral_norm_init(params, key)
params, uv = spectral_normalization(params, uv)

In [None]:
def loss(params, x):
    ll = logp.apply(params, None, x)
    return -jnp.mean(ll)

## Model training

In [None]:
opt_init, opt_update, get_params = adam(step_size=1e-3)

@jax.jit
def step(it, opt_state, uv, x):
    params = get_params(opt_state)
    params = make_weights_triangular(params, masks) # makes Jacobian triangular
    params, uv = spectral_normalization(params, uv)
    params_flat = jax.tree_util.tree_flatten(params)[0]
    for ind in range(len(params_flat)):
        opt_state.packed_state[ind][0] = params_flat[ind]
    value, grads = jax.value_and_grad(loss, 0)(params, x)
    opt_out = opt_update(it, grads, opt_state)
    return value, opt_out, uv

In [None]:
iters, batch_size = 50000, 256

opt_state = opt_init(params)

loss_hist = np.array([])

In [None]:
for i in range(iters):
    x = X[np.random.choice(X.shape[0], batch_size)]
    value, opt_state, uv = step(i, opt_state, uv, x)
    loss_hist = np.append(loss_hist, value.item())

In [None]:
params_final = get_params(opt_state)
params_final = make_weights_triangular(params_final, masks)
params_final, _ = spectral_normalization(params_final, uv)

In [None]:
from matplotlib import pyplot as plt

plt.plot(loss_hist)

In [None]:
npoints = 300
x, y = jnp.linspace(-3., 3., npoints), jnp.linspace(-3., 3., npoints)
xx, yy = jnp.meshgrid(x, y)
zz = jnp.column_stack([xx.reshape(-1), yy.reshape(-1)])

prob = jnp.exp(logp.apply(params_final, None, zz))

In [None]:
plt.figure(figsize=(15, 15))
plt.pcolormesh(np.array(xx), np.array(yy), np.array(prob.reshape(npoints, npoints)))

In [None]:
def inv_map_fn(x):
    flows = distrax.Chain([TriangularResidual(hidden_units + [2], name='residual_' + str(i))
                           for i in range(n_layers)])
    return flows.inverse(x)
inv_map = hk.transform(inv_map_fn)

In [None]:
S_rec = inv_map.apply(params_final, None, X)

In [None]:
scatterplot_variables(S_rec, 'Sources', colors=colors, savefig=False)

## Compute CIMA

In [None]:
def cima(x):
    jac_fn = jax.vmap(jax.jacfwd(lambda y: inv_map.apply(params_final, None, y)))
    J = jac_fn(x)
    detJ = J[:, 0, 0] * J[:, 1, 1] - J[:, 0, 1] * J[:, 1, 0]
    out = jnp.sum(jnp.log(jnp.linalg.norm(J, axis=2)), axis=1) - jnp.log(jnp.abs(detJ))
    return out

In [None]:
c = cima(X)

In [None]:
jnp.mean(c)