In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context

from models.bayesian_NN.NN_model import logprior, loglikelihood, init_network
from models.bayesian_NN.NN_data import X_train, y_train, X_test, y_test
from models.bayesian_NN.util import load_NN_MAP

from tuning.mamba import run_MAMBA
from tuning.ksd import imq_KSD




# NN - MAMBA - dt only

In [2]:
data = (X_train, y_train)

key = random.PRNGKey(0)
params_IC = load_NN_MAP()
centering_value = load_NN_MAP()

err_fn = lambda x,y: imq_KSD(x[:], y[:])

R = 10

batch_size = int(0.1*X_train.shape[0])

In [3]:
grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)


def get_fb_grads(samples):
    thin=10
    mygrads = [grad_log_post_fb(sam, *data) for sam in samples[::thin]]
    
    return samples[::thin], mygrads

### run MAMBA

In [4]:
build_kernel = lambda dt: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5)}
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/2, 12 arms, time budget = 2.50 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'dt': 0.01}., sample shape=(18, 79510). metric: 19124
Hyperparams: {'dt': 0.0031622776}., sample shape=(18, 79510). metric: 14300
Hyperparams: {'dt': 0.001}., sample shape=(17, 79510). metric: 12889
Hyperparams: {'dt': 0.00031622776}., sample shape=(17, 79510). metric: 5881
Hyperparams: {'dt': 1e-04}., sample shape=(17, 79510). metric: 3083
Hyperparams: {'dt': 3.1622778e-05}., sample shape=(17, 79510). metric: 3529
Hyperparams: {'dt': 1e-05}., sample shape=(17, 79510). metric: 433
Hyperparams: {'dt': 3.1622776e-06}., sample shape=(17, 79510). metric: 392
Hyperparams: {'dt': 1e-06}., sample shape=(17, 79510). metric: 478
Hyperparams: {'dt': 3.1622776e-07}., sample shape=(17, 79510). metric: 696
Hyperparams: {'dt': 1e-07}., sample shape=(17, 79510). metric: 959
Hyperparams: {'dt': 3.1622776e-08}., sample shape=(17, 79510). metric: 1147

Number of samples: [17, 17, 17, 17]


HBox(children=(HTML(value='Iteration 2/2, 4 arms, time budget = 7.50 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'dt': 3.1622776e-06}., sample shape=(64, 79510). metric: 284
Hyperparams: {'dt': 1e-05}., sample shape=(55, 79510). metric: 296
Hyperparams: {'dt': 1e-06}., sample shape=(58, 79510). metric: 378
Hyperparams: {'dt': 3.1622776e-07}., sample shape=(62, 79510). metric: 516

Number of samples: [64]
Running time: 168.6 sec
{'dt': 3.1622776e-06} 284.44485 (64, 79510)


In [6]:
build_kernel = lambda dt : build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/2, 12 arms, time budget = 2.50 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'dt': 0.01}., sample shape=(10, 79510). metric: 25600
Hyperparams: {'dt': 0.0031622776}., sample shape=(10, 79510). metric: 26539
Hyperparams: {'dt': 0.001}., sample shape=(11, 79510). metric: 17006
Hyperparams: {'dt': 0.00031622776}., sample shape=(10, 79510). metric: 8310
Hyperparams: {'dt': 1e-04}., sample shape=(10, 79510). metric: 3372
Hyperparams: {'dt': 3.1622778e-05}., sample shape=(10, 79510). metric: 5034
Hyperparams: {'dt': 1e-05}., sample shape=(10, 79510). metric: 383
Hyperparams: {'dt': 3.1622776e-06}., sample shape=(10, 79510). metric: 426
Hyperparams: {'dt': 1e-06}., sample shape=(10, 79510). metric: 568
Hyperparams: {'dt': 3.1622776e-07}., sample shape=(10, 79510). metric: 744
Hyperparams: {'dt': 1e-07}., sample shape=(10, 79510). metric: 1115
Hyperparams: {'dt': 3.1622776e-08}., sample shape=(10, 79510). metric: 1278

Number of samples: [10, 10, 10, 10]


HBox(children=(HTML(value='Iteration 2/2, 4 arms, time budget = 7.50 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'dt': 1e-05}., sample shape=(33, 79510). metric: 273
Hyperparams: {'dt': 3.1622776e-06}., sample shape=(37, 79510). metric: 330
Hyperparams: {'dt': 1e-06}., sample shape=(37, 79510). metric: 423
Hyperparams: {'dt': 3.1622776e-07}., sample shape=(38, 79510). metric: 559

Number of samples: [33]
Running time: 153.3 sec
{'dt': 1e-05} 272.99655 (33, 79510)


In [None]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params)#, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

In [7]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5),
               "L": [5, 10]eams
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params)#, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 96 arms, time budget = 0.25 sec'), FloatProgress(value=0.0, max=96.0…


Number of samples: [2, 3, 4, 2, 4, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 4, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1]


HBox(children=(HTML(value='Iteration 2/4, 32 arms, time budget = 0.75 sec'), FloatProgress(value=0.0, max=32.0…


Number of samples: [10, 11, 6, 5, 5, 3, 3, 6, 5, 3]


HBox(children=(HTML(value='Iteration 3/4, 10 arms, time budget = 2.40 sec'), FloatProgress(value=0.0, max=10.0…


Number of samples: [39, 21, 16]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 8.00 sec'), FloatProgress(value=0.0, max=3.0),…


Number of samples: [138]
Running time: 1467.7 sec
{'batch_size': 60, 'L': 5, 'dt': 3.1622776e-07} 242.33775 (138, 79510)


In [8]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=48.0…


Number of samples: [32, 6, 6, 6, 32, 33, 1, 6, 59, 57, 33, 1, 33, 32, 6, 6]


HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=16.0…


Number of samples: [106, 161, 97, 109, 17]


HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 7.38 sec'), FloatProgress(value=0.0, max=5.0),…


Number of samples: [701]
Running time: 1546.8 sec
{'batch_size': 60, 'dt': 1e-05} 120.7024 (701, 79510)


In [9]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=48.0…


Number of samples: [21, 21, 49, 22, 4, 4, 17, 4, 4, 46, 4, 1, 4, 46, 1, 1]


HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=16.0…


Number of samples: [56, 73, 62, 10, 11]


HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 7.38 sec'), FloatProgress(value=0.0, max=5.0),…


Number of samples: [230]
Running time: 1340.2 sec
{'batch_size': 600, 'dt': 3.1622776e-08} 230.5486 (230, 79510)
