In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer

import context

from models.bayesian_NN.NN_model import logprior, loglikelihood, init_network
from models.bayesian_NN.NN_data import X_train, y_train, X_test, y_test
from models.bayesian_NN.util import load_NN_MAP

from tuning.mamba import run_MAMBA
from tuning.ksd import imq_KSD




# NN - MAMBA

In [4]:
data = (X_train, y_train)

key = random.PRNGKey(0)
params_IC = load_NN_MAP()
centering_value = load_NN_MAP()

err_fn = lambda x,y: imq_KSD(x[:2], y[:2])

# T = 15
# T = 20

R = 5

### run MAMBA

In [5]:
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.38 sec'), FloatProgress(value=0.0, max=48.0…

Hyperparams: {'batch_size': 60000, 'dt': 0.01}. new KSD: 48154
Hyperparams: {'batch_size': 60000, 'dt': 0.0031622776}. new KSD: 64488
Hyperparams: {'batch_size': 60000, 'dt': 0.001}. new KSD: 57871
Hyperparams: {'batch_size': 60000, 'dt': 0.00031622776}. new KSD: 42231
Hyperparams: {'batch_size': 60000, 'dt': 1e-04}. new KSD: 21109
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}. new KSD: 1608
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. new KSD: 963
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}. new KSD: 1101
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}. new KSD: 1314
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}. new KSD: 1357
Hyperparams: {'batch_size': 60000, 'dt': 1e-07}. new KSD: 1435
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-08}. new KSD: 1414
Hyperparams: {'batch_size': 6000, 'dt': 0.01}. new KSD: 43599
Hyperparams: {'batch_size': 6000, 'dt': 0.0031622776}. new KSD: 32554
Hyperparams: {'batch_size': 6000, 'dt': 0.001}. new KSD: 43764

HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 1.15 sec'), FloatProgress(value=0.0, max=16.0…

Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. new KSD: 910
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}. new KSD: 929
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}. new KSD: 964
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}. new KSD: 1252
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-08}. new KSD: 1375
Hyperparams: {'batch_size': 60000, 'dt': 1e-07}. new KSD: 1376
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}. new KSD: 18578
Hyperparams: {'batch_size': 6000, 'dt': 1e-07}. new KSD: 1957
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-08}. new KSD: 2061
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-07}. new KSD: 1725
Hyperparams: {'batch_size': 6000, 'dt': 1e-05}. new KSD: 2014
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-06}. new KSD: 2198
Hyperparams: {'batch_size': 6000, 'dt': 1e-06}. new KSD: 1671
Hyperparams: {'batch_size': 6000, 'dt': 3.1622778e-05}. new KSD: 3151
Hyperparams: {'batch_size': 600, 'dt': 1e-07}. new KSD: 4800
Hy

HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 3.69 sec'), FloatProgress(value=0.0, max=5.0),…

Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. new KSD: 872
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}. new KSD: 754
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}. new KSD: 868
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}. new KSD: 964
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-08}. new KSD: 1474

Number of samples: [50]
Running time: 343.6 sec
{'batch_size': 60000, 'dt': 3.1622776e-06} 753.5069 (50, 79510)


In [4]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.38 sec'), FloatProgress(value=0.0, max=48.0…


Number of samples: [17, 118, 3, 3, 17, 17, 299, 119, 297, 3, 118, 119, 296, 3, 114, 297]


HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 1.15 sec'), FloatProgress(value=0.0, max=16.0…


Number of samples: [60, 9, 9, 60, 57]


HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 3.69 sec'), FloatProgress(value=0.0, max=5.0),…


Number of samples: [25]
Running time: 507.1 sec
{'batch_size': 60000, 'dt': 1e-05} 667.156 (25, 79510)


In [5]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)],
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 96 arms, time budget = 0.12 sec'), FloatProgress(value=0.0, max=96.0…


Number of samples: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 2, 2, 3, 3, 2, 3, 2, 2, 8, 13, 13, 12, 8, 13, 8, 14, 7, 2, 7, 15]


HBox(children=(HTML(value='Iteration 2/4, 32 arms, time budget = 0.38 sec'), FloatProgress(value=0.0, max=32.0…


Number of samples: [4, 4, 4, 4, 4, 6, 4, 9, 4, 4]


HBox(children=(HTML(value='Iteration 3/4, 10 arms, time budget = 1.20 sec'), FloatProgress(value=0.0, max=10.0…


Number of samples: [7, 8, 7]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 4.00 sec'), FloatProgress(value=0.0, max=3.0),…


Number of samples: [16]
Running time: 845.6 sec
{'batch_size': 60000, 'L': 5, 'dt': 1e-06} 594.93726 (16, 79510)


In [6]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)],
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 96 arms, time budget = 0.12 sec'), FloatProgress(value=0.0, max=96.0…


Number of samples: [5, 9, 22, 3, 3, 2, 2, 9, 22, 2, 2, 13, 5, 9, 2, 22, 9, 13, 5, 2, 3, 2, 2, 2, 22, 3, 2, 9, 2, 2, 2, 2]


HBox(children=(HTML(value='Iteration 2/4, 32 arms, time budget = 0.38 sec'), FloatProgress(value=0.0, max=32.0…


Number of samples: [27, 24, 7, 16, 15, 5, 27, 7, 15, 4]


HBox(children=(HTML(value='Iteration 3/4, 10 arms, time budget = 1.20 sec'), FloatProgress(value=0.0, max=10.0…


Number of samples: [84, 95, 53]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 4.00 sec'), FloatProgress(value=0.0, max=3.0),…


Number of samples: [265]
Running time: 1143.4 sec
{'batch_size': 600, 'L': 5, 'dt': 1e-06} 339.09305 (265, 79510)


In [7]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.38 sec'), FloatProgress(value=0.0, max=48.0…


Number of samples: [27, 5, 154, 26, 5, 26, 26, 5, 154, 156, 5, 155, 26, 154, 5, 155]


HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 1.15 sec'), FloatProgress(value=0.0, max=16.0…


Number of samples: [564, 567, 17, 17, 17]


HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 3.69 sec'), FloatProgress(value=0.0, max=5.0),…


Number of samples: [2031]
Running time: 3062.0 sec
{'batch_size': 600, 'dt': 3.1622778e-05} 146.16841 (2031, 79510)


In [None]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
# T = 10
grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.38 sec'), FloatProgress(value=0.0, max=48.0…