This notebook provides some simple tests for the distributions in this repo, and is mean to help users understand their functionality

In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import jax
jax.config.update("jax_enable_x64", True)

from jax import vmap, grad, random, jit
import jax.numpy as jnp
from distributions import categorical, dirichlet, normal, niw, mniw

In [2]:
key = random.PRNGKey(47)
key

Array([ 0, 47], dtype=uint32)

In [3]:
key, subkey = random.split(key)
batched_params = random.uniform(subkey, (3, 3))

# Test distributions via exponential family properties
One key property of exponential families is that the gradient of the log partition function is the expected statistics of the distribution. We can test this for each distribution below

### Categorical

In [4]:
logZgrad = vmap(grad(categorical.logZ))(batched_params)
es = categorical.expected_stats(batched_params)
assert jnp.isclose(logZgrad, es).all()

### Dirichlet

In [5]:
logZgrad = vmap(grad(dirichlet.logZ))(batched_params)
es = dirichlet.expected_stats(batched_params)
assert jnp.isclose(logZgrad, es).all()

### Normal

In [6]:
key, subkey = random.split(key)
A = random.normal(subkey, (4,4))
pd = A.dot(A.T) + jnp.identity(4) * 1e-5
key, subkey = random.split(key)
mu = random.normal(subkey, (4,1))
params = (mu, pd)
natparam = normal.moment_to_nat(params)

Check `moment_to_nat` and `nat_to_moment` agree

In [7]:
recon = normal.nat_to_moment(natparam)
assert jnp.isclose(recon[0], params[0]).all()
assert jnp.isclose(recon[1], params[1]).all()

confirm $\nabla_\eta \log Z(\eta) = E[t(\eta)]$

In [8]:
logZgrad = grad(normal.logZ)(natparam)
es = normal.expected_stats(natparam)
for i in range(2):
    assert jnp.isclose(logZgrad[i], es[i]).all()

how much faster is the hard-coded version compared to the jax gradient version?

In [9]:
f = jit(grad(normal.logZ))
f(natparam) # run once for just-in-time compilation before testing the speed
%timeit f(natparam)

189 µs ± 3.05 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [10]:
f = jit(normal.expected_stats)
f(natparam)
%timeit f(natparam)

113 µs ± 1.97 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Let's confirm that batching this works properly

In [11]:
batched_pd = jnp.tile(jnp.expand_dims(pd,0),(3,1,1))
batched_pd = batched_pd.at[0].set(batched_pd[0] + jnp.identity(4) * 0.7)
batched_mu = jnp.tile(jnp.expand_dims(mu,0),(3,1,1))
batched_mu = batched_mu.at[1].set(batched_mu[1] + 0.7)
batch_natparam = vmap(normal.moment_to_nat)((batched_mu, batched_pd))

In [12]:
logZgrad = vmap(grad(normal.logZ))(batch_natparam)
es = vmap(normal.expected_stats)(batch_natparam)
for i in range(2):
    assert jnp.isclose(logZgrad[i], es[i]).all()

### NIW

In [13]:
key, subkey = random.split(key)
A = random.normal(subkey, (4,4))
pd = A.dot(A.T) + jnp.identity(4) * 1e-5
key, subkey = random.split(key)
mu = random.normal(subkey, (4,1))
params = (pd, mu, 7., 15.)
natparam = niw.moment_to_nat(params)

check `moment_to_nat` and `nat_to_moment` agree


In [14]:
recon = niw.nat_to_moment(natparam)
for i in range(4):
    assert jnp.isclose(recon[i], params[i]).all()

confirm $\nabla_\eta \log Z(\eta) = E[t(\eta)]$

In [15]:
logZgrad = grad(niw.logZ)(natparam)
es = niw.expected_stats(natparam)
for i in range(4):
    assert jnp.isclose(logZgrad[i], es[i]).all()

confirm batching works properly

In [16]:
batched_pd = jnp.tile(jnp.expand_dims(pd,0),(3,1,1))
batched_pd = batched_pd.at[0].set(batched_pd[0] + jnp.identity(4) * 0.7)
batched_mu = jnp.tile(jnp.expand_dims(mu,0),(3,1,1))
batched_mu = batched_mu.at[1].set(batched_mu[1] + 0.7)
batch_natparam = vmap(niw.moment_to_nat)((batched_pd, batched_mu, jnp.array([7.,7.,7.]), jnp.array([15.,15.,15.])))

In [17]:
logZgrad = vmap(grad(niw.logZ))(batch_natparam)
es = vmap(niw.expected_stats)(batch_natparam)
for i in range(4):
    assert jnp.isclose(logZgrad[i], es[i]).all()

### MNIW

In [18]:
key, subkey = random.split(key)
A = random.normal(subkey, (4,4))
pd = A.dot(A.T) + jnp.identity(4) * 1e-5
key, subkey = random.split(key)
M = random.normal(subkey, (4,5))
key, subkey = random.split(key)
B = random.normal(subkey, (5,5))
V = B.dot(B.T) + jnp.identity(5) * 1e-5
params = (pd, M, V, 15.)
natparam = mniw.moment_to_nat(params)

check `moment_to_nat` and `nat_to_moment` agree

In [19]:
recon = mniw.nat_to_moment(natparam)
for i in range(4):
    assert jnp.isclose(recon[i], params[i]).all()

confirm $\nabla_\eta \log Z(\eta) = E[t(\eta)]$

In [20]:
logZgrad = grad(mniw.logZ)(natparam)
es = mniw.expected_stats(natparam)
for i in range(4):
    assert jnp.isclose(logZgrad[i], es[i]).all()

confirm batching works properly

In [21]:
batched_pd = jnp.tile(jnp.expand_dims(pd,0),(3,1,1))
batched_pd = batched_pd.at[0].set(batched_pd[0] + jnp.identity(4) * 0.7)
batched_M = jnp.tile(jnp.expand_dims(M,0),(3,1,1))
batched_M = batched_M.at[1].set(batched_M[1] + 0.7)
batched_V = jnp.tile(jnp.expand_dims(V,0),(3,1,1))
batched_V = batched_V.at[2].set(batched_V[2] + jnp.identity(5) * 0.7)
batch_natparam = vmap(mniw.moment_to_nat)((batched_pd, batched_M, batched_V, jnp.array([15.,15.,15.])))

In [22]:
logZgrad = vmap(grad(mniw.logZ))(batch_natparam)
es = vmap(mniw.expected_stats)(batch_natparam)
for i in range(4):
    assert jnp.isclose(logZgrad[i], es[i]).all()