## inverse mass matrix in NumPyro/NUTS

In [1]:
import numpy as np
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro_ext import information as information

Here we sample from a posterior $\pi(x) \propto \mathcal{L}(x)\,p_0(x)$ conditioned on multivariate normal likelihood
$$
\mathcal{L}(x) = {1\over\sqrt{|2\pi\Sigma|}}\,\exp\left[-{1\over 2}(\mu-x)^T\Sigma^{-1}(\mu-x)\right], \quad \mu = \begin{pmatrix} 2 \\ 1\end{pmatrix},\ \Sigma = \begin{pmatrix} 1 & 0.09 \\ 0.09 & 1 \end{pmatrix}
$$
and a uniform, **bound** prior on $x$:
$$
p_0(x) = \begin{cases}
{1/(2x_\mathrm{max})^2} &\text{for}\ |x_1| < x_\mathrm{max}, |x_2| < x_\mathrm{max}\\
0  &\text{otherwise}
\end{cases}.
$$

We use NUTS and tuned the dense inverse mass matrix $M^{-1}$. Its ideal value is the parameter covariance matrix, so we expect to find
$$
    M^{-1} = \Sigma = \mathcal{I}^{-1},
$$
where $\mathcal{H}$ is the Fisher information matrix.
Is this the case here?

In [2]:
mu_true = jnp.array([2., 1.])
cov_true = jnp.array([[1., 0.09], [0.09, 0.01]]) 
xabs_max = 10.

def model():
    x = numpyro.sample("x", dist.Uniform(-xabs_max, xabs_max), sample_shape=(2,))
    numpyro.sample("obs", dist.MultivariateNormal(loc=x, covariance_matrix=cov_true), obs=mu_true) # put x in loc to make numpyro_ext informatoin work


def run_mcmc(model, dense_mass=True, adapt_mass_matrix=True, inverse_mass_matrix=None):
    kernel = numpyro.infer.NUTS(model, dense_mass=dense_mass, adapt_mass_matrix=adapt_mass_matrix, inverse_mass_matrix=inverse_mass_matrix)
    mcmc = numpyro.infer.MCMC(kernel, num_warmup=2000, num_samples=2000)
    rng_key = random.PRNGKey(0)
    mcmc.run(rng_key)
    mcmc.print_summary()

    return mcmc

mcmc = run_mcmc(model)
invM = mcmc.last_state.adapt_state.inverse_mass_matrix[('x',)] # this is M^{-1}

sample: 100%|██████████| 4000/4000 [00:01<00:00, 3886.97it/s, 3 steps of size 7.26e-01. acc. prob=0.94]


                mean       std    median      5.0%     95.0%     n_eff     r_hat
      x[0]      2.01      0.98      2.00      0.38      3.61   1582.10      1.00
      x[1]      1.00      0.10      1.00      0.84      1.17   1514.56      1.00

Number of divergences: 0





First we evaluate the inverse Fisher information using information function in numpryo_ext. This agrees with $\Sigma$, as expected:

In [3]:
x_eval = {'x': mu_true}
fisher_inv = information(model, invert=True, include_prior=False, unconstrained=False)(x_eval)['x']['x']
cov_true / fisher_inv

Array([[1.       , 1.0000001],
       [1.0000001, 1.0000001]], dtype=float32)

But, fisher_inv does NOT agree with $M^{-1}$ tuned by NumPyro:

In [4]:
invM / fisher_inv

Array([[0.04571855, 0.04275488],
       [0.04275488, 0.04069575]], dtype=float32)

Why? This is beacause NumPyro actually sees *unconstrained variables* $z$ that has a real support and is mapped to the constrained $x \sim \mathcal{U}(-x_\mathrm{max}, x_\mathrm{max})$.

The inverse mass matrix in NumPyro NUTS is also defined w.r.t. these unconstrained variables. This can also be calculated from the same function by setting unconstrained=True:

In [5]:
x_eval = {'x': mu_true}
fisher_inv_unc = information(model, invert=True, include_prior=False, unconstrained=True)(x_eval)['x']['x']
invM / fisher_inv_unc

Array([[0.20159441, 0.3530357 ],
       [0.35303566, 0.62925816]], dtype=float32)

The two are closer but still don't agree. The key thing here is that $M^{-1}(z)$ defined w.r.t. $z$ is no longer constant even for the gaussian $\mathcal{L}(x)$, because the mapping $x(z)$ is non-linear.

So it depends on where it's evaulated; and we need to put the value of $z$, not $x$, at which we'd like to evaulate $M^{-1}$. The value of $z$ for a given $x$ can be found as follows:

In [6]:
from numpyro import handlers
from numpyro.distributions.transforms import biject_to

def to_unconstrained(model, params_constrained, *args, **kwargs):
    """
    Convert constrained parameter values to their unconstrained representations
    using the model's sample sites.

    Args:
        model (callable): A NumPyro model function.
        params_constrained (dict): Dictionary mapping parameter names to values
            in the *constrained* space (e.g., positive scales, simplex).
        *args: Positional arguments passed to `model` when tracing.
        **kwargs: Keyword arguments passed to `model` when tracing.

    Returns:
        dict: Dictionary mapping parameter names to values in the
        *unconstrained* space, suitable for use in inference algorithms
        (e.g., HMC/NUTS).
    """
    tr = handlers.trace(handlers.seed(model, 0)).get_trace(*args, **kwargs)
    bij = {}
    for name, site in tr.items():
        if site["type"] == "sample" and not site["is_observed"]:
            bij[name] = biject_to(site["fn"].support)
    # Map constrained -> unconstrained
    z_params = {k: bij[k].inv(v) for k, v in params_constrained.items()}
    return z_params

In [7]:
z_eval = to_unconstrained(model, {'x': mu_true})
print('Unconstrained parameters corresponding to mu_true:', z_eval)

Unconstrained parameters corresponding to mu_true: {'x': Array([0.4054652 , 0.20067078], dtype=float32)}


Using this value of $z$, $M^{-1}$ and inverse Fisher agree:

In [8]:
fisher_inv_unc0 = information(model, invert=True, include_prior=False, unconstrained=True)(z_eval)['x']['x']
invM / fisher_inv_unc0

Array([[1.0533552 , 1.0158558 ],
       [1.0158558 , 0.99714756]], dtype=float32)

To summarize:

- We want $M^{-1}$ to be the parameter covariance $\Sigma$, which is the inverse of Hessian $\mathcal{H}$.
- But when some parameters $x$ are sampled from bounded priors, NumPyro samples from unconstrained variables $z$ that are mapped to bounded $x$. So we need to evaluate $\Sigma$ w.r.t. $z$.

### when the mapping is identity

When x is sampled from dist.Normal, mapping between $x$ and $z$ is just identity, so this complexity doesn't arise.

In this case, parameter covariance evaluated for $x$ just works as $M^{-1}$. Let's check this:

In [9]:
def model_norm():
    x = numpyro.sample("x", dist.Normal(0., xabs_max), sample_shape=(2,)) # normal
    numpyro.sample("obs", dist.MultivariateNormal(loc=x, covariance_matrix=cov_true), obs=mu_true) 


def run_mcmc(model, dense_mass=True, adapt_mass_matrix=True, inverse_mass_matrix=None):
    kernel = numpyro.infer.NUTS(model, dense_mass=dense_mass, adapt_mass_matrix=adapt_mass_matrix, inverse_mass_matrix=inverse_mass_matrix)
    mcmc = numpyro.infer.MCMC(kernel, num_warmup=2000, num_samples=2000)
    rng_key = random.PRNGKey(0)
    mcmc.run(rng_key)
    mcmc.print_summary()

    return mcmc

mcmc_norm = run_mcmc(model_norm)
invM_norm = mcmc_norm.last_state.adapt_state.inverse_mass_matrix[('x',)] # this is M^{-1}

sample: 100%|██████████| 4000/4000 [00:01<00:00, 3936.08it/s, 3 steps of size 9.15e-01. acc. prob=0.91]


                mean       std    median      5.0%     95.0%     n_eff     r_hat
      x[0]      1.97      0.98      1.97      0.45      3.65   2104.61      1.00
      x[1]      1.00      0.10      1.00      0.84      1.15   2042.06      1.00

Number of divergences: 0





In [10]:
x_eval = {'x': mu_true}
fisher_inv_constrained = information(model, invert=True, include_prior=False, unconstrained=False)(x_eval)['x']['x']

In [11]:
invM_norm / fisher_inv_constrained

Array([[1.0130707, 1.048032 ],
       [1.048032 , 1.0706003]], dtype=float32)

So in this case, Fisher information computed for physical parameters can be fed into NUTS.