In [1]:
from target import QTargetIMQ, QTargetAuto, QTargetKGM

import numpy as np
from jax import numpy as jnp

In [2]:
# Test manual IMQ class
x_test = np.array([0., 3.])

mean = np.array([1., 2.])
cov = np.array([[2., 3.], [3., 5.]])

log_p = lambda x:-np.log(2*np.pi) - 0.5*np.log(np.linalg.det(cov)) - 0.5*(x-mean)@np.linalg.inv(cov)@(x-mean)
grad_log_p = lambda x: -np.linalg.inv(cov)@(x-mean)
hess_log_p = lambda x: -np.linalg.inv(cov)
linv = np.linalg.inv(cov)

test_manual_imq = QTargetIMQ(log_p=log_p, grad_log_p=grad_log_p, hess_log_p=hess_log_p, linv=linv)

print("log-pdf   :", test_manual_imq.log_p(x_test))
print("glog-pdf  :", test_manual_imq.grad_log_p(x_test))
print("hlog-pdf  :", test_manual_imq.hess_log_p(x_test))
print("log-q-pdf :", test_manual_imq.log_q(x_test))
print("glog-q-pdf:", test_manual_imq.grad_log_q(x_test))

log-pdf   : -8.337877066409348
glog-pdf  : [ 8. -5.]
hlog-pdf  : [[-5.  3.]
 [ 3. -2.]]
log-q-pdf : -6.05570297067543
glog-q-pdf: [ 7.42708333 -4.64583333]


In [3]:
# Test auto IMQ class
x_test = jnp.array([0., 3.])

mean = jnp.array([1., 2.])
cov = jnp.array([[2., 3.], [3., 5.]])

log_p = lambda x:-jnp.log(2*jnp.pi) - 0.5*jnp.log(jnp.linalg.det(cov)) - 0.5*(x-mean)@jnp.linalg.inv(cov)@(x-mean)
linv = jnp.linalg.inv(cov)
imq_kernel = lambda x, y: (1 + (x - y)@linv@(x - y))**(-0.5)

test_auto_imq = QTargetAuto(log_p=log_p, reproducing_kernel=imq_kernel)

print("log-pdf   :", test_auto_imq.log_p(x_test))
print("glog-pdf  :", test_auto_imq.grad_log_p(x_test))
print("hlog-pdf  :", test_auto_imq.hess_log_p(x_test))
print("log-q-pdf :", test_auto_imq.log_q(x_test))
print("glog-q-pdf:", test_auto_imq.grad_log_q(x_test))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


log-pdf   : -8.337875
glog-pdf  : [ 7.999998 -4.999999]
hlog-pdf  : [[-4.9999986  2.999999 ]
 [ 2.999999  -1.9999995]]
log-q-pdf : -6.0557013
glog-q-pdf: [ 7.4270816 -4.6458325]


In [4]:
# Test manual KGM class
x_test = np.array([0., 3.])

mean = np.array([1., 2.])
cov = np.array([[2., 3.], [3., 5.]])

log_p = lambda x:-np.log(2*np.pi) - 0.5*np.log(np.linalg.det(cov)) - 0.5*(x-mean)@np.linalg.inv(cov)@(x-mean)
grad_log_p = lambda x: -np.linalg.inv(cov)@(x-mean)
hess_log_p = lambda x: -np.linalg.inv(cov)
linv = np.linalg.inv(cov)

test_manual_kgm = QTargetKGM(log_p=log_p, grad_log_p=grad_log_p, hess_log_p=hess_log_p, linv=linv, s=3.0)

print("log-pdf   :", test_manual_kgm.log_p(x_test))
print("glog-pdf  :", test_manual_kgm.grad_log_p(x_test))
print("hlog-pdf  :", test_manual_kgm.hess_log_p(x_test))
print("log-q-pdf :", test_manual_kgm.log_q(x_test))
print("glog-q-pdf:", test_manual_kgm.grad_log_q(x_test))

log-pdf   : -8.337877066409348
glog-pdf  : [ 8. -5.]
hlog-pdf  : [[-5.  3.]
 [ 3. -2.]]
log-q-pdf : -3.227615211790754
glog-q-pdf: [ 6.36721585 -3.93922533]


In [5]:
# Test auto KGM class
x_test = jnp.array([0., 3.])

mean = jnp.array([1., 2.])
cov = jnp.array([[2., 3.], [3., 5.]])

log_p = lambda x:-jnp.log(2*jnp.pi) - 0.5*jnp.log(jnp.linalg.det(cov)) - 0.5*(x-mean)@jnp.linalg.inv(cov)@(x-mean)
linv = jnp.linalg.inv(cov)
s = 3.0

kgm_kernel = lambda x, y: (1 + x@linv@x)**((s-1)/2) *\
        (1 + y@linv@y)**((s-1)/2) *\
        (1 + (x-y)@linv@(x-y))**(-0.5) +\
        (1 + x@linv@y )/( jnp.sqrt(1+x@linv@x) * jnp.sqrt(1+y@linv@y) )

test_auto_kgm = QTargetAuto(log_p=log_p, reproducing_kernel=kgm_kernel)

print("log-pdf   :", test_auto_kgm.log_p(x_test))
print("glog-pdf  :", test_auto_kgm.grad_log_p(x_test))
print("hlog-pdf  :", test_auto_kgm.hess_log_p(x_test))
print("log-q-pdf :", test_auto_kgm.log_q(x_test))
print("glog-q-pdf:", test_auto_kgm.grad_log_q(x_test))

log-pdf   : -8.337875
glog-pdf  : [ 7.999998 -4.999999]
hlog-pdf  : [[-4.9999986  2.999999 ]
 [ 2.999999  -1.9999995]]
log-q-pdf : -3.227614
glog-q-pdf: [ 6.367214  -3.9392242]
