In [1]:
from typing import Optional

import distrax

In [34]:
import jax
import jax.numpy as jnp
import jax.random as jr
from flowjax.train import fit_to_data
from flowjax.train.losses import MaximumLikelihoodLoss

In [35]:
import math
from functools import partial
from typing import Callable, Optional, Sequence

import equinox as eqx
import jax
from jax.random import PRNGKey

In [19]:
class MLPConditioner(eqx.Module):
    """Multi-Layer Perceptron (MLP) Conditioner using Equinox.

    This conditioner takes the flow input, passes it through an MLP, and produces
    parameters for the bijector as required. Typically used with Masked Coupling Bijector
    to ensure the lower triangular Jacobian is preserved.
    """

    mlp: eqx.nn.MLP

    def __init__(
        self,
        output_dim: int,
        hidden_sizes: Sequence[int],
        num_bijector_params: int,
        key: PRNGKey,
        activation: Callable = jax.nn.relu,
        final_activation: Callable = lambda x: x,
        name: Optional[str] = None,
    ):
        super().__init__()
        depth = len(hidden_sizes) + 1  # Including the output layer
        width_size = (
            hidden_sizes[0] if hidden_sizes else output_dim
        )  # Handle depth=1 case

        # The last layer's size is determined by output_dim and num_bijector_params
        out_size = output_dim * num_bijector_params
        # MLP setup
        self.mlp = eqx.nn.MLP(
            in_size="scalar",  # Assuming input to MLPConditioner is already flattened
            out_size=out_size,  # Output size is tailored for the bijector parameterization
            width_size=width_size,
            depth=depth,
            activation=activation,
            final_activation=final_activation,
            use_bias=True,
            use_final_bias=True,
            key=key,
        )

    def __call__(self, x, *, key: Optional[PRNGKey] = None):
        # Assuming x has shape (batch_size, output_dim), and we're flattening it
        # to work with Equinox's MLP expecting scalar inputs. Adjust as necessary.
        x = x.reshape(
            (x.shape[0], -1)
        )  # Flatten if necessary, preserving batch dimension
        out = self.mlp(x, key=key)  # Key is optional, here for compatibility
        # Reshaping the output to have shape (batch_size, output_dim, num_bijector_params)
        out = out.reshape((x.shape[0], -1, num_bijector_params))
        return out

In [20]:
MLPConditioner(
    key=PRNGKey(9),
    output_dim=math.prod((2,)),
    hidden_sizes=[5, 5, 5],
    num_bijector_params=3 * 10 + 1,
    name="conditioner_phi",
)

MLPConditioner(
  mlp=MLP(
    layers=(
      Linear(
        weight=f32[5,1],
        bias=f32[5],
        in_features='scalar',
        out_features=5,
        use_bias=True
      ),
      Linear(
        weight=f32[5,5],
        bias=f32[5],
        in_features=5,
        out_features=5,
        use_bias=True
      ),
      Linear(
        weight=f32[5,5],
        bias=f32[5],
        in_features=5,
        out_features=5,
        use_bias=True
      ),
      Linear(
        weight=f32[5,5],
        bias=f32[5],
        in_features=5,
        out_features=5,
        use_bias=True
      ),
      Linear(
        weight=f32[62,5],
        bias=f32[62],
        in_features=5,
        out_features=62,
        use_bias=True
      )
    ),
    activation=<wrapped function relu>,
    final_activation=<function <lambda>>,
    use_bias=True,
    use_final_bias=True,
    in_size='scalar',
    out_size=62,
    width_size=5,
    depth=4
  )
)

In [21]:
def NSF(
    key: PRNGKey,
    phi_dim: int,
    num_layers: int,
    hidden_sizes: Sequence[int],
    num_bins: int,
    range_min: float = 0.0,
    range_max: float = 1.0,
    **_,
) -> distrax.Transformed:
    """Creates the Rational Quadratic Flow model.

    Args:
    range_min: the lower bound of the spline's range. Below `range_min`, the
      bijector defaults to a linear transformation.
    range_max: the upper bound of the spline's range. Above `range_max`, the
      bijector defaults to a linear transformation.
    """

    flow_dim = phi_dim

    event_shape = (flow_dim,)

    flow_layers = []

    # Number of parameters required by the bijector (rational quadratic spline)
    num_bijector_params = 3 * num_bins + 1

    def bijector_fn(params):
        return distrax.RationalQuadraticSpline(
            params, range_min=range_min, range_max=range_max
        )

    # Alternating binary mask.
    mask = jnp.arange(0, math.prod(event_shape)) % 2
    mask = jnp.reshape(mask, event_shape)
    mask = mask.astype(bool)

    # Number of parameters for the rational-quadratic spline:
    # - `num_bins` bin widths
    # - `num_bins` bin heights
    # - `num_bins + 1` knot slopes
    # for a total of `3 * num_bins + 1` parameters.

    for _ in range(num_layers):
        layer = distrax.MaskedCoupling(
            mask=mask,
            bijector=bijector_fn,
            conditioner=MLPConditioner(
                key=key,
                output_dim=math.prod(event_shape),
                hidden_sizes=hidden_sizes,
                num_bijector_params=num_bijector_params,
                name="conditioner_phi",
            ),
        )
        flow_layers.append(layer)
        # Flip the mask after each layer.
        mask = jnp.logical_not(mask)

    # Last layer: Map values to parameter domain
    # phi goes to [0,1]
    flow_layers.append(distrax.Block(distrax.Sigmoid(), 1))

    flow = distrax.Chain(flow_layers[::-1])

    # base_distribution = distrax.Independent(
    #     distrax.Uniform(low=jnp.zeros(event_shape), high=jnp.ones(event_shape)),
    #     reinterpreted_batch_ndims=len(event_shape))

    base_distribution = distrax.MultivariateNormalDiag(
        loc=jnp.zeros(event_shape), scale_diag=jnp.ones(event_shape)
    )

    return distrax.Transformed(base_distribution, flow)

In [22]:
nvars = 2
key, x_key = jr.split(jr.PRNGKey(0))
x = jr.normal(x_key, shape=(5000, nvars))
# x = jr.beta(x_key, a=0.4, b=0.4, shape=(5000, nvars))

In [23]:
key, subkey = jr.split(jr.PRNGKey(0))

# Create the flow
untrained_flow = NSF(
    key=subkey,
    phi_dim=2,  # x dim
    num_layers=8,
    hidden_sizes=[5, 5, 5],
    num_bins=10,
    range_min=0.0,
    range_max=1.0,
)

nsf_constructor = partial(
    NSF,
    phi_dim=2,  # Example configuration; adjust as needed
    num_layers=3,
    hidden_sizes=[64, 64],
    num_bins=10,
    range_min=0.0,
    range_max=1.0,
    # Include other parameters as required by your NSF model
)

In [16]:
untrained_flow

<distrax._src.distributions.transformed.Transformed at 0x1899fff50>

In [31]:
import optax
from tqdm.auto import tqdm


def fit_to_data(
    key: jnp.ndarray,
    nsf_constructor: Callable,
    x: jnp.ndarray,
    *,
    condition: Optional[jnp.ndarray] = None,
    max_epochs: int = 100,
    batch_size: int = 100,
    learning_rate: float = 5e-4,
    optimizer: Optional[optax.GradientTransformation] = None,
    return_best: bool = True,
    show_progress: bool = True,
):
    # Assuming the NSF constructor accepts a PRNGKey and returns an initialized flow
    nsf_flow = nsf_constructor(key=key)

    if optimizer is None:
        optimizer = optax.adam(learning_rate)

    loss_fn = MaximumLikelihoodLoss()

    params, static = eqx.partition(nsf_flow, eqx.is_array)
    opt_state = optimizer.init(params)
    best_params, best_loss = params, jnp.inf

    for epoch in tqdm(range(max_epochs), disable=not show_progress):
        # Shuffle the data at the beginning of each epoch
        indices = jr.permutation(key, len(x))
        x_shuffled = x[indices]

        epoch_loss = []
        for i in range(0, len(x), batch_size):
            batch_x = x_shuffled[i : i + batch_size]
            batch_cond = (
                condition[i : i + batch_size] if condition is not None else None
            )

            # Gradient update step
            grads, batch_loss = jax.value_and_grad(loss_fn, argnums=0)(
                params, static, batch_x, batch_cond
            )
            updates, opt_state = optimizer.update(grads, opt_state)
            params = optax.apply_updates(params, updates)

            epoch_loss.append(batch_loss)

        epoch_loss = jnp.mean(jnp.stack(epoch_loss))
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            best_params = params

        tqdm.write(f"Epoch {epoch+1}, Loss: {epoch_loss}")

    if return_best:
        trained_flow = eqx.combine(static, best_params)
    else:
        trained_flow = eqx.combine(static, params)

    return trained_flow, best_loss

In [32]:
untrained_flow

<distrax._src.distributions.transformed.Transformed at 0x18a567f50>

In [33]:
key, subkey = jr.split(key)

# Train on the unbounded space
flow, losses = fit_to_data(
    key=subkey,
    nsf_constructor=nsf_constructor,
    x=x,
    learning_rate=5e-4,
    max_epochs=70,
    return_best=False,
    batch_size=5000,
)

ValueError: Mismatch custom node data: ([], [], PyTreeDef({'_batch_shape': None, '_bijector': CustomNode(Chain[([1, 1, False, False], [False, False, False, False], PyTreeDef({'_bijectors': [CustomNode(Block[([1, 1, False, False, 1], [False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(Sigmoid[([0, 0, False, False], [False, False, False, False], PyTreeDef({'_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *}))], [None, None, None, None]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [None, None, None, None, None]), CustomNode(MaskedCoupling[([<function NSF.<locals>.bijector_fn at 0x18a30aca0>, True, True, True, True, True, True, True, True, <jax._src.custom_derivatives.custom_jvp object at 0x14fcaea50>, <function MLPConditioner.<lambda> at 0x18a57e700>, True, 1, 1, 1, 0, False, False, True], [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], PyTreeDef({'_bijector': *, '_conditioner': CustomNode(MLPConditioner[('mlp',), (), ()], [CustomNode(MLP[('layers', 'activation', 'final_activation'), ('use_bias', 'use_final_bias', 'in_size', 'out_size', 'width_size', 'depth'), (True, True, 'scalar', 62, 64, 3)], [(CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), ('scalar', 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 62, True)], [*, *])), *, *])]), '_event_mask': *, '_event_ndims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_inner_event_ndims': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_mask': *}))], [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]), CustomNode(MaskedCoupling[([<function NSF.<locals>.bijector_fn at 0x18a30aca0>, True, True, True, True, True, True, True, True, <jax._src.custom_derivatives.custom_jvp object at 0x14fcaea50>, <function MLPConditioner.<lambda> at 0x18a57e700>, True, 1, 1, 1, 0, False, False, True], [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], PyTreeDef({'_bijector': *, '_conditioner': CustomNode(MLPConditioner[('mlp',), (), ()], [CustomNode(MLP[('layers', 'activation', 'final_activation'), ('use_bias', 'use_final_bias', 'in_size', 'out_size', 'width_size', 'depth'), (True, True, 'scalar', 62, 64, 3)], [(CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), ('scalar', 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 62, True)], [*, *])), *, *])]), '_event_mask': *, '_event_ndims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_inner_event_ndims': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_mask': *}))], [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]), CustomNode(MaskedCoupling[([<function NSF.<locals>.bijector_fn at 0x18a30aca0>, True, True, True, True, True, True, True, True, <jax._src.custom_derivatives.custom_jvp object at 0x14fcaea50>, <function MLPConditioner.<lambda> at 0x18a57e700>, True, 1, 1, 1, 0, False, False, True], [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], PyTreeDef({'_bijector': *, '_conditioner': CustomNode(MLPConditioner[('mlp',), (), ()], [CustomNode(MLP[('layers', 'activation', 'final_activation'), ('use_bias', 'use_final_bias', 'in_size', 'out_size', 'width_size', 'depth'), (True, True, 'scalar', 62, 64, 3)], [(CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), ('scalar', 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 62, True)], [*, *])), *, *])]), '_event_mask': *, '_event_ndims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_inner_event_ndims': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_mask': *}))], [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None])], '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *}))], [None, None, None, None]), '_distribution': CustomNode(MultivariateNormalDiag[([dtype('float32'), 2, True, True], [False, False, False, False], PyTreeDef({'_batch_shape': (), '_bijector': CustomNode(Chain[([1, 1, True, True], [False, False, False, False], PyTreeDef({'_bijectors': [CustomNode(Block[([1, 1, True, True, 1], [False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(Shift[([2, 0, 0, True, True, True], [False, False, False, False, False, False], PyTreeDef({'_batch_shape': (*,), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_shift': *}))], [None, None, None, None, None, None]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [None, None, None, None, None]), CustomNode(DiagLinear[([True, dtype('float32'), 2, 1, 1, True, True, <bound method Block.forward of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.forward_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.inverse of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.inverse_and_log_det of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.inverse_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>], [False, False, False, False, False, False, False, False, False, False, False, False], PyTreeDef({'_batch_shape': (), '_bijector': CustomNode(Block[([1, 1, True, True, 1], [False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(ScalarAffine[([2, 0, 0, True, True, True, True, True, 0.0], [False, False, False, False, False, False, False, False, False], PyTreeDef({'_batch_shape': (*,), '_event_ndims_in': *, '_event_ndims_out': *, '_inv_scale': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_log_scale': *, '_scale': *, '_shift': *}))], [None, None, None, None, None, None, None, None, None]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [None, None, None, None, None]), '_diag': *, '_dtype': *, '_event_dims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, 'forward': *, 'forward_log_det_jacobian': *, 'inverse': *, 'inverse_and_log_det': *, 'inverse_log_det_jacobian': *}))], [None, None, None, None, None, None, None, None, None, None, None, None])], '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *}))], [None, None, None, None]), '_distribution': CustomNode(Independent[([1], [False], PyTreeDef({'_distribution': CustomNode(Normal[([True, True], [False, False], PyTreeDef({'_loc': *, '_scale': *}))], [None, None]), '_reinterpreted_batch_ndims': *}))], [None]), '_dtype': *, '_event_shape': (*,), '_loc': *, '_scale': CustomNode(DiagLinear[([True, dtype('float32'), 2, 1, 1, True, True, <bound method Block.forward of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.forward_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.inverse of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.inverse_and_log_det of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.inverse_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>], [False, False, False, False, False, False, False, False, False, False, False, False], PyTreeDef({'_batch_shape': (), '_bijector': CustomNode(Block[([1, 1, True, True, 1], [False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(ScalarAffine[([2, 0, 0, True, True, True, True, True, 0.0], [False, False, False, False, False, False, False, False, False], PyTreeDef({'_batch_shape': (*,), '_event_ndims_in': *, '_event_ndims_out': *, '_inv_scale': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_log_scale': *, '_scale': *, '_shift': *}))], [None, None, None, None, None, None, None, None, None]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [None, None, None, None, None]), '_diag': *, '_dtype': *, '_event_dims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, 'forward': *, 'forward_log_det_jacobian': *, 'inverse': *, 'inverse_and_log_det': *, 'inverse_log_det_jacobian': *}))], [None, None, None, None, None, None, None, None, None, None, None, None]), '_scale_diag': *}))], [None, None, None, None]), '_dtype': None, '_event_shape': None})) != ([None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None], [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True], PyTreeDef({'_batch_shape': None, '_bijector': CustomNode(Chain[([None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, 1, 1, False, False], [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False], PyTreeDef({'_bijectors': [CustomNode(Block[([1, 1, False, False, 1], [False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(Sigmoid[([0, 0, False, False], [False, False, False, False], PyTreeDef({'_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *}))], [None, None, None, None]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [None, None, None, None, None]), CustomNode(MaskedCoupling[([<function NSF.<locals>.bijector_fn at 0x18a30aca0>, None, None, None, None, None, None, None, None, <jax._src.custom_derivatives.custom_jvp object at 0x14fcaea50>, <function MLPConditioner.<lambda> at 0x18a57e700>, None, 1, 1, 1, 0, False, False, None], [False, True, True, True, True, True, True, True, True, False, False, True, False, False, False, False, False, False, True], PyTreeDef({'_bijector': *, '_conditioner': CustomNode(MLPConditioner[('mlp',), (), ()], [CustomNode(MLP[('layers', 'activation', 'final_activation'), ('use_bias', 'use_final_bias', 'in_size', 'out_size', 'width_size', 'depth'), (True, True, 'scalar', 62, 64, 3)], [(CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), ('scalar', 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 62, True)], [*, *])), *, *])]), '_event_mask': *, '_event_ndims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_inner_event_ndims': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_mask': *}))], [None, *, *, *, *, *, *, *, *, None, None, *, None, None, None, None, None, None, *]), CustomNode(MaskedCoupling[([<function NSF.<locals>.bijector_fn at 0x18a30aca0>, None, None, None, None, None, None, None, None, <jax._src.custom_derivatives.custom_jvp object at 0x14fcaea50>, <function MLPConditioner.<lambda> at 0x18a57e700>, None, 1, 1, 1, 0, False, False, None], [False, True, True, True, True, True, True, True, True, False, False, True, False, False, False, False, False, False, True], PyTreeDef({'_bijector': *, '_conditioner': CustomNode(MLPConditioner[('mlp',), (), ()], [CustomNode(MLP[('layers', 'activation', 'final_activation'), ('use_bias', 'use_final_bias', 'in_size', 'out_size', 'width_size', 'depth'), (True, True, 'scalar', 62, 64, 3)], [(CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), ('scalar', 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 62, True)], [*, *])), *, *])]), '_event_mask': *, '_event_ndims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_inner_event_ndims': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_mask': *}))], [None, *, *, *, *, *, *, *, *, None, None, *, None, None, None, None, None, None, *]), CustomNode(MaskedCoupling[([<function NSF.<locals>.bijector_fn at 0x18a30aca0>, None, None, None, None, None, None, None, None, <jax._src.custom_derivatives.custom_jvp object at 0x14fcaea50>, <function MLPConditioner.<lambda> at 0x18a57e700>, None, 1, 1, 1, 0, False, False, None], [False, True, True, True, True, True, True, True, True, False, False, True, False, False, False, False, False, False, True], PyTreeDef({'_bijector': *, '_conditioner': CustomNode(MLPConditioner[('mlp',), (), ()], [CustomNode(MLP[('layers', 'activation', 'final_activation'), ('use_bias', 'use_final_bias', 'in_size', 'out_size', 'width_size', 'depth'), (True, True, 'scalar', 62, 64, 3)], [(CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), ('scalar', 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 64, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (64, 62, True)], [*, *])), *, *])]), '_event_mask': *, '_event_ndims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_inner_event_ndims': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_mask': *}))], [None, *, *, *, *, *, *, *, *, None, None, *, None, None, None, None, None, None, *])], '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *}))], [*, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, None, None, None, None]), '_distribution': CustomNode(MultivariateNormalDiag[([None, None, None, None, None, None, None, dtype('float32'), 2, None, None, None, None, None, None], [True, True, True, True, True, True, True, False, False, True, True, True, True, True, True], PyTreeDef({'_batch_shape': (), '_bijector': CustomNode(Chain[([None, None, None, None, None, 1, 1, True, True], [True, True, True, True, True, False, False, False, False], PyTreeDef({'_bijectors': [CustomNode(Block[([None, 1, 1, True, True, 1], [True, False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(Shift[([2, 0, 0, True, True, None], [False, False, False, False, False, True], PyTreeDef({'_batch_shape': (*,), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_shift': *}))], [None, None, None, None, None, *]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [*, None, None, None, None, None]), CustomNode(DiagLinear[([None, None, None, None, dtype('float32'), 2, 1, 1, True, True, <bound method Block.forward of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.forward_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.inverse of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.inverse_and_log_det of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.inverse_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>], [True, True, True, True, False, False, False, False, False, False, False, False, False, False, False], PyTreeDef({'_batch_shape': (), '_bijector': CustomNode(Block[([None, None, None, 1, 1, True, True, 1], [True, True, True, False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(ScalarAffine[([2, 0, 0, None, True, True, None, None, 0.0], [False, False, False, True, False, False, True, True, False], PyTreeDef({'_batch_shape': (*,), '_event_ndims_in': *, '_event_ndims_out': *, '_inv_scale': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_log_scale': *, '_scale': *, '_shift': *}))], [None, None, None, *, None, None, *, *, None]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [*, *, *, None, None, None, None, None]), '_diag': *, '_dtype': *, '_event_dims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, 'forward': *, 'forward_log_det_jacobian': *, 'inverse': *, 'inverse_and_log_det': *, 'inverse_log_det_jacobian': *}))], [*, *, *, *, None, None, None, None, None, None, None, None, None, None, None])], '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *}))], [*, *, *, *, *, None, None, None, None]), '_distribution': CustomNode(Independent[([None, None, 1], [True, True, False], PyTreeDef({'_distribution': CustomNode(Normal[([None, None], [True, True], PyTreeDef({'_loc': *, '_scale': *}))], [*, *]), '_reinterpreted_batch_ndims': *}))], [*, *, None]), '_dtype': *, '_event_shape': (*,), '_loc': *, '_scale': CustomNode(DiagLinear[([None, None, None, None, dtype('float32'), 2, 1, 1, True, True, <bound method Block.forward of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.forward_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.inverse of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.inverse_and_log_det of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>, <bound method Block.inverse_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x18a6fd750>>], [True, True, True, True, False, False, False, False, False, False, False, False, False, False, False], PyTreeDef({'_batch_shape': (), '_bijector': CustomNode(Block[([None, None, None, 1, 1, True, True, 1], [True, True, True, False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(ScalarAffine[([2, 0, 0, None, True, True, None, None, 0.0], [False, False, False, True, False, False, True, True, False], PyTreeDef({'_batch_shape': (*,), '_event_ndims_in': *, '_event_ndims_out': *, '_inv_scale': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_log_scale': *, '_scale': *, '_shift': *}))], [None, None, None, *, None, None, *, *, None]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [*, *, *, None, None, None, None, None]), '_diag': *, '_dtype': *, '_event_dims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, 'forward': *, 'forward_log_det_jacobian': *, 'inverse': *, 'inverse_and_log_det': *, 'inverse_log_det_jacobian': *}))], [*, *, *, *, None, None, None, None, None, None, None, None, None, None, None]), '_scale_diag': *}))], [*, *, *, *, *, *, *, None, None, *, *, *, *, *, *]), '_dtype': None, '_event_shape': None})); value: <distrax._src.distributions.transformed.Transformed object at 0x18a8839d0>.