In [29]:
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
from jax.scipy.stats import norm

from jax.config import config

config.update('jax_platform_name', 'cpu')
config.update("jax_enable_x64", True)


In [51]:
seed = 10
rng_key = jax.random.PRNGKey(seed)

D = 1
N = 1000
mean = jnp.zeros(D) + 0.0
cov = jnp.eye(D) + 0.0

rng_key, _ = jax.random.split(rng_key)
x = jax.random.multivariate_normal(rng_key, mean, cov, shape=(N, ))
rng_key, _ = jax.random.split(rng_key)
y = jax.random.multivariate_normal(rng_key, mean, cov, shape=(2, ))


In [42]:
l = 0.3

kernel = tfp.math.psd_kernels.MaternThreeHalves(amplitude=1., length_scale=l)
K = kernel.matrix(x, y)

In [43]:
K.mean(0)

Array([0.2578935, 0.2391616], dtype=float64)

In [49]:
def Matern_kme(y, l):
    E10 = 1 - jnp.sqrt(3) * y / l
    E11 = jnp.sqrt(3) / l
    muA = -jnp.sqrt(3) / l

    part11 = jnp.exp((3 + 2*jnp.sqrt(3) * y * l) / (2 * l ** 2))
    part12 = (E10 + E11 * muA) * norm.cdf(muA - y)
    part13 = E11 / jnp.sqrt(2 * jnp.pi) * jnp.exp(-(y - muA) ** 2 / 2)
    part1 = part11 * (part12 + part13)

    E20 = 1 + jnp.sqrt(3) * y / l
    E21 = jnp.sqrt(3) / l
    muB = jnp.sqrt(3) / l

    part21 = jnp.exp((3 - 2*jnp.sqrt(3) * y * l) / (2 * l ** 2))
    part22 = (E20 - E21 * muB) * norm.cdf(y - muB)
    part23 = E21 / jnp.sqrt(2 * jnp.pi) * jnp.exp(-(y - muB) ** 2 / 2)
    part2 = part21 * (part22 + part23)

    final = part1 + part2
    return final

In [50]:
Matern_kme(y, l)

Array([[0.25933702],
       [0.23876826]], dtype=float64)

In [64]:
seed = 10
rng_key = jax.random.PRNGKey(seed)

D = 2
N = 100000
mean = jnp.zeros(D) + 0.0
cov = jnp.eye(D) + 0.0

rng_key, _ = jax.random.split(rng_key)
x = jax.random.multivariate_normal(rng_key, mean, cov, shape=(N, ))
rng_key, _ = jax.random.split(rng_key)
y = jax.random.multivariate_normal(rng_key, mean, cov, shape=(2, ))

x1 = x[:, 0][:, None]
x2 = x[:, 1][:, None]

y1 = y[:, 0][:, None]
y2 = y[:, 1][:, None]

In [65]:
K1 = kernel.matrix(x1, y1)
K2 = kernel.matrix(x2, y2)
K = K1 + K2

print(K.mean(0))

[0.47497255 0.40209477]


In [66]:
Matern_kme(y1, l) + Matern_kme(y2, l)

Array([[0.47497702],
       [0.40287098]], dtype=float64)