# Simple implementation for rectangular normalizing flow

In [None]:
import numpy as np
import jax
from jax import numpy as jnp

import distrax
import haiku as hk

from ima.upsampling import Pad

from jax.experimental.optimizers import adam

from tqdm import tqdm
from matplotlib import pyplot as plt

In [None]:
key = jax.random.PRNGKey(1)

In [None]:
# Generate training data

d = 10
D = 1000
N = 10000

key, subkey = jax.random.split(key)
sources = jax.random.normal(subkey, shape=(N, d))

key, subkey = jax.random.split(key)
A = jax.random.normal(subkey, shape=(D, d)) / np.sqrt(d)

key, subkey = jax.random.split(key)
mv = lambda m, v: jnp.matmul(m, v)
mbv = jax.vmap(mv, (None, 0), 0)
observations = mbv(A, sources) + 0.2 * jax.random.normal(subkey, shape=(N, D))

In [None]:
# Define Real NVP flow with Distrax
def mk_flow(K = 16, nl = 2, hu = 256):
    pad = Pad((0, D - d))
    layers = []
    for i in range(K):
        mlp = hk.Sequential([hk.nets.MLP(nl * (hu,), activate_final=True),
                             hk.Linear(D, w_init=jnp.zeros, b_init=jnp.zeros)])
        def bij_fn(params):
            bij = distrax.ScalarAffine(shift=params[..., :D // 2], log_scale=params[..., D // 2:])
            return distrax.Block(bij, 1)
        layers.append(distrax.SplitCoupling(D // 2, 1, mlp, bij_fn, swap=bool(i % 2)))
    flow = distrax.Chain(layers)
    return (pad, flow)

def fwd_(x):
    pad, flow = mk_flow()
    
    x = pad.forward(x)
    return flow.forward(x)

def inv_(x):
    pad, flow = mk_flow()
    
    x = flow.inverse(x)
    return pad.inverse(x)

In [None]:
key, subkey = jax.random.split(key)
fwd = hk.transform(fwd_)
inv = hk.transform(inv_)
params = fwd.init(subkey, jnp.array(np.random.randn(5, d)))

In [None]:
# Loss function

def loss_(args):
    x, lam, beta = args
    pad, flow = mk_flow()
    
    fwd = lambda y: flow.forward(pad.forward(y))
    inv = lambda y: pad.inverse(flow.inverse(y))
    
    base_dist = distrax.Independent(distrax.Normal(loc=jnp.zeros(d), scale=jnp.ones(d)),
                                    reinterpreted_batch_ndims=1)
    
    jac_fn = jax.vmap(jax.jacfwd(fwd))
    
    z = inv(x)
    jac = jac_fn(z)
    
    jj = jax.lax.batch_matmul(jnp.transpose(jac, (0, 2, 1)), jac)
    chol = jax.vmap(jax.scipy.linalg.cholesky)(jj)
    log_det = jnp.sum(jnp.log(jax.vmap(jnp.diag)(chol)), -1)
    
    diff = jnp.mean((x - fwd(z)) ** 2)
    
    return jnp.mean(-lam * (base_dist.log_prob(z) - log_det) + beta * diff)
    

In [None]:
key, subkey = jax.random.split(key)
loss = hk.transform(loss_)
params = loss.init(subkey, (jnp.array(np.random.randn(5, D)), 1., 1.))

In [None]:
b = jnp.array(np.random.randn(5, D))

In [None]:
loss.apply(params, None, (b, 1., 1.))

In [None]:
lr = 1.e-3

opt_init, opt_update, get_params = adam(step_size=lr)
opt_state = opt_init(params)

In [None]:
@jax.jit
def step(it_, opt_state_, x_, lam_, beta_):
    params_ = get_params(opt_state_)
    value, grads = jax.value_and_grad(loss.apply, 0)(params_, None, (x_, lam_, beta_))
    opt_state_ = opt_update(it_, grads, opt_state_)
    return value, opt_state_

In [None]:
num_iter = 1000
lam_int = [40, 2000]
batch_size = 256
beta = 20.

loss_hist = np.zeros((0, 2))

In [None]:
for it in tqdm(range(num_iter)):
    x = observations[np.random.choice(N, batch_size)]
    
    # Need to warm up lambda due to stability issues
    lam = np.interp(it, lam_int, [0, 1])
    
    loss_val, opt_state = step(it, opt_state, x, lam, beta)
    
    loss_append = np.array([[it + 1, loss_val.item()]])
    loss_hist = np.concatenate([loss_hist, loss_append])

In [None]:
plt.figure(figsize=(15, 10))
plt.plot(loss_hist[:, 0], loss_hist[:, 1])
plt.show()