In [1]:
import jax.numpy as jnp
import jax
import math
from tensorflow_probability.substrates import jax as tfp


def jax_dist(x, y):
    return jnp.sqrt(((x - y) ** 2).sum(-1)).squeeze()

distance = jax.vmap(jax_dist, in_axes=(None, 0), out_axes=1)
sign_func = jax.vmap(jnp.greater, in_axes=(None, 0), out_axes=1)


# @jax.jit
def my_Matern(x, y, l):
    r = distance(x, y).squeeze()
    part1 = 1 + math.sqrt(3) * r / l
    part2 = jnp.exp(-math.sqrt(3) * r / l)
    return part1 * part2


# @jax.jit
def one_d_my_Matern(x, y, l):
    r = jax_dist(x, y).squeeze()
    part1 = 1 + math.sqrt(3) * r / l
    part2 = jnp.exp(-math.sqrt(3) * r / l)
    return part1 * part2


# @jax.jit
def dx_Matern(x, y, l):
    sign = sign_func(x, y).squeeze().astype(float) * 2 - 1
    r = distance(x, y).squeeze()
    part1 = jnp.exp(-math.sqrt(3) / l * r) * (math.sqrt(3) / l * sign)
    part2 = (-math.sqrt(3) / l * sign) * jnp.exp(-math.sqrt(3) / l * r) * (1 + math.sqrt(3) / l * r)
    return part1 + part2


# @jax.jit
def dy_Matern(x, y, l):
    sign = -(sign_func(x, y).squeeze().astype(float) * 2 - 1)
    r = distance(x, y).squeeze()
    part1 = jnp.exp(-math.sqrt(3) / l * r) * (math.sqrt(3) / l * sign)
    part2 = (-math.sqrt(3) / l * sign) * jnp.exp(-math.sqrt(3) / l * r) * (1 + math.sqrt(3) / l * r)
    return part1 + part2


# @jax.jit
def dxdy_Matern(x, y, l):
    r = distance(x, y).squeeze()
    const = math.sqrt(3) / l
    part1 = const * const * jnp.exp(-const * r)
    part2 = -const * const * jnp.exp(-const * r) * (1 + const * r)
    part3 = const * jnp.exp(-const * r) * const
    return part1 + part2 + part3


# @jax.jit
def my_RBF(x, y, l):
    r = distance(x, y).squeeze()
    return jnp.exp(- r ** 2 / 2 / (l ** 2))


def my_Laplace(x, y, l):
    r = distance(x, y).squeeze()
    return jnp.exp(- r / l)


def dx_Laplace(x, y, l):
    sign = sign_func(x, y).squeeze().astype(float) * 2 - 1
    r = distance(x, y).squeeze()
    part1 = jnp.exp(- r / l) * (-sign)
    return part1


def dy_Laplace(x, y, l):
    sign = sign_func(x, y).squeeze().astype(float) * 2 - 1
    r = distance(x, y).squeeze()
    part1 = jnp.exp(- r / l) * sign
    return part1


def dxdy_Laplace(x, y, l):
    r = distance(x, y).squeeze()
    part1 = jnp.exp(- r / l) * (-1)
    return part1


def one_d_my_Laplace(x, y, l):
    r = jax_dist(x, y).squeeze()
    return jnp.exp(- r / l)


# @jax.jit
def one_d_my_RBF(x, y, l):
    r = jax_dist(x, y).squeeze()
    return jnp.exp(- r ** 2 / 2 / (l ** 2))



In [2]:
seed = 0
rng_key = jax.random.PRNGKey(seed)
x = jax.random.uniform(rng_key, shape=(3, 2))
rng_key, _ = jax.random.split(rng_key)
y = jax.random.uniform(rng_key, shape=(3, 2))
l = 0.5
batch_kernel = tfp.math.psd_kernels.MaternThreeHalves(amplitude=1., length_scale=.5)
K1 = batch_kernel.matrix(x, y)


In [3]:
seed = 0
rng_key = jax.random.PRNGKey(seed)
x = jax.random.uniform(rng_key, shape=(2, ))
rng_key, _ = jax.random.split(rng_key)
y = jax.random.uniform(rng_key, shape=(2, ))
l = 0.5
batch_kernel = tfp.math.psd_kernels.MaternThreeHalves(amplitude=1., length_scale=.5)
print(batch_kernel.apply(x, y))

print(batch_kernel.matrix(x[None, :], y[None, :]))

0.36376435
[[0.36376435]]


In [4]:
grad_x_K_fn = jax.grad(batch_kernel.apply, argnums=(0, ))


In [5]:
vec_grad_x_K_fn = jax.vmap(grad_x_K_fn, in_axes=(0, 0), out_axes=1)

rng_key = jax.random.PRNGKey(seed)
x = jax.random.uniform(rng_key, shape=(2, 1))
rng_key, _ = jax.random.split(rng_key)
y = jax.random.uniform(rng_key, shape=(2, 1))

x_dummy = jnp.stack((x, x), axis=0).reshape(4, 1)
y_dummy = jnp.stack((y, y), axis=1).reshape(4, 1)

vec_grad_x_K_fn(x_dummy, y_dummy)[0].reshape(2, 2, 1)

Array([[[ 1.2694279 ],
        [-1.2656392 ]],

       [[ 0.24535462],
        [-0.9564513 ]]], dtype=float32)

In [6]:
print(grad_x_K_fn(x[0, :], y[0, :]))
print(grad_x_K_fn(x[0, :], y[1, :]))
print(grad_x_K_fn(x[1, :], y[0, :]))
print(grad_x_K_fn(x[1, :], y[1, :]))

(Array([1.2694279], dtype=float32),)
(Array([0.24535462], dtype=float32),)
(Array([-1.2656392], dtype=float32),)
(Array([-0.9564513], dtype=float32),)


In [7]:
dx_Matern(x, y, l)

Array([[ 1.2694278 ,  0.24535465],
       [-1.265639  , -0.9564513 ]], dtype=float32)

In [20]:
seed = 0
rng_key = jax.random.PRNGKey(seed)
N = 2
D = 3
l = 0.5

batch_kernel = tfp.math.psd_kernels.MaternThreeHalves(amplitude=1., length_scale=0.5)
grad_x_K_fn = jax.grad(batch_kernel.apply, argnums=(0))
vec_grad_x_K_fn = jax.vmap(grad_x_K_fn, in_axes=(0, 0), out_axes=0)
grad_y_K_fn = jax.grad(batch_kernel.apply, argnums=(1))
vec_grad_y_K_fn = jax.vmap(grad_y_K_fn, in_axes=(0, 0), out_axes=0)


rng_key = jax.random.PRNGKey(seed)
x = jax.random.uniform(rng_key, shape=(N, D))
rng_key, _ = jax.random.split(rng_key)
y = jax.random.uniform(rng_key, shape=(N, D))

x_dummy = jnp.stack([x] * N, axis=1).reshape(N * N, D)
y_dummy = jnp.stack([y] * N, axis=0).reshape(N * N, D)

dx_K = vec_grad_x_K_fn(x_dummy, y_dummy).reshape(N, N, D)
dy_K = vec_grad_y_K_fn(x_dummy, y_dummy).reshape(N, N, D)



In [21]:
dx_K

Array([[[ 0.3103785 ,  1.0000873 ,  0.00740481],
        [-0.04561898,  0.58510643,  0.10524212]],

       [[-0.628633  , -0.03943844,  1.0889013 ],
        [-0.49021354,  0.37495774,  0.5537195 ]]], dtype=float32)

In [22]:
seed = 0
rng_key = jax.random.PRNGKey(seed)
N = 2
D = 3
l = 0.5

batch_kernel = tfp.math.psd_kernels.MaternThreeHalves(amplitude=1., length_scale=1.0)
grad_x_K_fn = jax.grad(batch_kernel.apply, argnums=(0))
vec_grad_x_K_fn = jax.vmap(grad_x_K_fn, in_axes=(0, 0), out_axes=0)
grad_y_K_fn = jax.grad(batch_kernel.apply, argnums=(1))
vec_grad_y_K_fn = jax.vmap(grad_y_K_fn, in_axes=(0, 0), out_axes=0)


x = x / 0.5
y = y / 0.5

x_dummy = jnp.stack([x] * N, axis=1).reshape(N * N, D)
y_dummy = jnp.stack([y] * N, axis=0).reshape(N * N, D)

dx_K = vec_grad_x_K_fn(x_dummy, y_dummy).reshape(N, N, D)
dy_K = vec_grad_y_K_fn(x_dummy, y_dummy).reshape(N, N, D)


In [23]:
dx_K

Array([[[ 0.15518925,  0.50004363,  0.00370241],
        [-0.02280949,  0.29255322,  0.05262106]],

       [[-0.3143165 , -0.01971922,  0.54445064],
        [-0.24510677,  0.18747887,  0.27685976]]], dtype=float32)

In [9]:
grad_xy_K_fn = jax.jacfwd(jax.jacrev(batch_kernel.apply, argnums=1), argnums=0)

def diag_sum_grad_xy_K_fn(x, y):
    return jnp.diag(grad_xy_K_fn(x, y)).sum()

vec_grad_xy_K_fn = jax.vmap(diag_sum_grad_xy_K_fn, in_axes=(0, 0), out_axes=0)

vec_grad_xy_K_fn(x_dummy, y_dummy).reshape(N, N)


Array([[2.5421534 , 0.16093946],
       [6.81763   , 0.98615026]], dtype=float32)

In [10]:
print(jnp.diag(grad_xy_K_fn(x[0, :], y[0, :])).sum())
print(jnp.diag(grad_xy_K_fn(x[0, :], y[1, :])).sum())
print(jnp.diag(grad_xy_K_fn(x[1, :], y[0, :])).sum())
print(jnp.diag(grad_xy_K_fn(x[1, :], y[1, :])).sum())


2.5421534
0.16093946
6.81763
0.98615026


In [120]:
dx_K

Array([[[ 0.3103785 ,  1.0000873 ,  0.00740481],
        [-0.04561898,  0.58510643,  0.10524212]],

       [[-0.628633  , -0.03943844,  1.0889013 ],
        [-0.49021354,  0.37495774,  0.5537195 ]]], dtype=float32)

In [121]:
dy_K

Array([[[-0.3103785 , -1.0000873 , -0.00740481],
        [ 0.04561898, -0.58510643, -0.10524212]],

       [[ 0.628633  ,  0.03943844, -1.0889013 ],
        [ 0.49021354, -0.37495774, -0.5537195 ]]], dtype=float32)

In [123]:
print(grad_x_K_fn(x[0, :], y[0, :]))
print(grad_x_K_fn(x[0, :], y[1, :]))
print(grad_x_K_fn(x[1, :], y[0, :]))
print(grad_x_K_fn(x[1, :], y[1, :]))

print(grad_y_K_fn(x[0, :], y[0, :]))
print(grad_y_K_fn(x[0, :], y[1, :]))
print(grad_y_K_fn(x[1, :], y[0, :]))
print(grad_y_K_fn(x[1, :], y[1, :]))


[0.3103785  1.0000873  0.00740481]
[-0.04561898  0.58510643  0.10524212]
[-0.628633   -0.03943844  1.0889013 ]
[-0.49021354  0.37495774  0.5537195 ]
[-0.3103785  -1.0000873  -0.00740481]
[ 0.04561898 -0.58510643 -0.10524212]
[ 0.628633    0.03943844 -1.0889013 ]
[ 0.49021354 -0.37495774 -0.5537195 ]


In [137]:
def stein_Matern(x, y, l, d_log_px, d_log_py):
    """
    :param x: N*D
    :param y: M*D
    :param l: scalar
    :param d_log_px: N*D
    :param d_log_py: M*D
    :return: N*M
    """
    N, D = x.shape
    M = y.shape[0]

    batch_kernel = tfp.math.psd_kernels.MaternThreeHalves(amplitude=1., length_scale=l)
    grad_x_K_fn = jax.grad(batch_kernel.apply, argnums=(0))
    vec_grad_x_K_fn = jax.vmap(grad_x_K_fn, in_axes=(0, 0), out_axes=0)
    grad_y_K_fn = jax.grad(batch_kernel.apply, argnums=(1))
    vec_grad_y_K_fn = jax.vmap(grad_y_K_fn, in_axes=(0, 0), out_axes=0)
    
    grad_xy_K_fn = jax.jacfwd(jax.jacrev(batch_kernel.apply, argnums=1), argnums=0)

    def diag_sum_grad_xy_K_fn(x, y):
        return jnp.diag(grad_xy_K_fn(x, y)).sum()

    vec_grad_xy_K_fn = jax.vmap(diag_sum_grad_xy_K_fn, in_axes=(0, 0), out_axes=0)

    x_dummy = jnp.stack([x] * N, axis=1).reshape(N * M, D)
    y_dummy = jnp.stack([y] * M, axis=0).reshape(N * M, D)

    K = batch_kernel.matrix(x, y)
    dx_K = vec_grad_x_K_fn(x_dummy, y_dummy).reshape(N, M, D)
    dy_K = vec_grad_y_K_fn(x_dummy, y_dummy).reshape(N, M, D)
    dxdy_K = vec_grad_xy_K_fn(x_dummy, y_dummy).reshape(N, N)

    part1 = d_log_px @ d_log_py.T * K
    part2 = (d_log_py[None, :] * dx_K).sum(-1)
    part3 = (d_log_px[:, None, :] * dy_K).sum(-1)
    part4 = dxdy_K
    
    print(dx_K, 'dxK')
    print(dy_K, 'dyK')

    print(part1, 'part1')
    print(part2, 'part2')
    print(part3, 'part3')
    print(part4, 'part4')
    return part1 + part2 + part3 + part4

def dx_Matern(x, y, l):
    sign = sign_func(x, y).squeeze().astype(float) * 2 - 1
    r = distance(x, y).squeeze()
    part1 = jnp.exp(-math.sqrt(3) / l * r) * (math.sqrt(3) / l * sign)
    part2 = (-math.sqrt(3) / l * sign) * jnp.exp(-math.sqrt(3) / l * r) * (1 + math.sqrt(3) / l * r)
    return part1 + part2


# @jax.jit
def dy_Matern(x, y, l):
    sign = -(sign_func(x, y).squeeze().astype(float) * 2 - 1)
    r = distance(x, y).squeeze()
    part1 = jnp.exp(-math.sqrt(3) / l * r) * (math.sqrt(3) / l * sign)
    part2 = (-math.sqrt(3) / l * sign) * jnp.exp(-math.sqrt(3) / l * r) * (1 + math.sqrt(3) / l * r)
    return part1 + part2


# @jax.jit
def dxdy_Matern(x, y, l):
    r = distance(x, y).squeeze()
    const = math.sqrt(3) / l
    part1 = const * const * jnp.exp(-const * r)
    part2 = -const * const * jnp.exp(-const * r) * (1 + const * r)
    part3 = const * jnp.exp(-const * r) * const
    return part1 + part2 + part3

def my_Matern(x, y, l):
    r = distance(x, y).squeeze()
    part1 = 1 + math.sqrt(3) * r / l
    part2 = jnp.exp(-math.sqrt(3) * r / l)
    return part1 * part2

def stein_Matern_old(x, y, l, d_log_px, d_log_py):
    K = my_Matern(x, y, l)
    dx_K = dx_Matern(x, y, l)
    dy_K = dy_Matern(x, y, l)
    dxdy_K = dxdy_Matern(x, y, l)
    part1 = d_log_px @ d_log_py.T * K
    part2 = d_log_py.T * dx_K
    part3 = d_log_px * dy_K
    part4 = dxdy_K
    
    print(dx_K, 'dxK')
    print(dy_K, 'dyK')
    
    print(part1, 'part1')
    print(part2, 'part2')
    print(part3, 'part3')
    print(part4, 'part4')
    return part1 + part2 + part3 + part4

In [138]:
seed = 0
rng_key = jax.random.PRNGKey(seed)
x = jax.random.uniform(rng_key, shape=(2, 1))
rng_key, _ = jax.random.split(rng_key)
y = jax.random.uniform(rng_key, shape=(2, 1))

d_log_px = -x
d_log_py = -y

K1 = stein_Matern_old(x, y, l, d_log_px, d_log_py)
K2 = stein_Matern(x, y, l, d_log_px, d_log_py)

[[ 1.2694278   0.24535465]
 [-1.265639   -0.9564513 ]] dxK
[[-1.2694278  -0.24535465]
 [ 1.265639    0.9564513 ]] dyK
[[0.07969617 0.05141427]
 [0.26686415 0.07992584]] part1
[[-0.6096557  -0.05848424]
 [ 0.60783607  0.22798559]] part2
[[ 0.27457133  0.05306907]
 [-1.0177308  -0.76910555]] part3
[[ 0.4116516  10.266797  ]
 [-0.47636604 -1.6226783 ]] part4
[[[ 1.2694279 ]
  [ 0.24535462]]

 [[-1.2656392 ]
  [-0.9564513 ]]] dxK
[[[-1.2694279 ]
  [-0.24535462]]

 [[ 1.2656392 ]
  [ 0.9564513 ]]] dyK
[[0.07969617 0.05141427]
 [0.2668642  0.07992584]] part1
[[-0.60965574 -0.05848423]
 [ 0.6078362   0.22798559]] part2
[[ 0.27457136  0.05306907]
 [-1.017731   -0.76910555]] part3
[[ 0.41165173 10.266797  ]
 [-0.4763666  -1.6226785 ]] part4


In [139]:
K1

Array([[ 0.15626344, 10.312797  ],
       [-0.6193967 , -2.0838723 ]], dtype=float32)

In [140]:
K2

Array([[ 0.15626353, 10.312797  ],
       [-0.61939716, -2.0838726 ]], dtype=float32)