In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import grad, jit, random, partial, lax, vmap

import context
from models.bayesian_NN.NN_model import logprior, loglikelihood, init_network, accuracy_BNN
from models.bayesian_NN.NN_data import X_train, y_train, X_test, y_test
from models.bayesian_NN.util import load_NN_MAP, add_noise_NN_params

from tuning.mamba import timed_sampler
from tuning.ksd import imq_KSD, k_0_fun



# linear KSD

In [2]:
@jit
def linear_imq_KSD(samples, grads):
    """
    linear KSD with imq kernel (Liu 2016)
    """
    c, beta = 1., -0.5
    N_2 = int(samples.shape[0]/2)
    batch_k_0_fun = vmap(k_0_fun, in_axes=(0,0,0,0,None, None))
    
    le_sum = jnp.sum(batch_k_0_fun(samples[::2], samples[1::2], grads[::2], grads[1::2], c, beta))
    return le_sum/N_2


In [38]:

key = random.PRNGKey(0)
N = 500 # num samples
d = 5 # dimension

samples = random.normal(key, shape=(N, d))
grads = -samples

J = 10
key = random.PRNGKey(1)
V = random.normal(key, shape=(J, d))

In [42]:
%time imq_KSD(samples, grads).block_until_ready()

CPU times: user 10.5 ms, sys: 1.01 ms, total: 11.6 ms
Wall time: 10.1 ms


Buffer(0.13167463, dtype=float32)

In [43]:
%time linear_imq_KSD(samples, grads).block_until_ready()

CPU times: user 286 µs, sys: 74 µs, total: 360 µs
Wall time: 192 µs


Buffer(-0.00093835, dtype=float32)

# FSSD

In [8]:
def imq_kernel(x, y):
    c, beta = 1., -0.5
    return (c + jnp.dot(x-y, x-y))**beta

k_dx = grad(imq_kernel, 0)

def xi(x, grad_x, v):
    return grad_x*imq_kernel(x,v) + k_dx(x,v)

batch_xi = vmap(xi, in_axes=(None, None, 0))


def tau(x, grad_x, V):
    d, = x.shape
    J, dv = V.shape
    assert dv == d
    return batch_xi(x, grad_x, V).reshape(J*d)/jnp.sqrt(d*J)
    
@jit
def delta_fn(x, y, g_x, g_y, V):
    return jnp.dot(tau(x, g_x, V), tau(y, g_y, V))



@jit
def FSSD(sgld_samples, sgld_grads, V):
    """
    FSSD with imq kernel
    """
    N = sgld_samples.shape[0]
    
    batch_delta_fun_rows = jit(vmap(delta_fn, in_axes=(None,0,None,0, None)))

    def body_ksd(le_sum, x):
        my_sample, my_grad = x
        le_sum += jnp.sum(batch_delta_fun_rows(my_sample, sgld_samples, my_grad, sgld_grads, V))
        return le_sum, None

    le_sum, _ = lax.scan(body_ksd, 0., (sgld_samples, sgld_grads))
    return jnp.sqrt(le_sum)/N

In [9]:
delta_fn(samples[5], -samples[5], samples[2], -samples[2], V)

DeviceArray(-0.15478922, dtype=float32)

In [10]:
FSSD(samples, grads, V)

DeviceArray(0.03130005, dtype=float32)

In [61]:

key = random.PRNGKey(0)
N = 10_000 # num samples
d = 5 # dimension

samples = random.normal(key, shape=(N, d))
grads = -samples

J = 10
key = random.PRNGKey(1)
V = random.normal(key, shape=(J, d))

In [64]:
%time imq_KSD(samples, grads).block_until_ready()

CPU times: user 3.42 s, sys: 3.01 ms, total: 3.42 s
Wall time: 3.42 s


Buffer(0.03080567, dtype=float32)

In [65]:
%time FSSD(samples, grads, V).block_until_ready()

CPU times: user 546 ms, sys: 1.94 ms, total: 548 ms
Wall time: 545 ms


Buffer(0.00288177, dtype=float32)