In [17]:
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
from functools import partial


In [18]:
@jax.jit
def kme_RBF_Gaussian(mu, Sigma, l, y):
    """
    :param mu: Gaussian mean, (D, )
    :param Sigma: Gaussian covariance, (D, D)
    :param l: lengthscale, scalar
    :param y: sample: (N, D)
    :return:
    """
    kme_RBF_Gaussian_func_ = partial(kme_RBF_Gaussian_func, mu, Sigma, l)
    kme_RBF_Gaussian_vmap_func = jax.vmap(kme_RBF_Gaussian_func_)
    return kme_RBF_Gaussian_vmap_func(y)


@jax.jit
def kme_RBF_Gaussian_func(mu, Sigma, l, y):
    """
    :param mu: Gaussian mean, (D, )
    :param Sigma: Gaussian covariance, (D, D)
    :param l: lengthscale, scalar
    :param y: sample: D,
    :return: scalar
    """
    # From the kernel mean embedding document
    D = mu.shape[0]
    l_ = l ** 2
    Lambda = jnp.eye(D) * l_
    Lambda_inv = jnp.eye(D) / l_
    part1 = jnp.linalg.det(jnp.eye(D) + Sigma @ Lambda_inv)
    part2 = jnp.exp(-0.5 * (mu - y).T @ jnp.linalg.inv(Lambda + Sigma) @ (mu - y))
    return part1 ** (-0.5) * part2

@jax.jit
def my_RBF(x, y, l):
    """
    :param x: N*D
    :param y: M*D
    :param l: scalar
    :return: N*M
    """
    kernel = tfp.math.psd_kernels.ExponentiatedQuadratic(amplitude=1., length_scale=l)
    K = kernel.matrix(x, y)
    return K

In [68]:
seed = 1
rng_key = jax.random.PRNGKey(seed)
D = 10
N = 10

rng_key, _ = jax.random.split(rng_key)
mean = jax.random.uniform(rng_key, shape=(D, ), minval=0.0, maxval=1.0)

rng_key, _ = jax.random.split(rng_key)
dummy = jax.random.uniform(rng_key, shape=(D, 100), minval=0.0, maxval=0.1)
var = dummy @ dummy.T + jnp.eye(D)

l = 1.0
rng_key, _ = jax.random.split(rng_key)
y = jax.random.multivariate_normal(rng_key, mean=mean, cov=var, shape=(N, ))
rng_key, _ = jax.random.split(rng_key)
y_fixed = jax.random.multivariate_normal(rng_key, mean=mean, cov=var, shape=(10, ))

K = my_RBF(y_fixed, y, l)
kme = kme_RBF_Gaussian(mean, var, l, y_fixed)

In [69]:
K.mean(1)

Array([8.7965076e-04, 4.5721470e-03, 1.2602427e-04, 2.1201411e-05,
       2.3726935e-03, 4.6493388e-03, 4.3867859e-03, 2.3953275e-04,
       1.1579233e-03, 3.4512789e-04], dtype=float32)

In [70]:
kme

Array([0.0010185 , 0.01007888, 0.00100589, 0.00078874, 0.00124499,
       0.0027315 , 0.00986949, 0.00081317, 0.00297458, 0.00039722],      dtype=float32)

In [57]:
var

Array([[1.3641884 , 0.2633632 , 0.2464004 , 0.27786878, 0.26133367,
        0.25889683, 0.2586669 , 0.2575637 , 0.30468386, 0.25391233],
       [0.2633632 , 1.3181638 , 0.22231829, 0.2663336 , 0.24106838,
        0.2320017 , 0.23834854, 0.2182107 , 0.2695743 , 0.22789618],
       [0.2464004 , 0.22231829, 1.3087564 , 0.23935854, 0.2408936 ,
        0.22616975, 0.21328543, 0.21933575, 0.26172552, 0.21886317],
       [0.27786878, 0.2663336 , 0.23935854, 1.3690783 , 0.26243916,
        0.25589228, 0.23990871, 0.23971394, 0.30577105, 0.24893215],
       [0.26133367, 0.24106838, 0.2408936 , 0.26243916, 1.3118457 ,
        0.24977423, 0.2408855 , 0.2263063 , 0.2831559 , 0.23346937],
       [0.25889683, 0.2320017 , 0.22616975, 0.25589228, 0.24977423,
        1.3196143 , 0.24970306, 0.23052216, 0.27194983, 0.23551844],
       [0.2586669 , 0.23834854, 0.21328543, 0.23990871, 0.2408855 ,
        0.24970306, 1.3046267 , 0.22181647, 0.28493407, 0.24839784],
       [0.2575637 , 0.2182107 , 0.2193357