In [12]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import grad, jit, random, partial, lax, vmap

import context
from models.bayesian_NN.NN_model import logprior, loglikelihood, init_network, accuracy_BNN
from models.bayesian_NN.NN_data import X_train, y_train, X_test, y_test
from models.bayesian_NN.util import load_NN_MAP, add_noise_NN_params

from tuning.mamba import timed_sampler
from tuning.ksd import imq_KSD, k_0_fun, FSSD, FSSD_O2, FSSD_opt

# linear KSD

In [2]:
@jit
def linear_imq_KSD(samples, grads):
    """
    linear KSD with imq kernel (Liu 2016)
    """
    c, beta = 1., -0.5
    N_2 = int(samples.shape[0]/2)
    batch_k_0_fun = vmap(k_0_fun, in_axes=(0,0,0,0,None, None))
    
    le_sum = jnp.sum(batch_k_0_fun(samples[::2], samples[1::2], grads[::2], grads[1::2], c, beta))
    return le_sum/N_2


In [3]:

def get_test_locations(samples, J=10, key=random.PRNGKey(0)):
    _, dim = jnp.shape(samples)
    gauss_mean = jnp.mean(samples, axis=0)
    gauss_cov = jnp.cov(samples.T) + 1e-10*jnp.eye(dim)
    gauss_chol = jnp.linalg.cholesky(gauss_cov)
    batch_get_samples = vmap(lambda k: jnp.dot(gauss_chol, random.normal(key, shape=(dim,))) + gauss_mean)
    V = batch_get_samples(random.split(key, J))
    return V


In [4]:

key = random.PRNGKey(60)
N = 500 # num samples
d = 5 # dimension

samples = random.normal(key, shape=(N, d))
grads = -samples

J = 10
key = random.PRNGKey(1)
# V = random.normal(key, shape=(J, d))

In [5]:
%time imq_KSD(samples, grads).block_until_ready()

CPU times: user 191 ms, sys: 2.83 ms, total: 194 ms
Wall time: 192 ms


DeviceArray(0.14592385, dtype=float32)

In [6]:
%time linear_imq_KSD(samples, grads).block_until_ready()

CPU times: user 103 ms, sys: 2.3 ms, total: 105 ms
Wall time: 103 ms


DeviceArray(0.04582699, dtype=float32)

In [8]:
%time FSSD(samples, grads, get_test_locations(samples)).block_until_ready()

CPU times: user 11.7 ms, sys: 2.98 ms, total: 14.7 ms
Wall time: 10.7 ms


DeviceArray(0.00029859, dtype=float32)

In [10]:
%time FSSD_O2(samples, grads, get_test_locations(samples)).block_until_ready()

CPU times: user 27.1 ms, sys: 4.97 ms, total: 32.1 ms
Wall time: 19.7 ms


DeviceArray(0.00029859, dtype=float32)

In [14]:
%time FSSD_opt(samples, grads, get_test_locations(samples), 100).block_until_ready()

CPU times: user 39 ms, sys: 3.5 ms, total: 42.5 ms
Wall time: 37.6 ms


DeviceArray(0.00034239, dtype=float32)

In [17]:
from jax.experimental.optimizers import adam

In [42]:
# @partial(jit, static_argnums=(3,))
# def FSSD_opt(sgld_samples, sgld_grads, V, Niter):

V = get_test_locations(samples)
Niter = 100

def opt_fssd_fn(v):
    J = 10
    _, dim = jnp.shape(samples)
    return -FSSD(samples, grads, v.reshape(J,dim))

# grad_opt_fn = jit(grad(opt_fssd_fn))
# init_fn, update, get_params = adam(1e-3)
# state = init_fn(V)

# def body(state, i):
#     return update(i, grad_opt_fn(get_params(state)), state), None

# state, _ = lax.scan(body, state, jnp.arange(Niter))
# V_opt = get_params(state)



In [48]:
from jax.scipy.optimize import minimize

def opt_v(V):
    J, dim = jnp.shape(V)
    
    def opt_fssd_fn(v):
        return -FSSD(samples, grads, v.reshape(J, dim))

    result = minimize(opt_fssd_fn, V.flatten(), method="BFGS")
    return result.x

In [49]:
opt_v(V)

DeviceArray([ 0.19024403, -1.1654866 , -0.35848635,  1.1371242 ,
              0.22448991,  0.19024403, -1.1654866 , -0.35848635,
              1.1371242 ,  0.22448991,  0.19024403, -1.1654866 ,
             -0.35848635,  1.1371242 ,  0.22448991,  0.19024403,
             -1.1654866 , -0.35848635,  1.1371242 ,  0.22448991,
              0.19024403, -1.1654866 , -0.35848635,  1.1371242 ,
              0.22448991,  0.19024403, -1.1654866 , -0.35848635,
              1.1371242 ,  0.22448991,  0.19024403, -1.1654866 ,
             -0.35848635,  1.1371242 ,  0.22448991,  0.19024403,
             -1.1654866 , -0.35848635,  1.1371242 ,  0.22448991,
              0.19024403, -1.1654866 , -0.35848635,  1.1371242 ,
              0.22448991,  0.19024403, -1.1654866 , -0.35848635,
              1.1371242 ,  0.22448991], dtype=float32)

In [44]:
V.shape

(10, 5)

In [46]:
minimize(opt_fssd_fn, V.flatten(), method="BFGS")

OptimizeResults(x=DeviceArray([ 0.19024403, -1.1654866 , -0.35848635,  1.1371242 ,
              0.22448991,  0.19024403, -1.1654866 , -0.35848635,
              1.1371242 ,  0.22448991,  0.19024403, -1.1654866 ,
             -0.35848635,  1.1371242 ,  0.22448991,  0.19024403,
             -1.1654866 , -0.35848635,  1.1371242 ,  0.22448991,
              0.19024403, -1.1654866 , -0.35848635,  1.1371242 ,
              0.22448991,  0.19024403, -1.1654866 , -0.35848635,
              1.1371242 ,  0.22448991,  0.19024403, -1.1654866 ,
             -0.35848635,  1.1371242 ,  0.22448991,  0.19024403,
             -1.1654866 , -0.35848635,  1.1371242 ,  0.22448991,
              0.19024403, -1.1654866 , -0.35848635,  1.1371242 ,
              0.22448991,  0.19024403, -1.1654866 , -0.35848635,
              1.1371242 ,  0.22448991], dtype=float32), success=DeviceArray(False, dtype=bool), status=Buffer(5, dtype=int32), fun=DeviceArray(-0.00029859, dtype=float32), jac=DeviceArray([-6.7624378e-0

In [None]:
FSSD(sgld_samples, sgld_grads, V_opt)

# random code


## FSSD

In [14]:
def imq_kernel(x, y):
    c, beta = 1., -0.5
    return (c + jnp.dot(x-y, x-y))**beta

k_dx = grad(imq_kernel, 0)

def xi(x, grad_x, v):
    return grad_x*imq_kernel(x,v) + k_dx(x,v)

batch_xi = vmap(xi, in_axes=(None, None, 0))


def tau(x, grad_x, V):
    d, = x.shape
    J, dv = V.shape
    assert dv == d
    return batch_xi(x, grad_x, V).reshape(J*d)/jnp.sqrt(d*J)
    
@jit
def delta_fn(x, y, g_x, g_y, V):
    return jnp.dot(tau(x, g_x, V), tau(y, g_y, V))



# @jit
def FSSD_O2(sgld_samples, sgld_grads, V):
    """
    FSSD with imq kernel
    """
    N = sgld_samples.shape[0]
    
    batch_delta_fun_rows = jit(vmap(delta_fn, in_axes=(None,0,None,0, None)))

    def body_ksd(le_sum, x):
        my_sample, my_grad = x
        le_sum += jnp.sum(batch_delta_fun_rows(my_sample, sgld_samples, my_grad, sgld_grads, V))
        return le_sum, None

    le_sum, _ = lax.scan(body_ksd, 0., (sgld_samples, sgld_grads))
    # remove diagonal
    le_sum -= jnp.sum(vmap(delta_fn, in_axes=(0,0,0,0, None))(sgld_samples, sgld_samples, sgld_grads, sgld_grads, V))
    return jnp.sqrt(le_sum)/N




In [15]:
FSSD_O2(samples, grads, V)

DeviceArray(0.0289157, dtype=float32)

In [40]:

key = random.PRNGKey(0)
N = 100 # num samples
d = 5 # dimension

samples = random.normal(key, shape=(N, d))
grads = -samples

J = 10
key = random.PRNGKey(1)
V = random.normal(key, shape=(J, d)) + 10

In [18]:
sgld_samples = samples
sgld_grads = grads

N = sgld_samples.shape[0]
d = sgld_samples[0].shape[0]
J = V.shape[0]

def imq_kernel(x, y):
    c, beta = 1., -0.5
    return (c + jnp.dot(x-y, x-y))**beta

k_dx = grad(imq_kernel, 0)

def xi(x, grad_x, v):
    return grad_x*imq_kernel(x,v) + k_dx(x,v)


def compute_xi_sum(v):
    batch_xi_sam = vmap(xi, in_axes=(0,0, None))
    all_xi = batch_xi_sam(sgld_samples, sgld_grads, v)
    return jnp.square(all_xi.sum(axis=0)) - jnp.square(all_xi).sum(axis=0)

batch_compute_xi_sum = vmap(compute_xi_sum, in_axes=(0,))

lesum = jnp.sum(jnp.sum(batch_compute_xi_sum(V), axis=0))
le_est = lesum/(d*J*N*(N-1))

In [12]:
xi(sgld_samples[0], sgld_grads[0], V[0])

DeviceArray([-0.01490127, -0.02905262,  0.01897729,  0.0462897 ,
              0.0708947 ], dtype=float32)

In [15]:
batch_xi_sam = vmap(xi, in_axes=(0,0, None))
all_xi = batch_xi_sam(sgld_samples, sgld_grads, V[0])

In [23]:
jnp.square(all_xi.sum(axis=0)) - jnp.square(all_xi).sum(axis=0)

DeviceArray([-0.13566822, -0.18629603, -0.15436825,  0.11876743,
             -0.12753035], dtype=float32)

In [24]:
jnp.square(all_xi).sum(axis=0)

DeviceArray([0.15292655, 0.19166648, 0.19066176, 0.18643297, 0.20289975],            dtype=float32)

In [26]:
batch_compute_xi_sum(V).shape

(10, 5)

In [30]:
jnp.sum(batch_compute_xi_sum(V), axis=0).sum()

DeviceArray(-4.8881726, dtype=float32)

In [31]:
batch_compute_xi_sum(V).sum()

DeviceArray(-4.888172, dtype=float32)

## compare to O(N^2)

In [19]:
xi(sgld_samples[1], sgld_grads[0], V[0])

DeviceArray([-0.01649672, -0.03237094,  0.02123744,  0.05179383,
              0.07955745], dtype=float32)

In [53]:
@jit
def FSSD_O2(samples, grads, V):
    "FSSD with naive O(N^2) implementation"
    
    def kernel_xi(x, dx, y, dy, v):
        return jnp.dot(xi(x,dx,v), xi(y,dy,v))
    
    def sum_gram_mat(v):
        "Sum gram matrix minus the diagonal"
        le_sum = jnp.sum(vmap(lambda samples2, grads2:
                          vmap(lambda samples1, grads1:
                               kernel_xi(samples1, grads1, samples2, grads2, v))(sgld_samples, sgld_grads))(sgld_samples, sgld_grads))

        le_sum -= jnp.sum(vmap(kernel_xi, in_axes=(0,0,0,0,None))(sgld_samples, sgld_grads,sgld_samples, sgld_grads, v))
        return le_sum
    
    return jnp.sum(vmap(sum_gram_mat)(V))/(N*(N-1)*d*J)


In [54]:
FSSD(sgld_samples, sgld_grads, V)

Buffer(-9.875096e-06, dtype=float32)

In [50]:
FSSD_O2(sgld_samples, sgld_grads, V)

Buffer(-9.875097e-06, dtype=float32)

In [22]:
kernel_xi(sgld_samples[0], sgld_grads[0], sgld_samples[5], sgld_grads[5], V[0])

DeviceArray(-0.00062116, dtype=float32)

In [32]:
def sum_gram_mat(v):
    le_sum = jnp.sum(vmap(lambda samples2, grads2:
                      vmap(lambda samples1, grads1:
                           kernel_xi(samples1, grads1, samples2, grads2, v))(sgld_samples, sgld_grads))(sgld_samples, sgld_grads))

    le_sum -= jnp.sum(vmap(kernel_xi, in_axes=(0,0,0,0,None))(sgld_samples, sgld_grads,sgld_samples, sgld_grads, v))
    return le_sum



In [36]:

sum_gram_mat(V[7])

DeviceArray(-0.5442765, dtype=float32)

In [39]:
jnp.sum(vmap(sum_gram_mat)(V))/(N*(N-1)*d*J)

DeviceArray(-9.875096e-06, dtype=float32)

### for a single dimension

In [34]:
sam_l = sgld_samples[:,0]
grad_l = sgld_grads[:,0]

def my_kernel_1d(x, y, dx, dy, v):
    return xi(x,dx,v)[0] * xi(y,dy,v)[0]

In [37]:
my_kernel_1d(sgld_samples[0], sgld_samples[1], sgld_grads[0], sgld_grads[1], V[0])

DeviceArray(0.00020165, dtype=float32)

In [46]:
mysum = 0.

for i in range(N):
    print(f"it {i}/{N}")
    for j in range(N):
        mysum += my_kernel_1d(sgld_samples[i], sgld_samples[j], sgld_grads[i], sgld_grads[j], V[0])
        
print(mysum)  

it 0/100
it 1/100
it 2/100
it 3/100
it 4/100
it 5/100
it 6/100
it 7/100


KeyboardInterrupt: 

In [102]:
le_sum = jnp.sum(vmap(lambda samples2, grads2:
                  vmap(lambda samples1, grads1:
                       my_kernel_1d(samples1, samples2, grads1, grads2, V[0]))(sgld_samples, sgld_grads))(sgld_samples, sgld_grads))

# le_sum -= jnp.sum(vmap(my_kernel_1d, in_axes=(0,0,0,0,None))(sgld_samples, sgld_grads,sgld_samples, sgld_grads, V[0]))



In [103]:
mymat = vmap(lambda samples2, grads2:
                  vmap(lambda samples1, grads1:
                       my_kernel_1d(samples1, samples2, grads1, grads2, V[0]))(sgld_samples, sgld_grads))(sgld_samples, sgld_grads)

jnp.diag(mymat).sum()

DeviceArray(0.15292656, dtype=float32)

In [104]:
le_sum

DeviceArray(0.01725833, dtype=float32)

In [None]:
my_kernel_1d(sgld_samples)

In [90]:
sum_xil = vmap(xi, in_axes=(0,0,None))(sgld_samples, sgld_grads, V[0])[:,0].sum()
# print(sum_xil)

sq_sum_xil = jnp.square(sum_xil)
print(sq_sum_xil)

0.01725833


In [78]:
sum_sq_xil = jnp.sum(jnp.square(vmap(xi, in_axes=(0,0,None))(sgld_samples, sgld_grads, V[0])[:,0]))
print(sum_sq_xil)

0.15292656


In [101]:
jnp.sum(vmap(my_kernel_1d, in_axes=(0,0,0,0,None))(sgld_samples, sgld_samples, sgld_grads, sgld_grads, V[0]))

DeviceArray(0.15292656, dtype=float32)

In [79]:
sq_sum_xil - sum_sq_xil

DeviceArray(-0.13566823, dtype=float32)

In [8]:
le_est

DeviceArray(-9.875096e-06, dtype=float32)

In [11]:
FSSD(sgld_samples, sgld_grads, V)

DeviceArray(-9.875096e-06, dtype=float32)

In [39]:
%time imq_KSD(samples, grads).block_until_ready()

CPU times: user 1.81 ms, sys: 888 µs, total: 2.7 ms
Wall time: 6.38 ms


Buffer(0.29847407, dtype=float32)

In [108]:
%time FSSD_O2(samples, grads, V).block_until_ready()

CPU times: user 900 µs, sys: 369 µs, total: 1.27 ms
Wall time: 598 µs


Buffer(0.00523681, dtype=float32)

In [109]:
FSSD(samples, grads, V).block_until_ready()

Buffer(-9.875096e-06, dtype=float32)

### linear version

In [42]:

@jit
def FSSD(sgld_samples, sgld_grads, V):
    """
    FSSD with imq kernel
    """
    N = sgld_samples.shape[0]
    d = sgld_samples[0].shape[0]
    J = V.shape[0]
    
    def imq_kernel(x, y):
        c, beta = 1., -0.5
        return (c + jnp.dot(x-y, x-y))**beta

    k_dx = grad(imq_kernel, 0)

    def xi(x, grad_x, v):
        return grad_x*imq_kernel(x,v) + k_dx(x,v)

    
    def compute_xi_sum(v):
        batch_xi_sam = vmap(xi, in_axes=(0,0, None))
        all_xi = batch_xi_sam(sgld_samples, sgld_grads, v)
        return jnp.square(all_xi.sum(axis=0)) - jnp.square(all_xi).sum(axis=0)

    batch_compute_xi_sum = vmap(compute_xi_sum, in_axes=(0,))

    lesum = jnp.sum(jnp.sum(batch_compute_xi_sum(V), axis=0))
    return lesum/(d*J*N*(N-1))


from jax.experimental.optimizers import adam
from functools import partial

@partial(jit, static_argnums=(3,))
def FSSD_opt(sgld_samples, sgld_grads, V, Niter):
    
    def opt_fssd_fn(v):
        return -FSSD(samples, grads, v)

    grad_opt_fn = jit(grad(opt_fssd_fn))
    init_fn, update, get_params = adam(1e-3)
    state = init_fn(V)

    def body(state, i):
        return update(i, grad_opt_fn(get_params(state)), state), None

    state, _ = lax.scan(body, state, jnp.arange(Niter))
    V_opt = get_params(state)
    return FSSD(sgld_samples, sgld_grads, V_opt)


    

In [2]:
J = 10
key = random.PRNGKey(6)
V = random.normal(key, shape=(J, d))

# FSSD(samples, grads, V)

NameError: name 'd' is not defined

In [3]:
# FSSD_opt(samples, grads, V, 400)

In [194]:
grad_FSSD(samples, grads, V)

DeviceArray([[-1.4949895e-05,  5.7517100e-06, -2.0329546e-05,
              -1.8899555e-05,  2.1406895e-06],
             [ 8.3543937e-06,  9.9315366e-07, -4.1955186e-06,
              -6.2971958e-07,  8.3952291e-06],
             [ 6.8008540e-06,  2.4068449e-05,  5.8266869e-07,
              -9.1672609e-07, -2.6546986e-06],
             [-1.8636318e-05,  9.6201402e-06, -2.4999619e-05,
               9.0209614e-06, -5.2451242e-06],
             [-8.1619528e-06,  1.1443335e-05,  5.7691977e-06,
              -3.8049802e-06, -2.2124510e-05],
             [-2.2405948e-06,  3.6173231e-05, -1.1673906e-05,
              -4.1901985e-06, -4.0308441e-06],
             [ 1.7546281e-06,  4.4181550e-05, -3.9278366e-05,
               4.0121464e-05, -9.9971294e-06],
             [ 3.5164259e-05,  8.9908526e-06, -1.5863337e-05,
               8.8084562e-06, -2.8339928e-06],
             [ 3.3414828e-05, -2.3344233e-05, -1.2204182e-05,
               9.7794109e-06,  7.4116906e-07],
             [ 9.73

In [195]:
# FSSD_O2(samples, grads, V)

FSSD(samples, grads, V)


Buffer(-0.00034945, dtype=float32)

In [196]:
V.shape

(10, 5)

In [206]:
update

<function jax.experimental.optimizers.adam.<locals>.update(i, g, state)>

In [228]:
from jax.scipy.optimize import minimize
from jax.experimental.optimizers import adam

def opt_fssd_fn(v):
    return -FSSD(samples, grads, v)

grad_opt_fn = jit(grad(opt_fssd_fn))

init_fn, update, get_params = adam(1e-3)



In [261]:
state = init_fn(V)

def body(state, i):
    return update(i, grad_opt_fn(get_params(state)), state), None

state, _ = lax.scan(body, state, jnp.arange(400))

# for i in range(400):
#     state = update(i, grad_opt_fn(get_params(state)), state)

In [262]:
print(opt_fssd_fn(V), opt_fssd_fn(get_params(state)))

0.00034944955 6.0526898e-05


In [27]:
plt.hist(get_params(state).flatten(), alpha=0.7, label="final")
plt.hist(V.flatten(), alpha=0.7, label="initial")

plt.legend()

NameError: name 'get_params' is not defined