In [1]:
import flax.nnx as nnx
import flax.typing as ftp
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from functools import partial
from jax_autovmap import auto_vmap
from einops import rearrange

import lfx

rngs = nnx.Rngs(default=0)

lfx.utils.load_shapes_magic()

# General coupling layers

In coupling layers, some subset of values are updated bijectively conditioned on the other values.

For some choices this can be implemented straight-forwardly.
In real NVP, for example, we can emply binary masks and do the transformation directly.

More abstractly, we may have a bijection $x \mapsto f(x; \theta)$ that updates the *active* values $x$.
To turn this into a coupling layer generally, we want to take $\theta(y)$ given passive values $y$.
We can do this manually, of course.

But there is also a convenience class implemented that:
- Converts any module (or state) into the map $\theta \mapsto (x \mapsto f(x; \theta))$
- Extracts the total size and shapes of parameters (such that $\theta$ can be one large array, a list of parameters, or a dict of parameters)
- Handle automatic vmap over batch dimension (since $\theta(y)$ would generally introduce a batch dimension)

[Note: for most (all?) scalar bijections the reconstructor is enough, even without auto_vmap, as parameter broadcasting is supported]

## Example: single spline

In [2]:
# transform three entries of an array independently -> shape (3,)
model = lfx.bijections.splines.MonotoneRQSpline(10, (3,), rngs=rngs)
model_factory = lfx.bijections.coupling.ModuleReconstructor(model)

In [3]:
# can provide single 1d array of parameters of this size
model_factory.params_total_size

np.int64(84)

In [4]:
# can also provide dictionary which matches this structure
model_factory.params_dict

{'heights': ShapedArray(float32[3,9]),
 'slopes': ShapedArray(float32[3,9]),
 'widths': ShapedArray(float32[3,10])}

In [5]:
# or "leaves" of the corresponding pytree
model_factory.params_leaves

[ShapedArray(float32[3,9]),
 ShapedArray(float32[3,9]),
 ShapedArray(float32[3,10])]

In [6]:
# some other convenience attributes also available
model_factory.params_dtypes

[dtype('float32'), dtype('float32'), dtype('float32')]

In [7]:
# some other convenience attributes also available
model_factory.params_shapes

[(3, 9), (3, 9), (3, 10)]

In [8]:
# some other convenience attributes also available
model_factory.params_shape_dict

{'heights': (3, 9), 'slopes': (3, 9), 'widths': (3, 10)}

In [9]:
# caveat: need to be careful if parameters are complex (not fully supported)
model_factory.has_complex_params

False

In [10]:
# dummy array of parameters (would usually be/depend on the output of some NN)
params_array = np.random.normal(size=(model_factory.params_total_size,))

# example inputs
x = np.ones((10, 3)) / 2
lp = np.zeros((10,))

model = model_factory.from_params(params_array)
%shapes model.forward(x, lp)

((10, 3), (10,))


Note that in the above example, the input is batched (which is fine because the model itself supports that),
but the parameters are not batched.
If we have batched parameters, we effectively have a "stack" of models.
Depending on the bijection, this may be nontrivial to handle (it is trivial for real NVP where the affine transformation is trivially broadcastable, but not e.g. for splines).

If we provide the optional argument `auto_vmap=True` to `from_params`, the object effectively acts like an instance of `model`, except that function calls add `params` as first argument and automatically vmap over these and the inputs.

In addition, the ranks (i.e. dimensions) of the inputs need to be provided to know how vmap should be applied.
This defaults to $(0, 0)$ which matches the many NF cases (vector state and scalar log-density), but the splines are actually implemented to assume

In [11]:
# add batch dimension to parameters
p = np.random.normal(size=(10, model_factory.params_total_size,))
rec = model_factory.from_params(p, auto_vmap=True)

# spline above was initialized to expect 1D inputs, not scalar, so need to specify
%shapes rec.forward(x, lp, input_ranks=(1, 0))

((10, 3), (10,))


In [12]:
# don't really need to wrap this in jax.jit, but in prinicple avoids
# allocating the parameters that are not used in the end
@partial(jax.jit, static_argnums=(0, 1))
def get_spline_template(size, knots):
     # output not numerical, doesn't matter here
    rngs = nnx.Rngs(0)

    # define bijection
    spline = lfx.bijections.splines.MonotoneRQSpline(knots, (size,), rngs=rngs)

    # get reconstructor map params -> bijection
    return lfx.bijections.coupling.ModuleReconstructor(spline)

In [13]:
class SplineCouplingLayer(lfx.Bijection):
    def __init__(self, mask, embeds, depth, width, rngs):

        spline_template = get_spline_template(mask.sum(), 10)

        resnet = lfx.nn.simple_nets.SimpleResNet(
            embeds, spline_template.params_total_size, width, depth,
            final_kernel_init=nnx.initializers.normal(),
            final_bias_init=nnx.initializers.zeros,
            rngs=rngs,
        )
        param_net = nnx.Sequential(
            nnx.Linear(mask.sum(), embeds, rngs=rngs),
            jnp.sin,
            resnet,
        )

        mask = lfx.BinaryMask.from_boolean_mask(mask)

        self.bijection = lfx.GeneralCouplingLayer(
            param_net, mask, spline_template,
            # by default assume bijections are scalar, but here the splines
            # operate (independently) on entries of a 1D array
            bijection_event_rank=1,
        )

    def forward(self, x, log_density, **kwargs):
        return self.bijection.forward(x, log_density, input_ranks=(1, 0), **kwargs)

    def reverse(self, x, log_density, **kwargs):
        return self.bijection.reverse(x, log_density, input_ranks=(1, 0), **kwargs)

In [14]:
mask = np.array([True, False])

layers = []

for _ in range(5):
    layers.append(SplineCouplingLayer(mask, 12, 2, 32, rngs=rngs))
    mask = ~mask

model = lfx.Chain(*layers)

In [15]:
%shapes model.forward(jnp.ones((10, 2)), jnp.zeros((10,)))

((10, 2), (10,))


Note that above, the fundamental transformation of splines was taken to operate on a whole array.
In principle, it need not have applied independent operations per entry, which shows the generality of the coupling layer implementation.
However, in this specific case, splines do act element wise.
Some bijections in fact may only be defined on single scalar values.
Below is the same example as above, but treating the "array" index as part of the "batch" index.

In [16]:
knots = 11
spline_template = lfx.ModuleReconstructor(
      lfx.MonotoneRQSpline(knots, rngs=nnx.Rngs(0))
)

def spline_coupling_layer(mask, embeds, depth, width, rngs):

        mask = lfx.BinaryMask.from_boolean_mask(mask)

        count_active, count_passive = mask.counts
        param_count = spline_template.params_total_size

        resnet = lfx.nn.simple_nets.SimpleResNet(
            embeds, count_active *param_count, width, depth,
            final_kernel_init=nnx.initializers.normal(),
            final_bias_init=nnx.initializers.zeros,
            rngs=rngs,
        )

        param_net = nnx.Sequential(
            nnx.Linear(count_passive, embeds, rngs=rngs),
            jnp.sin,
            resnet,
            lambda x: rearrange(x, '... (t b) -> ... t b', t=count_active),
        )


        return lfx.GeneralCouplingLayer(
            param_net,
            mask,
            spline_template,
        )

In [17]:
mask = np.array([True, False])

layers = []

for _ in range(5):
    layers.append(spline_coupling_layer(mask, 12, 2, 32, rngs=rngs))
    mask = ~mask

model = lfx.Chain(*layers)

In [18]:
model.forward(jnp.ones((10, 2)), jnp.zeros((10,)))[0].shape

(10, 2)

In [19]:
class ManualAffineCoupling(lfx.Bijection):
    """
    Affine coupling layer.

    Masking here is done by multiplication, not by indexing.

    Example:
        ```python
        space_shape = (16, 16)  # no channel/feature dim (add dummy axis below)

        affine_flow = lfx.Chain([
            lfx.ExpandDims(),
            lfx.AffineCoupling(
                lfx.SimpleConvNet(1, 2, rngs=rngs),
                lfx.checker_mask(space_shape + (1,), True)),
            lfx.AffineCoupling(
                lfx.SimpleConvNet(1, 2, rngs=rngs),
                lfx.checker_mask(space_shape + (1,), False)),
            lfx.ExpandDims().invert(),
        ])
        ```

    `net` should map: `x_f -> act`
    such that `s, t = split(act, 2, -1)`
    and `x_out = t + x_a * exp(s) + x_f`

    Args:
        net: Network that maps frozen features to s, t.
        mask: BinaryMask to apply to input.
    """

    def __init__(self, net: nnx.Module, mask: lfx.BinaryMask, *, rngs=None):
        self.mask = mask
        self.net = net

    @property
    def mask_active(self):
        return self.mask

    @property
    def mask_frozen(self):
        return ~self.mask

    def forward(self, x, log_density):
        x_frozen = self.mask_frozen * x
        x_active = self.mask_active * x
        activation = self.net(x_frozen)
        s, t = jnp.split(activation, 2, -1)

        fx = x_frozen + (self.mask_active * t) + x_active * jnp.exp(s)
        axes = tuple(range(-len(self.mask.event_shape), 0))
        log_jac = jnp.sum((self.mask_active * s), axis=axes)
        return fx, log_density - log_jac

    def reverse(self, fx, log_density):
        fx_frozen = self.mask_frozen * fx
        fx_active = self.mask_active * fx
        activation = self.net(fx_frozen)
        s, t = jnp.split(activation, 2, -1)
        x = (fx_active - (self.mask_active * t)) * jnp.exp(-s) + fx_frozen
        axes = tuple(range(-len(self.mask.event_shape), 0))
        log_jac = jnp.sum((self.mask_active * s), axis=axes)
        return x, log_density + log_jac

In [20]:
space_shape = (16, 16)  # no channel/feature dim (add dummy axis below)

affine_flow_1 = lfx.Chain(
    lfx.ExpandDims(),
    ManualAffineCoupling(
        lfx.SimpleConvNet(1, 2, rngs=rngs),
        lfx.checker_mask(space_shape + (1,), True)),
    ManualAffineCoupling(
        lfx.SimpleConvNet(1, 2, rngs=rngs),
        lfx.checker_mask(space_shape + (1,), False)),
    lfx.ExpandDims().invert(),
)

In [21]:
%shapes affine_flow_1.forward(np.ones((2, 16, 16)), jnp.zeros((2,)))

((2, 16, 16), (2,))


In [22]:
def masked_affine_coupling(
    net,
    mask,
    template=lfx.ModuleReconstructor(lfx.LinearAffine(rngs=nnx.Rngs(0)))
):
    if not isinstance(mask, lfx.BinaryMask):
        mask = lfx.BinaryMask.from_boolean_mask(mask)

    return lfx.GeneralCouplingLayer(net, mask, template, split=False)


In [23]:
affine_flow = lfx.Chain(
    masked_affine_coupling(
        nnx.Sequential(lambda x: x[..., None], lfx.SimpleConvNet(1, 2, rngs=rngs)),
        lfx.checker_mask(space_shape, True)),
    masked_affine_coupling(
        nnx.Sequential(lambda x: x[..., None], lfx.SimpleConvNet(1, 2, rngs=rngs)),
        lfx.checker_mask(space_shape, False)),
)

In [24]:
%shapes affine_flow.forward(np.ones((1, 16, 16)), jnp.zeros((1,)))

((1, 16, 16), (1,))
