In [1]:
import flax.nnx as nnx
import flax.typing as ftp
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from functools import partial
from jax_autovmap import auto_vmap

import lfx

rngs = nnx.Rngs(default=0)

# General coupling layers

In coupling layers, some subset of values are updated bijectively conditioned on the other values.

For some choices this can be implemented straight-forwardly.
In real NVP, for example, we can emply binary masks and do the transformation directly.

More abstractly, we may have a bijection $x \mapsto f(x; \theta)$ that updates the *active* values $x$.
To turn this into a coupling layer generally, we want to take $\theta(y)$ given passive values $y$.
We can do this manually, of course.

But there is also a convenience class implemented that:
- Converts any module (or state) into the map $\theta \mapsto (x \mapsto f(x; \theta))$
- Extracts the total size and shapes of parameters (such that $\theta$ can be one large array, a list of parameters, or a dict of parameters)
- Handle automatic vmap over batch dimension (since $\theta(y)$ would generally introduce a batch dimension)

## Example: single spline

In [2]:
model = lfx.bijections.splines.MonotoneRQSpline(3, 10, rngs=rngs)
model_factory = lfx.bijections.coupling.ModuleReconstructor(model)

In [3]:
# can provide single 1d array of parameters of this size
model_factory.params_total_size

np.int64(87)

In [4]:
# can also provide dictionary which matches this structure
model_factory.params_dict

{'heights': ShapedArray(float32[3,10]),
 'slopes': ShapedArray(float32[3,9]),
 'widths': ShapedArray(float32[3,10])}

In [5]:
# or "leaves" of the corresponding pytree
model_factory.params_leaves

[ShapedArray(float32[3,10]),
 ShapedArray(float32[3,9]),
 ShapedArray(float32[3,10])]

In [6]:
# some other convenience attributes also available
model_factory.params_dtypes

[dtype('float32'), dtype('float32'), dtype('float32')]

In [7]:
# some other convenience attributes also available
model_factory.params_shapes

[(3, 10), (3, 9), (3, 10)]

In [8]:
# some other convenience attributes also available
model_factory.params_shape_dict

{'heights': (3, 10), 'slopes': (3, 9), 'widths': (3, 10)}

In [9]:
# caveat: need to be careful if parameters are complex (not fully supported)
model_factory.has_complex_params

False

In [10]:
# dummy array of parameters (would usually be/depend on the output of some NN)
params_array = np.random.normal(size=(model_factory.params_total_size,))

# example inputs
x = np.ones((10, 3)) / 2
lp = np.zeros((10,))

model = model_factory.from_array(params_array)
model.forward(x, lp)

(Array([[0.44812006, 0.74798065, 0.3836273 ],
        [0.44812006, 0.74798065, 0.3836273 ],
        [0.44812006, 0.74798065, 0.3836273 ],
        [0.44812006, 0.74798065, 0.3836273 ],
        [0.44812006, 0.74798065, 0.3836273 ],
        [0.44812006, 0.74798065, 0.3836273 ],
        [0.44812006, 0.74798065, 0.3836273 ],
        [0.44812006, 0.74798065, 0.3836273 ],
        [0.44812006, 0.74798065, 0.3836273 ],
        [0.44812006, 0.74798065, 0.3836273 ]], dtype=float32),
 Array([1.6312416, 1.6312416, 1.6312416, 1.6312416, 1.6312416, 1.6312416,
        1.6312416, 1.6312416, 1.6312416, 1.6312416], dtype=float32))

Note that in the above example, the input is batched (which is fine because the model itself supports that),
but the parameters are not batched.
If we have batched parameters, we effectively have a "stack" of models.
Depending on the bijection, this may be nontrivial to handle (it is trivial for real NVP where the affine transformation is trivially broadcastable, but not e.g. for splines).

`.auto_vmap_array`, `.auto_vmap_dict` and `.auto_vmap_leaves` effectively act like an instance of `model`, except that function calls add `params` as first argument and automatically vmap over these and the inputs.

In addition, the ranks (i.e. dimensions) of the inputs need to be provided to know how vmap should be applied.
This defaults to $(1, 0)$ which matches many NF cases (vector state and scalar log-density).

In [11]:
# add batch dimension to parameters
p = np.random.normal(size=(10, model_factory.params_total_size))

model_factory.auto_vmap_array.forward(p, x, lp)

(Array([[0.41064703, 0.01066445, 0.2767329 ],
        [0.37461302, 0.7361562 , 0.30834916],
        [0.3074231 , 0.5192351 , 0.39163208],
        [0.12614258, 0.47200948, 0.53802824],
        [0.25864998, 0.53435785, 0.6444026 ],
        [0.32224256, 0.8636802 , 0.35441762],
        [0.43365756, 0.91749805, 0.17344713],
        [0.7867354 , 0.49086675, 0.39499205],
        [0.5754057 , 0.31122735, 0.48212758],
        [0.67516726, 0.19719765, 0.654301  ]], dtype=float32),
 Array([1.0165834e+01, 5.0352812e-03, 1.0950508e+00, 6.2632942e+00,
        6.2445688e+00, 7.2555504e+00, 4.3614316e+00, 7.4305475e-02,
        1.3601456e+01, 5.1075807e+00], dtype=float32))

In [12]:
# explicitly provide input ranks
autovmap_model = model_factory.auto_vmap_array

autovmap_model.forward(p, x, lp, input_ranks=(1, 0))

(Array([[0.41064703, 0.01066445, 0.2767329 ],
        [0.37461302, 0.7361562 , 0.30834916],
        [0.3074231 , 0.5192351 , 0.39163208],
        [0.12614258, 0.47200948, 0.53802824],
        [0.25864998, 0.53435785, 0.6444026 ],
        [0.32224256, 0.8636802 , 0.35441762],
        [0.43365756, 0.91749805, 0.17344713],
        [0.7867354 , 0.49086675, 0.39499205],
        [0.5754057 , 0.31122735, 0.48212758],
        [0.67516726, 0.19719765, 0.654301  ]], dtype=float32),
 Array([1.0165834e+01, 5.0352812e-03, 1.0950508e+00, 6.2632942e+00,
        6.2445688e+00, 7.2555504e+00, 4.3614316e+00, 7.4305475e-02,
        1.3601456e+01, 5.1075807e+00], dtype=float32))

In [13]:
# don't really need to wrap this in jax.jit, but in prinicple avoids
# allocating the parameters that are not used in the end
@partial(jax.jit, static_argnums=(0, 1))
def get_spline_template(size, knots):
     # output not numerical, doesn't matter here
    rngs = nnx.Rngs(0)

    # define bijection
    spline = lfx.bijections.splines.MonotoneRQSpline(size, knots, rngs=rngs)

    # get reconstructor map params -> bijection
    return lfx.bijections.coupling.ModuleReconstructor(spline)

In [14]:
spline_template = get_spline_template(1, 10)

In [None]:
class SplineCouplingLayer(lfx.Bijection):
    def __init__(self, mask, embeds, depth, width, rngs):
        self.embeds = nnx.Linear(mask.sum(), embeds, rngs=rngs)
        self.spline_template = get_spline_template(mask.sum(), 10)
        self.couplings = lfx.nn.simple_nets.SimpleResNet(
            embeds, self.spline_template.params_total_size, width, depth,
            final_kernel_init=nnx.initializers.normal(),
            final_bias_init=nnx.initializers.zeros,
            rngs=rngs)
        self.mask = mask

    def split(self, x):
        return x[..., self.mask], x[..., ~self.mask]

    def forward(self, x, log_density, **kwargs):
        active, passive = self.split(x)
        h = jnp.sin(self.embeds(passive))
        params = self.couplings(h)
        active, log_density = self.spline_template.auto_vmap_array.forward(
            params, active, log_density)
        x = x.at[..., self.mask].set(active)
        return x, log_density

    def reverse(self, x, log_density, **kwargs):
        active, passive = self.split(x)
        h = jnp.sin(self.embeds(passive))
        params = self.couplings(h)
        active, log_density = self.spline_template.auto_vmap_array.reverse(
            params, active, log_density)
        x = x.at[..., self.mask].set(active)
        return x, log_density


In [16]:
mask = np.array([True, False])

layers = []

for _ in range(5):
    layers.append(SplineCouplingLayer(mask, 12, 2, 32, rngs=rngs))
    mask = ~mask

model = lfx.Chain(*layers)

In [17]:
model.forward(jnp.ones((10, 2)), jnp.zeros((10,)))

(Array([[0.9999998, 1.0000001],
        [0.9999998, 1.0000001],
        [0.9999998, 1.0000001],
        [0.9999998, 1.0000001],
        [0.9999998, 1.0000001],
        [0.9999998, 1.0000001],
        [0.9999998, 1.0000001],
        [0.9999998, 1.0000001],
        [0.9999998, 1.0000001],
        [0.9999998, 1.0000001]], dtype=float32),
 Array([-8.821487e-06, -8.821487e-06, -8.821487e-06, -8.821487e-06,
        -8.821487e-06, -8.821487e-06, -8.821487e-06, -8.821487e-06,
        -8.821487e-06, -8.821487e-06], dtype=float32))