In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context

from models.bayesian_NN.NN_model import logprior, loglikelihood, init_network, get_ECE_MCE
from models.bayesian_NN.NN_data import X_train, y_train, X_test, y_test
from models.bayesian_NN.util import load_NN_MAP
from util import flatten_param_list

from tuning.mamba import run_MAMBA
from tuning.ksd import FSSD_opt, FSSD, imq_KSD, get_test_locations, linear_imq_KSD




# NN - MAMBA-ECE

In [2]:
data = (X_train, y_train)

key = random.PRNGKey(0)
params_IC = load_NN_MAP()
centering_value = load_NN_MAP()

# err_fn = lambda x,y: get_ECE_MCE(x[::10], X_test, y_test, M=10, pbar=False)[0]

# err_fn = lambda x,y: imq_KSD(x, y)

# err_fn = lambda x,y: FSSD_opt(x, y, get_test_locations(x), 100)
err_fn = lambda x,y: FSSD(x, y, get_test_locations(x))

# get_fb_grads = None
batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 3)]
print(batch_size_range)

R = 10

In [3]:
grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)


def get_fb_grads(samples):
    thin=10
    mygrads = [grad_log_post_fb(sam, *data) for sam in samples[::thin]]
    
    return samples[::thin], mygrads

### run MAMBA

In [4]:
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 36 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=36.0…

Hyperparams: {'batch_size': 60000, 'dt': 0.01}., sample shape=(1, 79510). metric: inf
Hyperparams: {'batch_size': 60000, 'dt': 0.0031622776}., sample shape=(1, 79510). metric: inf
Hyperparams: {'batch_size': 60000, 'dt': 0.001}., sample shape=(1, 79510). metric: inf
Hyperparams: {'batch_size': 60000, 'dt': 0.00031622776}., sample shape=(1, 79510). metric: inf
Hyperparams: {'batch_size': 60000, 'dt': 1e-04}., sample shape=(1, 79510). metric: inf
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}., sample shape=(1, 79510). metric: inf
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}., sample shape=(1, 79510). metric: inf
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}., sample shape=(1, 79510). metric: inf
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}., sample shape=(1, 79510). metric: inf
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}., sample shape=(1, 79510). metric: inf
Hyperparams: {'batch_size': 60000, 'dt': 1e-07}., sample shape=(1, 79510). metric: inf
Hyper

HBox(children=(HTML(value='Iteration 2/3, 12 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'batch_size': 6000, 'dt': 0.0031622776}., sample shape=(20, 79510). metric: 0
Hyperparams: {'batch_size': 600, 'dt': 0.01}., sample shape=(114, 79510). metric: 0
Hyperparams: {'batch_size': 600, 'dt': 0.001}., sample shape=(113, 79510). metric: -0
Hyperparams: {'batch_size': 6000, 'dt': 1e-05}., sample shape=(19, 79510). metric: 0
Hyperparams: {'batch_size': 600, 'dt': 0.0031622776}., sample shape=(124, 79510). metric: 0
Hyperparams: {'batch_size': 6000, 'dt': 0.01}., sample shape=(19, 79510). metric: 0
Hyperparams: {'batch_size': 600, 'dt': 1e-05}., sample shape=(129, 79510). metric: 0
Hyperparams: {'batch_size': 600, 'dt': 3.1622778e-05}., sample shape=(120, 79510). metric: 0
Hyperparams: {'batch_size': 600, 'dt': 1e-04}., sample shape=(118, 79510). metric: -0
Hyperparams: {'batch_size': 6000, 'dt': 0.001}., sample shape=(19, 79510). metric: -0
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-06}., sample shape=(123, 79510). metric: 0
Hyperparams: {'batch_size': 600, '

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 6.92 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 6000, 'dt': 0.001}., sample shape=(59, 79510). metric: 0
Hyperparams: {'batch_size': 600, 'dt': 0.001}., sample shape=(336, 79510). metric: 0
Hyperparams: {'batch_size': 600, 'dt': 1e-04}., sample shape=(405, 79510). metric: 0
Hyperparams: {'batch_size': 600, 'dt': 0.01}., sample shape=(415, 79510). metric: 0

Number of samples: [405]
Running time: 529.8 sec
{'batch_size': 600, 'dt': 1e-04} 3.1700256e-05 (405, 79510)


In [None]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

In [None]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

In [None]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

In [None]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

In [None]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)