In [1]:
from target import QTargetIMQ, QTargetAuto, QTargetKGM

import numpy as np
from scipy.stats import wishart
from jax import numpy as jnp

np.random.seed(1234)

In [2]:
dim = 4

x_test = np.random.normal(size=dim)
mean = np.random.normal(size=dim)
cov = wishart.rvs(dim + 10, np.eye(dim), size=1)

print("x_test:", x_test)
print("mean  :", mean)
print("cov   :", cov)

x_test: [ 0.47143516 -1.19097569  1.43270697 -0.3126519 ]
mean  : [-0.72058873  0.88716294  0.85958841 -0.6365235 ]
cov   : [[11.68150068  0.05364739 -7.66509583  3.39029399]
 [ 0.05364739 12.34410368  4.00531349  3.36496236]
 [-7.66509583  4.00531349 25.21091916 -9.90588931]
 [ 3.39029399  3.36496236 -9.90588931 10.75391052]]


In [3]:
# Test manual IMQ class
log_p = lambda x:-np.log(2*np.pi) - 0.5*np.log(np.linalg.det(cov)) - 0.5*(x-mean)@np.linalg.inv(cov)@(x-mean)
grad_log_p = lambda x: -np.linalg.inv(cov)@(x-mean)
hess_log_p = lambda x: -np.linalg.inv(cov)
linv = np.linalg.inv(cov)

test_manual_imq = QTargetIMQ(log_p=log_p, grad_log_p=grad_log_p, hess_log_p=hess_log_p, linv=linv)

print("log-pdf   :", test_manual_imq.log_p(x_test))
print("glog-pdf  :", test_manual_imq.grad_log_p(x_test))
print("hlog-pdf  :", test_manual_imq.hess_log_p(x_test))
print("log-q-pdf :", test_manual_imq.log_q(x_test))
print("glog-q-pdf:", test_manual_imq.grad_log_q(x_test))

log-pdf   : -7.153352187026759
glog-pdf  : [-0.17921809  0.33648779 -0.25339807 -0.31232095]
hlog-pdf  : [[-0.10860438  0.013307   -0.03654333 -0.0035867 ]
 [ 0.013307   -0.12428333  0.05865073  0.08871951]
 [-0.03654333  0.05865073 -0.09838846 -0.09746129]
 [-0.0035867   0.08871951 -0.09746129 -0.20939533]]
log-q-pdf : -7.235925371128592
glog-q-pdf: [-0.13873336  0.23413024 -0.15708047 -0.17007607]


In [4]:
# Test auto IMQ class
log_p = lambda x:-jnp.log(2*jnp.pi) - 0.5*jnp.log(jnp.linalg.det(cov)) - 0.5*(x-mean)@jnp.linalg.inv(cov)@(x-mean)
linv = jnp.linalg.inv(cov)
imq_kernel = lambda x, y: (1 + (x - y)@linv@(x - y))**(-0.5)

test_auto_imq = QTargetAuto(log_p=log_p, reproducing_kernel=imq_kernel)

print("log-pdf   :", test_auto_imq.log_p(x_test))
print("glog-pdf  :", test_auto_imq.grad_log_p(x_test))
print("hlog-pdf  :", test_auto_imq.hess_log_p(x_test))
print("log-q-pdf :", test_auto_imq.log_q(x_test))
print("glog-q-pdf:", test_auto_imq.grad_log_q(x_test))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


log-pdf   : -7.1533523
glog-pdf  : [-0.17921811  0.33648777 -0.2533981  -0.31232098]
hlog-pdf  : [[-0.10860439  0.013307   -0.03654334 -0.0035867 ]
 [ 0.013307   -0.12428331  0.05865073  0.08871952]
 [-0.03654334  0.05865073 -0.09838847 -0.09746131]
 [-0.0035867   0.08871952 -0.09746131 -0.20939535]]
log-q-pdf : -7.235925
glog-q-pdf: [-0.13873339  0.23413023 -0.15708049 -0.17007609]


In [5]:
# Test manual KGM class
log_p = lambda x:-np.log(2*np.pi) - 0.5*np.log(np.linalg.det(cov)) - 0.5*(x-mean)@np.linalg.inv(cov)@(x-mean)
grad_log_p = lambda x: -np.linalg.inv(cov)@(x-mean)
hess_log_p = lambda x: -np.linalg.inv(cov)
linv = np.linalg.inv(cov)

test_manual_kgm = QTargetKGM(log_p=log_p, grad_log_p=grad_log_p, hess_log_p=hess_log_p, linv=linv, s=3.0)

print("log-pdf   :", test_manual_kgm.log_p(x_test))
print("glog-pdf  :", test_manual_kgm.grad_log_p(x_test))
print("hlog-pdf  :", test_manual_kgm.hess_log_p(x_test))
print("log-q-pdf :", test_manual_kgm.log_q(x_test))
print("glog-q-pdf:", test_manual_kgm.grad_log_q(x_test))

log-pdf   : -7.153352187026759
glog-pdf  : [-0.17921809  0.33648779 -0.25339807 -0.31232095]
hlog-pdf  : [[-0.10860438  0.013307   -0.03654333 -0.0035867 ]
 [ 0.013307   -0.12428333  0.05865073  0.08871951]
 [-0.03654333  0.05865073 -0.09838846 -0.09746129]
 [-0.0035867   0.08871951 -0.09746129 -0.20939533]]
log-q-pdf : -6.832792089900934
glog-q-pdf: [-0.06523543  0.12608668 -0.05548275 -0.1164513 ]


In [6]:
# Test auto KGM class
s = 3.0
linv = jnp.linalg.inv(cov)
log_p = lambda x:-jnp.log(2*jnp.pi) - 0.5*jnp.log(jnp.linalg.det(cov)) - 0.5*(x-mean)@jnp.linalg.inv(cov)@(x-mean)

kgm_kernel = lambda x, y: (1 + x@linv@x)**((s-1)/2) *\
        (1 + y@linv@y)**((s-1)/2) *\
        (1 + (x-y)@linv@(x-y))**(-0.5) +\
        (1 + x@linv@y )/( jnp.sqrt(1+x@linv@x) * jnp.sqrt(1+y@linv@y) )

test_auto_kgm = QTargetAuto(log_p=log_p, reproducing_kernel=kgm_kernel)

print("log-pdf   :", test_auto_kgm.log_p(x_test))
print("glog-pdf  :", test_auto_kgm.grad_log_p(x_test))
print("hlog-pdf  :", test_auto_kgm.hess_log_p(x_test))
print("log-q-pdf :", test_auto_kgm.log_q(x_test))
print("glog-q-pdf:", test_auto_kgm.grad_log_q(x_test))

log-pdf   : -7.1533523
glog-pdf  : [-0.17921811  0.33648777 -0.2533981  -0.31232098]
hlog-pdf  : [[-0.10860439  0.013307   -0.03654334 -0.0035867 ]
 [ 0.013307   -0.12428331  0.05865073  0.08871952]
 [-0.03654334  0.05865073 -0.09838847 -0.09746131]
 [-0.0035867   0.08871952 -0.09746131 -0.20939535]]
log-q-pdf : -6.8327923
glog-q-pdf: [-0.06523542  0.12608667 -0.05548273 -0.11645128]
