In [1]:
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
from functools import partial


In [13]:
seed = 1
rng_key = jax.random.PRNGKey(seed)

N = 1000
D = 2

rng_key, _ = jax.random.split(rng_key)
mu = jax.random.uniform(rng_key, shape=(D,))

rng_key, _ = jax.random.split(rng_key)
x_dummy = jax.random.uniform(rng_key, shape=(100, D))
Sigma = x_dummy.T @ x_dummy + jnp.eye(D)

rng_key, _ = jax.random.split(rng_key)
x = jax.random.multivariate_normal(rng_key, mean=mu, cov=Sigma, shape=(N, ))

rng_key, _ = jax.random.split(rng_key)
y = jax.random.multivariate_normal(rng_key, mean=mu, cov=Sigma, shape=(N, ))

l = 1.0


In [28]:
def kme_RBF_Gaussian(mu, Sigma, l, y):
    """
    :param mu: Gaussian mean, (D, )
    :param sigma: Gaussian covariance, (D, D)
    :param l: lengthscale, scalar
    :param y: sample: (N, D)
    :return:
    """
    kme_RBF_Gaussian_func_ = partial(kme_RBF_Gaussian_func, mu, Sigma, l)
    kme_RBF_Gaussian_vmap_func = jax.vmap(kme_RBF_Gaussian_func_)
    return kme_RBF_Gaussian_vmap_func(y)


def kme_RBF_Gaussian_func(mu, Sigma, l, y):
    """
    :param mu: Gaussian mean, (D, )
    :param sigma: Gaussian covariance, (D, D)
    :param l: lengthscale, scalar
    :param y: sample: D,
    :return: scalar
    """
    # From the kernel mean embedding document
    D = mu.shape[0]
    Lambda = jnp.eye(D) * l
    Lambda_inv = jnp.eye(D) / l
    part1 = jnp.linalg.det(jnp.eye(D) + Sigma @ Lambda_inv)
    part2 = jnp.exp(-0.5 * (mu - y).T @ jnp.linalg.inv(Lambda + Sigma) @ (mu - y))
    return part1 ** (-0.5) * part2

def my_RBF(x, y, l):
    """
    :param x: N*D
    :param y: M*D
    :param l: scalar
    :return: N*M
    """
    kernel = tfp.math.psd_kernels.ExponentiatedQuadratic(amplitude=1., length_scale=l)
    K = kernel.matrix(x, y)
    return K

In [29]:
empirical = my_RBF(x, y, l).mean(0)

In [30]:
analytical = empirical * 0
for i in range(N):
    Y = y[i, :]
    analytical_dummy = kme_RBF_Gaussian_func(mu, Sigma, l, Y)
    analytical = analytical.at[i].set(analytical_dummy)

In [31]:
analytical[:10]

Array([0.04002754, 0.00739984, 0.03849194, 0.00446801, 0.02718731,
       0.01028908, 0.01717951, 0.00129134, 0.01635272, 0.00551911],      dtype=float32)

In [32]:
empirical[:10]

Array([0.03678151, 0.00621564, 0.04432947, 0.00332187, 0.03083159,
       0.00990483, 0.01637228, 0.00169566, 0.01606366, 0.00648652],      dtype=float32)

In [34]:
kme_RBF_Gaussian(mu, Sigma, l, y)[:10]

Array([0.04002754, 0.00739984, 0.03849194, 0.00446801, 0.02718731,
       0.01028908, 0.01717951, 0.00129134, 0.01635272, 0.00551911],      dtype=float32)

In [37]:
Sigma.shape

(2, 2)

In [70]:
kme_RBF_Gaussian_func_ = partial(kme_RBF_Gaussian_func, mu, sigma, l)
kme_RBF_Gaussian_vmap_func = jax.vmap(kme_RBF_Gaussian_func_)

In [73]:
kme_RBF_Gaussian_vmap_func(y)[:10]

Array([0.7683728 , 0.7709943 , 0.7710005 , 0.7691586 , 0.64575416,
       0.6365243 , 0.7714575 , 0.7085045 , 0.642029  , 0.67188025],      dtype=float32)