In [8]:
from sobol_seq import i4_sobol_generate


In [37]:

import jax.numpy as jnp
import jax
from tensorflow_probability.substrates import jax as tfp
from functools import partial
from jax.scipy.stats import norm


def stein_Matern(x, y, l, d_log_px, d_log_py):
    """
    :param x: N*D
    :param y: M*D
    :param l: scalar
    :param d_log_px: N*D
    :param d_log_py: M*D
    :return: N*M
    """
    N, D = x.shape
    M = y.shape[0]

    batch_kernel = tfp.math.psd_kernels.MaternThreeHalves(amplitude=1., length_scale=l)
    grad_x_K_fn = jax.grad(batch_kernel.apply, argnums=0)
    vec_grad_x_K_fn = jax.vmap(grad_x_K_fn, in_axes=(0, 0), out_axes=0)
    grad_y_K_fn = jax.grad(batch_kernel.apply, argnums=1)
    vec_grad_y_K_fn = jax.vmap(grad_y_K_fn, in_axes=(0, 0), out_axes=0)

    grad_xy_K_fn = jax.jacfwd(jax.jacrev(batch_kernel.apply, argnums=1), argnums=0)

    def diag_sum_grad_xy_K_fn(x, y):
        return jnp.diag(grad_xy_K_fn(x, y)).sum()

    vec_grad_xy_K_fn = jax.vmap(diag_sum_grad_xy_K_fn, in_axes=(0, 0), out_axes=0)

    x_dummy = jnp.stack([x] * N, axis=1).reshape(N * M, D)
    y_dummy = jnp.stack([y] * M, axis=0).reshape(N * M, D)

    K = batch_kernel.matrix(x, y)
    dx_K = vec_grad_x_K_fn(x_dummy, y_dummy).reshape(N, M, D)
    dy_K = vec_grad_y_K_fn(x_dummy, y_dummy).reshape(N, M, D)
    dxdy_K = vec_grad_xy_K_fn(x_dummy, y_dummy).reshape(N, M)

    part1 = d_log_px @ d_log_py.T * K
    part2 = (d_log_py[None, :] * dx_K).sum(-1)
    part3 = (d_log_px[:, None, :] * dy_K).sum(-1)
    part4 = dxdy_K

    return part1 + part2 + part3 + part4


def score_fn(y, mu, sigma):
    """
    return \nabla_y log p(y|mu, sigma)
    :param y: (N, D)
    :param mu: (D, )
    :param sigma: (D, D)
    :return: (N, D)
    """
    return -(y - mu[None, :]) @ jnp.linalg.inv(sigma)

def log_llk(y, mu, sigma):
    return jax.scipy.stats.multivariate_normal.logpdf(y, mu, sigma).sum()
    

def qmc_gaussian(mu, sigma, nsamples):
    """
    :param mu: (D, )
    :param sigma: (D, D)
    :param nsamples:
    :return: samples: (nsamples, D)
    """
    D = mu.shape[0]
    u = i4_sobol_generate(D, nsamples)
    L = jnp.linalg.cholesky(sigma)
    samples = mu[:, None] + (norm.ppf(u) @ L).T
    samples = samples.T
    return samples, u

In [32]:
log_llk(samples, mu, sigma).shape

()

In [34]:
grad = jax.grad(log_llk, argnums=0)(samples, mu, sigma)
grad

Array([[ 0.        ,  0.        ],
       [-0.83907676,  1.0036637 ],
       [ 0.83907676, -1.0036637 ],
       ...,
       [-0.5193622 ,  0.45856872],
       [ 1.2219273 , -1.7677315 ],
       [ 1.146997  ,  0.55294514]], dtype=float32)

In [38]:
seed = 0
rng_key = jax.random.PRNGKey(seed)
N = 100

mu = jnp.array([0.1, 0.1])
sigma = jnp.array([[1.0, 0.5], [0.5, 1.0]])
samples, _ = qmc_gaussian(mu, sigma, N)
score = score_fn(samples, mu, sigma)


In [40]:
l = 1.0
K = stein_Matern(samples, samples, l, score, score)

In [41]:
K.mean(1)

Array([-0.01541214, -0.02288976, -0.02268224, -0.05494424,  0.023611  ,
       -0.15775913,  0.06250537,  0.09608445,  0.00113019, -0.01712696,
       -0.02659494, -0.00278812,  0.09662905, -0.02559544,  0.00747745,
        0.14961518, -0.1373193 , -0.01368134, -0.07155657, -0.13544227,
        0.16497128, -0.02587449, -0.07058392,  0.09313472, -0.01122014,
        0.09472355, -0.153085  , -0.00428645,  0.04264457, -0.17822246,
        0.21884274,  0.23187469, -0.06738494, -0.20360675, -0.04399686,
       -0.09875187,  0.1334598 , -0.03871112, -0.08141498, -0.00494679,
        0.02810924,  0.15093832, -0.10725578, -0.04930729, -0.03664687,
       -0.1280179 ,  0.05464987,  0.19470344, -0.04909984,  0.06619162,
       -0.18760559, -0.06878809,  0.06685413, -0.15949498, -0.0099932 ,
        0.10129671, -0.10415832, -0.01529307, -0.07082627, -0.04559669,
        0.04281029, -0.09472591,  0.14663856,  0.24400835, -0.12239994,
       -0.10217958, -0.03305624, -0.12522101,  0.12372877, -0.07