In [75]:
import jax.numpy as jnp
import jax
from tensorflow_probability.substrates import jax as tfp
from functools import partial
import matplotlib.pyplot as plt
from jax.config import config
import numpy as onp
config.update("jax_enable_x64", True)

tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

In [76]:
K = 5
eta = 2.5
zeros = jnp.zeros((K * (K - 1))//2)

In [77]:
@partial(jax.jit, static_argnames=['K'])
def tfp_transform(y, K):
    L = jnp.eye(K);
    log_det_jacobian = 0;
    counter = 0
    for i in range(1, K):
        for j in range(i):
            L = L.at[i, j].set(y[counter])
            counter += 1;
        s = jnp.linalg.norm(L[i])
        L = L.at[i].set(L[i] / s)
        log_det_jacobian -= (i + 2) * jnp.log(s)
    return L, log_det_jacobian

In [78]:
@partial(jax.jit, static_argnames=['K'])
def stan_transform(y, K):
    L = jnp.eye(K);
    z = jnp.tanh(y)
    log_det_jacobian = jnp.sum(jnp.log(1-jnp.square(z)))
    counter = 0
    for i in range(1, K):
        L = L.at[i, 0].set(z[counter])
        counter += 1
        sum_sqs = L[i, 0] ** 2
        for j in range(1,i):
            log_det_jacobian += 0.5 * jnp.log(1-sum_sqs)
            L = L.at[i, j].set(z[counter] * jnp.sqrt( 1 - sum_sqs))
            counter += 1
            sum_sqs = sum_sqs + jnp.square(L[i, j]);
        L = L.at[i, i].set(jnp.sqrt(1 - sum_sqs))
    return L, log_det_jacobian

In [79]:
@partial(jax.jit, static_argnames=['K'])
def lp_tfp(y, eta, K):
    L, log_det_jacobian = tfp_transform(y, K)
    return log_det_jacobian + tfd.CholeskyLKJ(K, eta.astype(jnp.float64)).log_prob(L)

In [80]:
@partial(jax.jit, static_argnames=['K'])
def lp_stan(y, eta, K):
    L, log_det_jacobian = stan_transform(y, K)
    return log_det_jacobian + tfd.CholeskyLKJ(K, eta.astype(jnp.float64)).log_prob(L)

Both Hessians are negative-definite, but TFP's version is "isotropic" in unconstrained space whereas Stan's version is not (isotropic in the sense that all directions have the same second derivative, i.e. "symmetric" curvature at the mode):

In [81]:
jax.hessian(lp_tfp)(zeros, eta, K)

DeviceArray([[-9.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0., -9.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0., -9.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0., -9.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0., -9.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0., -9.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0., -9.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0., -9.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -9.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -9.]],            dtype=float64)

In [82]:
jax.hessian(lp_stan)(zeros, eta, K)

DeviceArray([[-8.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0., -8.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0., -7.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0., -8.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0., -7.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0., -6.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0., -8.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0., -7.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -6.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -5.]],            dtype=float64)

We now move to K=3 where its easier to compute the probabilities on a grid, here using a grid of 200 points from -5 to 5 on all three axes:

In [83]:
K = 3
lb = -5
ub = 5
diff = ub - lb
n_points = 200

In [84]:
t = jnp.linspace(lb,ub,num=n_points)
res_tfp = jax.vmap(
    jax.vmap(
        jax.vmap(
            lambda x,y,z : lp_tfp(jnp.array([x,y,z]), eta, K), in_axes=(0,None,None)),
        in_axes=(None,0,None)),
    in_axes=(None,None,0))(t,t,t)

In [85]:
res_stan = jax.vmap(
    jax.vmap(
        jax.vmap(
            lambda x,y,z : lp_stan(jnp.array([x,y,z]), eta, K), in_axes=(0,None,None)),
        in_axes=(None,0,None)),
    in_axes=(None,None,0))(t,t,t)

This volume contains ~98% of the probability density for both transforms:

In [86]:
(((ub-lb)/n_points)**3*jnp.exp(res_tfp)).sum()

DeviceArray(0.98487558, dtype=float64)

In [87]:
(((ub-lb)/n_points)**3*jnp.exp(res_stan)).sum()

DeviceArray(0.98507487, dtype=float64)

# norm of gradient

In [88]:
t = jnp.linspace(lb,ub,num=n_points)
grad_norm_tfp = jax.vmap(
    jax.vmap(
        jax.vmap(
            lambda x,y,z : jnp.linalg.norm(jax.grad(lp_tfp)(jnp.array([x,y,z]), eta, K)), in_axes=(0,None,None)),
        in_axes=(None,0,None)),
    in_axes=(None,None,0))(t,t,t)

In [89]:
t = jnp.linspace(lb,ub,num=n_points)
grad_norm_stan = jax.vmap(
    jax.vmap(
        jax.vmap(
            lambda x,y,z : jnp.linalg.norm(jax.grad(lp_stan)(jnp.array([x,y,z]), eta, K)), in_axes=(0,None,None)),
        in_axes=(None,0,None)),
    in_axes=(None,None,0))(t,t,t)

# hessian condition number

In [90]:
t = jnp.linspace(lb,ub,num=n_points)
hess_cond_tfp = jax.vmap(
    jax.vmap(
        jax.vmap(
            lambda x,y,z : jnp.linalg.cond(jax.hessian(lp_tfp)(jnp.array([x,y,z]), eta, K)), in_axes=(0,None,None)),
        in_axes=(None,0,None)),
    in_axes=(None,None,0))(t,t,t)

In [91]:
t = jnp.linspace(lb,ub,num=n_points)
hess_cond_stan = jax.vmap(
    jax.vmap(
        jax.vmap(
            lambda x,y,z : jnp.linalg.cond(jax.hessian(lp_stan)(jnp.array([x,y,z]), eta, K)), in_axes=(0,None,None)),
        in_axes=(None,0,None)),
    in_axes=(None,None,0))(t,t,t)

stan's grad is generally 3 times larger in this volume:

In [92]:
jnp.median(grad_norm_stan/grad_norm_tfp), jnp.mean(grad_norm_stan/grad_norm_tfp)

(DeviceArray(2.98259064, dtype=float64),
 DeviceArray(3.12826176, dtype=float64))

and the median of the ratio of hessian condition numbers is ~40 times larger (the distribution is quite skewed):

In [93]:
jnp.median(hess_cond_stan/hess_cond_tfp), jnp.mean(hess_cond_stan/hess_cond_tfp)

(DeviceArray(39.44971372, dtype=float64),
 DeviceArray(186.42597119, dtype=float64))

in almost all "cubes" in this volume, stan's version assigns smaller probabilty; the places where stan's probability is higher are concentrated near the mode (again, un-isotropically) 

In [101]:
zeros = jnp.zeros((K * (K - 1))//2)

In [102]:
prob_ratio = jnp.exp(res_stan)/jnp.exp(res_tfp)
prob_ratio.mean()

DeviceArray(0.02308517, dtype=float64)

In [103]:
jnp.where(prob_ratio>1)

(DeviceArray([ 72,  72,  72, ..., 127, 127, 127], dtype=int64),
 DeviceArray([ 97,  97,  97, ..., 102, 102, 102], dtype=int64),
 DeviceArray([ 87,  88,  89, ..., 110, 111, 112], dtype=int64))