In [4]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import grad, jit, random, partial, lax, vmap

import context
from models.bayesian_NN.NN_model import logprior, loglikelihood, init_network, accuracy_BNN
from models.bayesian_NN.NN_data import X_train, y_train, X_test, y_test
from models.bayesian_NN.util import load_NN_MAP, add_noise_NN_params

from tuning.mamba import timed_sampler
from tuning.ksd import imq_KSD, k_0_fun

# linear KSD

In [2]:
@jit
def linear_imq_KSD(samples, grads):
    """
    linear KSD with imq kernel (Liu 2016)
    """
    c, beta = 1., -0.5
    N_2 = int(samples.shape[0]/2)
    batch_k_0_fun = vmap(k_0_fun, in_axes=(0,0,0,0,None, None))
    
    le_sum = jnp.sum(batch_k_0_fun(samples[::2], samples[1::2], grads[::2], grads[1::2], c, beta))
    return le_sum/N_2


In [2]:

key = random.PRNGKey(0)
N = 50 # num samples
d = 5 # dimension

samples = random.normal(key, shape=(N, d))
grads = -samples

J = 10
key = random.PRNGKey(1)
V = random.normal(key, shape=(J, d))

In [4]:
%time imq_KSD(samples, grads).block_until_ready()

CPU times: user 170 ms, sys: 2.55 ms, total: 173 ms
Wall time: 171 ms


DeviceArray(0.4545508, dtype=float32)

In [5]:
%time linear_imq_KSD(samples, grads).block_until_ready()

CPU times: user 84.8 ms, sys: 2.15 ms, total: 87 ms
Wall time: 85.5 ms


DeviceArray(0.00021834, dtype=float32)

# FSSD

In [29]:
def imq_kernel(x, y):
    c, beta = 1., -0.5
    return (c + jnp.dot(x-y, x-y))**beta

k_dx = grad(imq_kernel, 0)

def xi(x, grad_x, v):
    return grad_x*imq_kernel(x,v) + k_dx(x,v)

batch_xi = vmap(xi, in_axes=(None, None, 0))


def tau(x, grad_x, V):
    d, = x.shape
    J, dv = V.shape
    assert dv == d
    return batch_xi(x, grad_x, V).reshape(J*d)/jnp.sqrt(d*J)
    
@jit
def delta_fn(x, y, g_x, g_y, V):
    return jnp.dot(tau(x, g_x, V), tau(y, g_y, V))



@jit
def FSSD_O2(sgld_samples, sgld_grads, V):
    """
    FSSD with imq kernel
    """
    N = sgld_samples.shape[0]
    
    batch_delta_fun_rows = jit(vmap(delta_fn, in_axes=(None,0,None,0, None)))

    def body_ksd(le_sum, x):
        my_sample, my_grad = x
        le_sum += jnp.sum(batch_delta_fun_rows(my_sample, sgld_samples, my_grad, sgld_grads, V))
        return le_sum, None

    le_sum, _ = lax.scan(body_ksd, 0., (sgld_samples, sgld_grads))
    return jnp.sqrt(le_sum)/N




In [33]:
xi(samples[0], grads[0], V[1])

DeviceArray([-0.13347642, -0.22785588,  0.26775482,  0.41801387,
              0.62521464], dtype=float32)

In [37]:
FSSD_O2(samples, grads, V)

DeviceArray(0.03130005, dtype=float32)

In [5]:

key = random.PRNGKey(0)
N = 100 # num samples
d = 5 # dimension

samples = random.normal(key, shape=(N, d))
grads = -samples

J = 10
key = random.PRNGKey(1)
V = random.normal(key, shape=(J, d)) + 10

In [None]:
    N = sgld_samples.shape[0]
    d = sgld_samples[0].shape[0]
    J = V.shape[0]
    
    def imq_kernel(x, y):
        c, beta = 1., -0.5
        return (c + jnp.dot(x-y, x-y))**beta

    k_dx = grad(imq_kernel, 0)

    def xi(x, grad_x, v):
        return grad_x*imq_kernel(x,v) + k_dx(x,v)

    
    def compute_xi_sum(v):
        batch_xi_sam = vmap(xi, in_axes=(0,0, None))
        all_xi = batch_xi_sam(sgld_samples, sgld_grads, v)
        return jnp.square(all_xi.sum(axis=0)) - jnp.square(all_xi).sum(axis=0)

    batch_compute_xi_sum = vmap(compute_xi_sum, in_axes=(0,))

    lesum = jnp.sum(jnp.sum(batch_compute_xi_sum(V), axis=0))
    return lesum/(d*J*N*(N-1))

In [39]:
%time imq_KSD(samples, grads).block_until_ready()

CPU times: user 1.81 ms, sys: 888 µs, total: 2.7 ms
Wall time: 6.38 ms


Buffer(0.29847407, dtype=float32)

In [41]:
%time FSSD_O2(samples, grads, V).block_until_ready()

CPU times: user 686 µs, sys: 151 µs, total: 837 µs
Wall time: 4.74 ms


Buffer(0.03130005, dtype=float32)

### linear version

In [25]:

@jit
def FSSD(sgld_samples, sgld_grads, V):
    """
    FSSD with imq kernel
    """
    N = sgld_samples.shape[0]
    d = sgld_samples[0].shape[0]
    J = V.shape[0]
    
    def imq_kernel(x, y):
        c, beta = 1., -0.5
        return (c + jnp.dot(x-y, x-y))**beta

    k_dx = grad(imq_kernel, 0)

    def xi(x, grad_x, v):
        return grad_x*imq_kernel(x,v) + k_dx(x,v)

    
    def compute_xi_sum(v):
        batch_xi_sam = vmap(xi, in_axes=(0,0, None))
        all_xi = batch_xi_sam(sgld_samples, sgld_grads, v)
        return jnp.square(all_xi.sum(axis=0)) - jnp.square(all_xi).sum(axis=0)

    batch_compute_xi_sum = vmap(compute_xi_sum, in_axes=(0,))

    lesum = jnp.sum(jnp.sum(batch_compute_xi_sum(V), axis=0))
    return lesum/(d*J*N*(N-1))


from jax.experimental.optimizers import adam
from functools import partial

@partial(jit, static_argnums=(3,))
def FSSD_opt(sgld_samples, sgld_grads, V, Niter):
    
    def opt_fssd_fn(v):
        return -FSSD(samples, grads, v)

    grad_opt_fn = jit(grad(opt_fssd_fn))
    init_fn, update, get_params = adam(1e-3)
    state = init_fn(V)

    def body(state, i):
        return update(i, grad_opt_fn(get_params(state)), state), None

    state, _ = lax.scan(body, state, jnp.arange(Niter))
    V_opt = get_params(state)
    return FSSD(sgld_samples, sgld_grads, V_opt)


    

In [2]:
J = 10
key = random.PRNGKey(6)
V = random.normal(key, shape=(J, d))

# FSSD(samples, grads, V)

NameError: name 'd' is not defined

In [3]:
# FSSD_opt(samples, grads, V, 400)

In [194]:
grad_FSSD(samples, grads, V)

DeviceArray([[-1.4949895e-05,  5.7517100e-06, -2.0329546e-05,
              -1.8899555e-05,  2.1406895e-06],
             [ 8.3543937e-06,  9.9315366e-07, -4.1955186e-06,
              -6.2971958e-07,  8.3952291e-06],
             [ 6.8008540e-06,  2.4068449e-05,  5.8266869e-07,
              -9.1672609e-07, -2.6546986e-06],
             [-1.8636318e-05,  9.6201402e-06, -2.4999619e-05,
               9.0209614e-06, -5.2451242e-06],
             [-8.1619528e-06,  1.1443335e-05,  5.7691977e-06,
              -3.8049802e-06, -2.2124510e-05],
             [-2.2405948e-06,  3.6173231e-05, -1.1673906e-05,
              -4.1901985e-06, -4.0308441e-06],
             [ 1.7546281e-06,  4.4181550e-05, -3.9278366e-05,
               4.0121464e-05, -9.9971294e-06],
             [ 3.5164259e-05,  8.9908526e-06, -1.5863337e-05,
               8.8084562e-06, -2.8339928e-06],
             [ 3.3414828e-05, -2.3344233e-05, -1.2204182e-05,
               9.7794109e-06,  7.4116906e-07],
             [ 9.73

In [195]:
# FSSD_O2(samples, grads, V)

FSSD(samples, grads, V)


Buffer(-0.00034945, dtype=float32)

In [196]:
V.shape

(10, 5)

In [206]:
update

<function jax.experimental.optimizers.adam.<locals>.update(i, g, state)>

In [228]:
from jax.scipy.optimize import minimize
from jax.experimental.optimizers import adam

def opt_fssd_fn(v):
    return -FSSD(samples, grads, v)

grad_opt_fn = jit(grad(opt_fssd_fn))

init_fn, update, get_params = adam(1e-3)



In [261]:
state = init_fn(V)

def body(state, i):
    return update(i, grad_opt_fn(get_params(state)), state), None

state, _ = lax.scan(body, state, jnp.arange(400))

# for i in range(400):
#     state = update(i, grad_opt_fn(get_params(state)), state)

In [262]:
print(opt_fssd_fn(V), opt_fssd_fn(get_params(state)))

0.00034944955 6.0526898e-05


In [27]:
plt.hist(get_params(state).flatten(), alpha=0.7, label="final")
plt.hist(V.flatten(), alpha=0.7, label="initial")

plt.legend()

NameError: name 'get_params' is not defined