<a href="https://colab.research.google.com/github/lindermanlab/hackathons/blob/master/notebooks/TFP_Normal_Inverse_Wishart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implementing a Normal Inverse Wishart Distribution in Tensorflow Probability
_Scott Linderman_

_Oct 27, 2021_

---

The normal inverse Wishart (NIW) is a conjugate prior for a multivariate Gaussian with unknown mean and covariance. Specifically, it's a joint distribution on a vector $\mu \in \mathbb{R}^D$ and a positive semi-definite matrix $\Sigma \in \mathbb{R}_{\succeq 0}^{D \times D}$ defined by the following generative model,
\begin{align}
\Sigma &\sim \mathrm{IW}(\nu, \Psi) \\
\mu \mid \Sigma &\sim \mathcal{N}(\mu_0, \kappa^{-1} \Sigma)
\end{align}
The NIW is parameterized by the degrees of freedom $\nu$, a scale matrix $\Psi$, a location $\mu_0$, and a scaling parameter $\kappa$ that I'll call the "mean precision."
Unfortunately, tensorflow probability (TFP) doesn't have a premade NIW distribution and implementing it was a bit of a pain... this notebook shows how I ended up doing it.

TFP's `JointDistribution` objects seem well suited to this problem. Ideally, it would be as simple as,
```
niw = tfd.JointDistributionNamed(dict(
    Sigma=lambda: tfd.InverseWishartTriL(df, np.linalg.cholesky(scale)), 
    mu=lambda Sigma: tfd.MultivariateNormalFullCovariance(loc, Sigma / mean_precision)
))
```

The first problem is that TFP doesn't have an `InverseWishartTriL` distribution (ugh!). Ok, we'll just specify the inverse Wishart as a `TransformedDistribution` that wraps the `WishartTriL` distribution... something like,
```
niw = tfd.JointDistributionNamed(dict(
    Sigma=lambda: tfd.TransformedDistribution(
        tfd.WishartTriL(df, np.linalg.cholesky(np.linalg.inv(scale))),
        bijector=tfb.MatrixInverse()), 
    mu=lambda Sigma: tfd.MultivariateNormalFullCovariance(loc, Sigma / mean_precision)
))
```

The second problem is that TFP doesn't have a `MatrixInverse` bijector (ugh!!). Here's the workaround I arrived at: invert the (positive semidefinite) matrix $J$ output by the Wishart distribution by chaining three bijectors:
1. $J \mapsto \mathrm{chol}(J)$                        
2. $\mathrm{chol}(J) \mapsto \mathrm{chol}(J^{-1})$ 
2. $\mathrm{chol}(J^{-1}) \mapsto \mathrm{chol}(J^{-1}) \mathrm{chol}(J^{-1})^\top = J^{-1} \equiv \Sigma$ 

That would look like,
```
niw = tfd.JointDistributionNamed(dict(
    Sigma=lambda: tfd.TransformedDistribution(
        tfd.WishartTriL(df, np.linalg.cholesky(np.linalg.inv(scale))),
        bijector=tfb.Chain([
            tfb.CholeskyOuterProduct(),
            tfb.CholeskyToInvCholesky(),
            tfb.Cholesky()])), 
    mu=lambda Sigma: tfd.MultivariateNormalFullCovariance(loc, Sigma / mean_precision)
))
```

The third problem is that TFP doesn't have a `Cholesky` bijector (ugh!!!). Thankfully, the Cholesky bijector is just the inverse of the `CholeskyOuterProduct` bijector, and we can use `tfb.Invert` to get it.

The code below puts all of this together. I'm sure there are other implementations. For example, the TFP [Bayesian GMM example](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb) defines a new multivariate normal distribution parameterized by the square root of the inverse covariance (aka precision) matrix. That's probably a bit more efficient since it doesn't have so many conversions to/from Cholesky decompositions. Personally, I prefer the approach below because it outputs a distribution on $(\mu, \Sigma)$, as the NIW is typically specified.



In [1]:
import jax.numpy as np
import jax.random as jr
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

In [2]:
dim = 3
params = dict(
    loc=np.zeros(dim),
    mean_precision=1.0,
    df=dim + 3,
    scale=np.eye(dim)
)



In [3]:
def make_niw(loc, mean_precision, df, scale):
    """
    loc: \mu_0 in math above
    mean_precision: \kappa_0
    df: \nu
    scale: \Psi 
    """
    wishart_scale_tril = np.linalg.cholesky(np.linalg.inv(scale))
    niw = tfd.JointDistributionNamed(dict(
        Sigma=lambda: tfd.TransformedDistribution(
            tfd.WishartTriL(df, scale_tril=wishart_scale_tril),
            tfb.Chain([tfb.CholeskyOuterProduct(),                 
                       tfb.CholeskyToInvCholesky(),                
                       tfb.Invert(tfb.CholeskyOuterProduct())
                       ])),
        mu=lambda Sigma: tfd.MultivariateNormalFullCovariance(
            loc, Sigma / mean_precision)
    ))
    return niw

niw = make_niw(**params)
smpl = niw.sample(seed=jr.PRNGKey(0))
print(smpl)

{'Sigma': DeviceArray([[ 0.39998385, -0.08702334, -0.01624717],
             [-0.08702334,  0.28204948, -0.0854842 ],
             [-0.01624717, -0.0854842 ,  0.18808396]], dtype=float32), 'mu': DeviceArray([0.47917017, 0.43772495, 0.16959237], dtype=float32)}


## Double check the log prob

In [4]:
from scipy.stats import invwishart
from scipy.stats import multivariate_normal as mvn

def manual_niw_log_prob(mu, Sigma, loc, mean_precision, df, scale):
    lp = invwishart.logpdf(Sigma, df, scale)
    lp += mvn.logpdf(mu, loc, Sigma / mean_precision)
    return lp

manual_niw_log_prob(**smpl, **params)
assert np.allclose(niw.log_prob(smpl), manual_niw_log_prob(**smpl, **params))