In [1]:
import os
os.chdir('/home/zongchen/nest_bq')
import jax.numpy as jnp
import jax
from jax import config
config.update("jax_enable_x64", True)
from utils.kernel_means import *

rng_key = jax.random.PRNGKey(0)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [2]:
N = 10000
X1 = jnp.exp(jax.random.normal(rng_key, shape=(N, 1)))
rng_key, _ = jax.random.split(rng_key)
X2 = jnp.exp(jax.random.normal(rng_key, shape=(5, 1)))
l = 2.0

K = my_log_RBF(X1, X2, l)
empirical_mean = jnp.mean(K, axis=0)
print('Empirical kernel mean=', empirical_mean)
analytical_mean = kme_log_normal_log_RBF(0., 1., X2, l)
print('Analytic kernel mean=', analytical_mean)


Empirical kernel mean= [0.45349489 0.89313244 0.79373742 0.71834174 0.87206135]
Analytic kernel mean= [[0.45245266]
 [0.89432469]
 [0.79486433]
 [0.71842054]
 [0.87329693]]


In [3]:
N = 10000
a, b = -5., 1.
d = 1
X1 = jax.random.uniform(rng_key, shape=(N, 1), minval=a, maxval=b)
rng_key, _ = jax.random.split(rng_key)
X2 = jax.random.uniform(rng_key, shape=(5, 1), minval=a, maxval=b)
l = 1.0

K = my_Matern_12_product(X1, X2, l * jnp.ones(d))
empirical_mean = jnp.mean(K, axis=0)
print('Empirical kernel mean=', empirical_mean)
analytical_mean = kme_Matern_12_Uniform_1d(a * jnp.ones(d), b* jnp.ones(d), l * jnp.ones(d), X2)
print('Analytic kernel mean=', analytical_mean)


Empirical kernel mean= [0.26477975 0.26029472 0.27209775 0.29459142 0.3111509 ]
Analytic kernel mean= [[0.26478407]
 [0.25750361]
 [0.27222503]
 [0.29232695]
 [0.31552523]]


In [4]:
N = 100000
d = 5
a, b = -5., 1.
X1 = jax.random.uniform(rng_key, shape=(N, d), minval=a, maxval=b)
rng_key, _ = jax.random.split(rng_key)
X2 = jax.random.uniform(rng_key, shape=(5, d), minval=a, maxval=b)
l = 1.0 

K = my_Matern_12_product(X1, X2, l * jnp.ones(d))
empirical_mean = jnp.mean(K, axis=0)
print('Empirical kernel mean=', empirical_mean)
analytical_mean = kme_Matern_12_Uniform(a * jnp.ones(d), b * jnp.ones(d), l * jnp.ones(d), X2)
print('Analytic kernel mean=', analytical_mean)


Empirical kernel mean= [0.31081491 0.26196342 0.29382618 0.27768245 0.2660458 ]
Analytic kernel mean= [0.00292447 0.00108009 0.00214197 0.00151346 0.00120335]


In [5]:
N = 10000
a, b = -5., 1.
d = 1
X1 = jax.random.uniform(rng_key, shape=(N, 1), minval=a, maxval=b)
rng_key, _ = jax.random.split(rng_key)
X2 = jax.random.uniform(rng_key, shape=(5, 1), minval=a, maxval=b)
l = 1.0

K = my_Matern_32_product(X1, X2, l * jnp.ones(d))
empirical_mean = jnp.mean(K, axis=0)
print('Empirical kernel mean=', empirical_mean)
analytical_mean = kme_Matern_32_Uniform_1d(a* jnp.ones(d), b* jnp.ones(d), l* jnp.ones(d), X2)
print('Analytic kernel mean=', analytical_mean)


Empirical kernel mean= [0.37386373 0.37379593 0.22079658 0.33009174 0.34942444]
Analytic kernel mean= [[0.37722939]
 [0.37709775]
 [0.22183448]
 [0.32721466]
 [0.35329883]]


In [6]:
N = 1000000
d = 5
a, b = -5., 1.
X1 = jax.random.uniform(rng_key, shape=(N, d), minval=a, maxval=b)
rng_key, _ = jax.random.split(rng_key)
X2 = jax.random.uniform(rng_key, shape=(5, d), minval=a, maxval=b)
l = 10.0

K = my_Matern_32_product(X1, X2, l * jnp.ones(d))
empirical_mean = jnp.mean(K, axis=0)
print('Empirical kernel mean=', empirical_mean)
analytical_mean = kme_Matern_32_Uniform(a * jnp.ones(d), b * jnp.ones(d), l * jnp.ones(d), X2)
print('Analytic kernel mean=', analytical_mean)


Empirical kernel mean= [0.73787522 0.75689361 0.68579912 0.71766462 0.75204973]
Analytic kernel mean= [0.73776941 0.75693616 0.68589801 0.71774337 0.75205041]


In [7]:
N = 10000
d = 5
X1 = jax.random.normal(rng_key, shape=(N, d))
rng_key, _ = jax.random.split(rng_key)
X2 = jax.random.normal(rng_key, shape=(5, d))
l = 1.0


K = my_Matern_12_product(X1, X2, l * jnp.ones(d))
empirical_mean = jnp.mean(K, axis=0)
print('Empirical kernel mean=', empirical_mean)
analytical_mean = kme_Matern_12_Gaussian(l * jnp.ones(d), X2)
print('Analytic kernel mean=', analytical_mean)


Empirical kernel mean= [0.46942333 0.37730606 0.37065108 0.42514889 0.4227503 ]
Analytic kernel mean= [0.4681925  0.37747902 0.37137085 0.42498039 0.42337043]


In [8]:
N = 10000
a, b = -5., 1.
X1 = jax.random.normal(rng_key, shape=(N, 1))
rng_key, _ = jax.random.split(rng_key)
X2 = jax.random.normal(rng_key, shape=(5, 1))
l = 1.0

K = my_Matern_12_product(X1, X2, l * jnp.ones(1))
empirical_mean = jnp.mean(K, axis=0)
print('Empirical kernel mean=', empirical_mean)
analytical_mean = kme_Matern_12_Gaussian_1d(l, X2.squeeze())
print('Analytic kernel mean=', analytical_mean)


Empirical kernel mean= [0.48726468 0.50487266 0.30886715 0.31149302 0.31920672]
Analytic kernel mean= [0.48612827 0.5044938  0.30968273 0.31176245 0.32003897]


In [11]:
N = 10000
a, b = -1., 1.
X1 = jax.random.uniform(rng_key, shape=(N, 2), minval=a, maxval=b)
X2 = jax.random.uniform(rng_key, shape=(5, 2), minval=a, maxval=b)
l = 2.0

K = my_RBF(X1, X2, l)
empirical_mean = jnp.mean(K, axis=0)
print('Empirical kernel mean=', empirical_mean)
analytical_mean = kme_RBF_uniform(a, b, l, X2)
print('Analytic kernel mean=', analytical_mean)


Empirical kernel mean= [0.82546891 0.89780632 0.8007673  0.82570303 0.84545707]
Analytic kernel mean= [0.82424402 0.89678603 0.79974699 0.82471063 0.84419394]


In [10]:
N = 10000
a, b = 0., 1.
d = 2
X1 = jax.random.normal(rng_key, shape=(N, d))
rng_key, _ = jax.random.split(rng_key)
X2 = jax.random.normal(rng_key, shape=(5, d))
l = 0.2

K = my_RBF(X1, X2, l)
empirical_mean = jnp.mean(K, axis=0)
print('Empirical kernel mean=', empirical_mean)
analytical_mean = kme_RBF_Gaussian(a * jnp.ones([d]), b * jnp.eye(d), l, X2)
print('Analytic kernel mean=', analytical_mean)


Empirical kernel mean= [0.00381389 0.03175655 0.00434104 0.0132915  0.01822839]
Analytic kernel mean= [0.0034869  0.03140728 0.00469772 0.01202661 0.01834633]
