# Constructing normalizing flows

Here, we demonstrate the usage of several inference funnels [@]

In [4]:
import distrax
import haiku as hk
from jax import numpy as jnp
from jax import random as jr

## How to construct a Haiku module

We begin with demonstrating how to construct any normalizing flow exemplified through a `Slice` layer [@]. We first demonstrate the code, before we explain what it does.

In [30]:
from surjectors import Slice, TransformedDistribution
from surjectors.nn import make_mlp

In [61]:
def make_flow(n_dimensions):
    def flow(**kwargs):
        def decoder_fn(n_dim):
            def _fn(z):
                params = make_mlp([4, n_dim * 2])(z)
                means, log_scales = jnp.split(params, 2, -1)
                return distrax.Independent(
                    distrax.Normal(means, jnp.exp(log_scales)),
                    reinterpreted_batch_ndims=1,
                )

            return _fn

        transform = Slice(n_dimensions // 2, decoder_fn(n_dimensions // 2))
        base_distribution = distrax.Independent(
            distrax.Normal(
                jnp.zeros(n_dimensions // 2), jnp.ones(n_dimensions // 2)
            ),
            reinterpreted_batch_ndims=1,
        )
        pushforward = TransformedDistribution(base_distribution, transform)
        return pushforward(**kwargs)

    td = hk.transform(flow)
    return td

Constructing a Haiku module needs to be done within a `hk.transform` block. This can either be done by providing a function like here and an object. In our case we are using `hk.transform` on `pushforward(**kwargs)` which calls
the `__call__` method of `TransformedDistribution`. Since we are generally interested to use to functions of a `TransformedDistribution`, `sample` and `log_prob`, the `__call__` function is implemented to dispatch on a method based on what is provided in `**kwargs`. More on that later.

We can now initialize the flow. Let's define a random data set first and then initialize the parameters.

In [65]:
rng_key_seq = hk.PRNGSequence(0)

n, p = 1000, 10
y = jr.normal(next(rng_key_seq), (n, p))

flow = make_flow(p)
params = flow.init(next(rng_key_seq), method="log_prob", y=y)
params

{'mlp/~/linear_0': {'w': Array([[ 0.00938151,  0.00138335,  0.0027178 ,  0.00052393],
         [ 0.00773658,  0.00142679, -0.00033992,  0.00513109],
         [ 0.01221598,  0.01620722,  0.009937  , -0.00184268],
         [-0.00059485,  0.00240193,  0.01265547, -0.00374089],
         [-0.0061826 , -0.01318036, -0.00686558, -0.01773127]],      dtype=float32),
  'b': Array([0., 0., 0., 0.], dtype=float32)},
 'mlp/~/linear_1': {'w': Array([[-0.01119876, -0.01138027,  0.0113112 , -0.01312024,  0.00154802,
           0.00816533, -0.00145846,  0.01340849, -0.01006911,  0.01952732],
         [ 0.01008801,  0.00135307, -0.0115126 , -0.00349136, -0.00175548,
           0.01259692, -0.01788977, -0.01747455, -0.00335201,  0.00770757],
         [-0.0051812 , -0.01671031, -0.01600558,  0.00488472,  0.00310151,
           0.01437935,  0.01761029, -0.00858634,  0.00084632,  0.01145467],
         [ 0.0065305 , -0.00679018, -0.00247612,  0.00194623, -0.00034392,
           0.01276916, -0.00199806, -0.00

The only trainable paramaters that are flow defines are the weights of the MLP. The MLP is used to compute the conditional probability density inside the `decoder_fn` function. 
The `Slice` surjector itself doesn't have paramters.

We can now test the flow. Let's sample some data first.

In [71]:
samples = flow.apply(
    params, next(rng_key_seq), method="sample", sample_shape=(2,)
)
samples

Array([[ 0.5136197 ,  0.4787564 , -0.79882836,  0.6568475 , -1.1301666 ,
         0.7758274 , -0.52150553, -0.72154933, -0.17268361,  1.2563365 ],
       [ 0.7778937 ,  0.6387316 ,  0.40471345, -0.10387196,  1.2993045 ,
         0.5106493 ,  1.2868667 , -0.14718308, -0.7556821 , -1.7023089 ]],      dtype=float32)

As mentioned above, in order to dispatch to a method, we just provide a keyword argument. In this case this is `method='sample'`. Computing the log probability of the data can be done, by changing the method argument to `log_prob`.

In [72]:
flow.apply(params, next(rng_key_seq), method="log_prob", y=samples)

Array([-12.110363, -13.330914], dtype=float32)

## How to construct `TransformedDistribution` objects

The `TransformedDistribution` class takes a base distribution and a transformation that latter of which is a `Slice` surjection. The base distribution can be pretty much every