In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import grad, jit, random, partial, lax, vmap

import context
from models.bayesian_NN.NN_model import logprior, loglikelihood, init_network, accuracy_BNN
from models.bayesian_NN.NN_data import X_train, y_train, X_test, y_test
from models.bayesian_NN.util import load_NN_MAP, add_noise_NN_params

from tuning.mamba import timed_sampler
from tuning.ksd import imq_KSD, k_0_fun



# linear KSD

In [2]:
@jit
def linear_imq_KSD(samples, grads):
    """
    linear KSD with imq kernel (Liu 2016)
    """
    c, beta = 1., -0.5
    N_2 = int(samples.shape[0]/2)
    batch_k_0_fun = vmap(k_0_fun, in_axes=(0,0,0,0,None, None))
    
    le_sum = jnp.sum(batch_k_0_fun(samples[::2], samples[1::2], grads[::2], grads[1::2], c, beta))
    return le_sum/N_2


In [2]:

key = random.PRNGKey(0)
N = 50 # num samples
d = 5 # dimension

samples = random.normal(key, shape=(N, d))
grads = -samples

J = 10
key = random.PRNGKey(1)
V = random.normal(key, shape=(J, d))

In [4]:
%time imq_KSD(samples, grads).block_until_ready()

CPU times: user 170 ms, sys: 2.55 ms, total: 173 ms
Wall time: 171 ms


DeviceArray(0.4545508, dtype=float32)

In [5]:
%time linear_imq_KSD(samples, grads).block_until_ready()

CPU times: user 84.8 ms, sys: 2.15 ms, total: 87 ms
Wall time: 85.5 ms


DeviceArray(0.00021834, dtype=float32)

# FSSD

In [5]:
def imq_kernel(x, y):
    c, beta = 1., -0.5
    return (c + jnp.dot(x-y, x-y))**beta

k_dx = grad(imq_kernel, 0)

def xi(x, grad_x, v):
    return grad_x*imq_kernel(x,v) + k_dx(x,v)

batch_xi = vmap(xi, in_axes=(None, None, 0))


def tau(x, grad_x, V):
    d, = x.shape
    J, dv = V.shape
    assert dv == d
    return batch_xi(x, grad_x, V).reshape(J*d)/jnp.sqrt(d*J)
    
@jit
def delta_fn(x, y, g_x, g_y, V):
    return jnp.dot(tau(x, g_x, V), tau(y, g_y, V))



@jit
def FSSD(sgld_samples, sgld_grads, V):
    """
    FSSD with imq kernel
    """
    N = sgld_samples.shape[0]
    
    batch_delta_fun_rows = jit(vmap(delta_fn, in_axes=(None,0,None,0, None)))

    def body_ksd(le_sum, x):
        my_sample, my_grad = x
        le_sum += jnp.sum(batch_delta_fun_rows(my_sample, sgld_samples, my_grad, sgld_grads, V))
        return le_sum, None

    le_sum, _ = lax.scan(body_ksd, 0., (sgld_samples, sgld_grads))
    return jnp.sqrt(le_sum)/N




In [6]:
FSSD(samples, grads, V)

DeviceArray(0.05833324, dtype=float32)

In [7]:

key = random.PRNGKey(0)
N = 10_000 # num samples
d = 5 # dimension

samples = random.normal(key, shape=(N, d))
grads = -samples

J = 10
key = random.PRNGKey(1)
V = random.normal(key, shape=(J, d))

In [8]:
%time imq_KSD(samples, grads).block_until_ready()

CPU times: user 3.75 s, sys: 13.9 ms, total: 3.76 s
Wall time: 3.77 s


DeviceArray(0.03080567, dtype=float32)

In [9]:
%time FSSD(samples, grads, V).block_until_ready()

CPU times: user 817 ms, sys: 7.6 ms, total: 824 ms
Wall time: 826 ms


DeviceArray(0.00288177, dtype=float32)