In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context

from models.bayesian_NN.NN_model import logprior, loglikelihood, init_network
from models.bayesian_NN.NN_data import X_train, y_train, X_test, y_test
from models.bayesian_NN.util import load_NN_MAP

from tuning.mamba import run_MAMBA
from tuning.ksd import imq_KSD




# NN - MAMBA

In [2]:
data = (X_train, y_train)

key = random.PRNGKey(0)
params_IC = load_NN_MAP()
centering_value = load_NN_MAP()

err_fn = lambda x,y: imq_KSD(x[:5], y[:5])

# T = 15
# T = 20

R = 3

In [3]:
grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)


def get_fb_grads(samples):
    thin=5
    mygrads = [grad_log_post_fb(sam, *data) for sam in samples[::thin]]
    
    return samples[::thin], mygrads

### run MAMBA

In [4]:
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=48.0…

Hyperparams: {'batch_size': 60000, 'dt': 0.01}. metric: 38431
Hyperparams: {'batch_size': 60000, 'dt': 0.0031622776}. metric: 36401
Hyperparams: {'batch_size': 60000, 'dt': 0.001}. metric: 33733
Hyperparams: {'batch_size': 60000, 'dt': 0.00031622776}. metric: 25788
Hyperparams: {'batch_size': 60000, 'dt': 1e-04}. metric: 23065
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}. metric: 2650
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. metric: 704
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}. metric: 895
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}. metric: 1168
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}. metric: 1323
Hyperparams: {'batch_size': 60000, 'dt': 1e-07}. metric: 1431
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-08}. metric: 1403
Hyperparams: {'batch_size': 6000, 'dt': 0.01}. metric: 32400
Hyperparams: {'batch_size': 6000, 'dt': 0.0031622776}. metric: 27769
Hyperparams: {'batch_size': 6000, 'dt': 0.001}. metric: 24025
Hyperparams: {'

HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 0.69 sec'), FloatProgress(value=0.0, max=16.0…

Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. metric: 632
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}. metric: 846
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}. metric: 953
Hyperparams: {'batch_size': 6000, 'dt': 1e-05}. metric: 1192
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-06}. metric: 1562
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}. metric: 1223
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-08}. metric: 1366
Hyperparams: {'batch_size': 60000, 'dt': 1e-07}. metric: 1373
Hyperparams: {'batch_size': 6000, 'dt': 1e-07}. metric: 1800
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-08}. metric: 1415
Hyperparams: {'batch_size': 6000, 'dt': 1e-06}. metric: 1474
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-07}. metric: 1677
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}. metric: 13816
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-06}. metric: 3378
Hyperparams: {'batch_size': 600, 'dt': 1e-06}. metric: 3333
Hyperparams: {'bat

HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 2.22 sec'), FloatProgress(value=0.0, max=5.0),…

Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. metric: 682
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}. metric: 832
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}. metric: 900
Hyperparams: {'batch_size': 6000, 'dt': 1e-05}. metric: 1357
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}. metric: 1274

Number of samples: [35]
Running time: 290.2 sec
{'batch_size': 60000, 'dt': 1e-05} 681.99176 (35, 79510)


In [6]:
err_fn = lambda x,y: imq_KSD(x[:], y[:])

build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=48.0…

Hyperparams: {'batch_size': 60000, 'dt': 0.01}. metric: 81625
Hyperparams: {'batch_size': 60000, 'dt': 0.0031622776}. metric: 108797
Hyperparams: {'batch_size': 60000, 'dt': 0.001}. metric: 83430
Hyperparams: {'batch_size': 60000, 'dt': 0.00031622776}. metric: 29712
Hyperparams: {'batch_size': 60000, 'dt': 1e-04}. metric: 8383
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}. metric: 1920
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. metric: 1045
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}. metric: 1288
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}. metric: 1438
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}. metric: 1362
Hyperparams: {'batch_size': 60000, 'dt': 1e-07}. metric: 1444
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-08}. metric: 1421
Hyperparams: {'batch_size': 6000, 'dt': 0.01}. metric: 38681
Hyperparams: {'batch_size': 6000, 'dt': 0.0031622776}. metric: 35481
Hyperparams: {'batch_size': 6000, 'dt': 0.001}. metric: 29399
Hyperparams: 

HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 0.69 sec'), FloatProgress(value=0.0, max=16.0…

Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-06}. metric: 494
Hyperparams: {'batch_size': 6000, 'dt': 1e-06}. metric: 663
Hyperparams: {'batch_size': 600, 'dt': 1e-06}. metric: 569
Hyperparams: {'batch_size': 6000, 'dt': 1e-05}. metric: 551
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-07}. metric: 629
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-06}. metric: 627
Hyperparams: {'batch_size': 60, 'dt': 1e-07}. metric: 666
Hyperparams: {'batch_size': 600, 'dt': 1e-07}. metric: 659
Hyperparams: {'batch_size': 60, 'dt': 1e-06}. metric: 768
Hyperparams: {'batch_size': 60, 'dt': 3.1622776e-07}. metric: 724
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. metric: 637
Hyperparams: {'batch_size': 600, 'dt': 1e-05}. metric: 560
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-08}. metric: 1047
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}. metric: 917
Hyperparams: {'batch_size': 60, 'dt': 3.1622776e-08}. metric: 856
Hyperparams: {'batch_size': 60, 'dt': 3.1622776e-06}. m

HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 2.22 sec'), FloatProgress(value=0.0, max=5.0),…

Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-06}. metric: 261
Hyperparams: {'batch_size': 6000, 'dt': 1e-05}. metric: 331
Hyperparams: {'batch_size': 600, 'dt': 1e-05}. metric: 398
Hyperparams: {'batch_size': 600, 'dt': 1e-06}. metric: 337
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-06}. metric: 378

Number of samples: [272]
Running time: 606.1 sec
{'batch_size': 600, 'dt': 3.1622776e-06} 261.01947 (272, 79510)


In [7]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=48.0…

Hyperparams: {'batch_size': 60000, 'dt': 0.01}. metric: 81039
Hyperparams: {'batch_size': 60000, 'dt': 0.0031622776}. metric: 107830
Hyperparams: {'batch_size': 60000, 'dt': 0.001}. metric: 82992
Hyperparams: {'batch_size': 60000, 'dt': 0.00031622776}. metric: 29394
Hyperparams: {'batch_size': 60000, 'dt': 1e-04}. metric: 6824
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}. metric: 3284
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. metric: 1216
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}. metric: 1247
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}. metric: 1448
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}. metric: 1361
Hyperparams: {'batch_size': 60000, 'dt': 1e-07}. metric: 1442
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-08}. metric: 1422
Hyperparams: {'batch_size': 6000, 'dt': 0.01}. metric: 57313
Hyperparams: {'batch_size': 6000, 'dt': 0.0031622776}. metric: 53572
Hyperparams: {'batch_size': 6000, 'dt': 0.001}. metric: 43325
Hyperparams: 

HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 0.69 sec'), FloatProgress(value=0.0, max=16.0…

Hyperparams: {'batch_size': 60, 'dt': 3.1622776e-06}. metric: 366
Hyperparams: {'batch_size': 60, 'dt': 1e-06}. metric: 399
Hyperparams: {'batch_size': 600, 'dt': 1e-05}. metric: 331
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-06}. metric: 363
Hyperparams: {'batch_size': 600, 'dt': 1e-06}. metric: 477
Hyperparams: {'batch_size': 6000, 'dt': 1e-05}. metric: 529
Hyperparams: {'batch_size': 60, 'dt': 3.1622776e-07}. metric: 497
Hyperparams: {'batch_size': 6000, 'dt': 1e-06}. metric: 643
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-07}. metric: 562
Hyperparams: {'batch_size': 60, 'dt': 1e-05}. metric: 935
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-06}. metric: 614
Hyperparams: {'batch_size': 60, 'dt': 1e-07}. metric: 619
Hyperparams: {'batch_size': 600, 'dt': 1e-07}. metric: 823
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. metric: 866
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}. metric: 1029
Hyperparams: {'batch_size': 60, 'dt': 3.1622776e-08}. metric: 96

HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 2.22 sec'), FloatProgress(value=0.0, max=5.0),…

Hyperparams: {'batch_size': 600, 'dt': 1e-05}. metric: 253
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-06}. metric: 249
Hyperparams: {'batch_size': 60, 'dt': 3.1622776e-06}. metric: 273
Hyperparams: {'batch_size': 60, 'dt': 1e-06}. metric: 292
Hyperparams: {'batch_size': 600, 'dt': 1e-06}. metric: 339

Number of samples: [179]
Running time: 692.7 sec
{'batch_size': 600, 'dt': 3.1622776e-06} 248.52815 (179, 79510)


In [8]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)],
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 96 arms, time budget = 0.07 sec'), FloatProgress(value=0.0, max=96.0…

Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.01}. metric: 44773
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.0031622776}. metric: 67544
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.001}. metric: 68994
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.00031622776}. metric: 80529
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-04}. metric: 144456
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622778e-05}. metric: 16597
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-05}. metric: 4157
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-06}. metric: 1561
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-06}. metric: 1046
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-07}. metric: 1071
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-07}. metric: 1382
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-08}. metric: 1324
Hyperparams: {'batch_size': 60000, 'L': 10, 'dt': 0.01}. metric: 82697
Hyperparams: {'batch_size': 60000, '

HBox(children=(HTML(value='Iteration 2/4, 32 arms, time budget = 0.22 sec'), FloatProgress(value=0.0, max=32.0…

Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 3.1622776e-07}. metric: 1093
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 1e-07}. metric: 1106
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-06}. metric: 1357
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-07}. metric: 949
Hyperparams: {'batch_size': 6000, 'L': 10, 'dt': 3.1622776e-07}. metric: 952
Hyperparams: {'batch_size': 60000, 'L': 10, 'dt': 3.1622776e-07}. metric: 991
Hyperparams: {'batch_size': 60, 'L': 10, 'dt': 3.1622776e-08}. metric: 934
Hyperparams: {'batch_size': 6000, 'L': 5, 'dt': 3.1622776e-06}. metric: 3345
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 1e-07}. metric: 850
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 3.1622776e-08}. metric: 796
Hyperparams: {'batch_size': 60000, 'L': 10, 'dt': 1e-07}. metric: 1179
Hyperparams: {'batch_size': 6000, 'L': 5, 'dt': 1e-06}. metric: 1253
Hyperparams: {'batch_size': 6000, 'L': 10, 'dt': 1e-07}. metric: 1033
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt':

HBox(children=(HTML(value='Iteration 3/4, 10 arms, time budget = 0.72 sec'), FloatProgress(value=0.0, max=10.0…

Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 3.1622776e-08}. metric: 600
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 1e-07}. metric: 680
Hyperparams: {'batch_size': 60, 'L': 10, 'dt': 3.1622776e-08}. metric: 857
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-07}. metric: 1058
Hyperparams: {'batch_size': 6000, 'L': 10, 'dt': 3.1622776e-07}. metric: 676
Hyperparams: {'batch_size': 60000, 'L': 10, 'dt': 3.1622776e-07}. metric: 1056
Hyperparams: {'batch_size': 6000, 'L': 5, 'dt': 3.1622776e-07}. metric: 937
Hyperparams: {'batch_size': 6000, 'L': 10, 'dt': 1e-07}. metric: 652
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 3.1622776e-07}. metric: 526
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 1e-07}. metric: 635

Number of samples: [18, 10, 9]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 2.40 sec'), FloatProgress(value=0.0, max=3.0),…

Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 3.1622776e-07}. metric: 344
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 3.1622776e-08}. metric: 562
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 1e-07}. metric: 422

Number of samples: [61]
Running time: 844.8 sec
{'batch_size': 600, 'L': 5, 'dt': 3.1622776e-07} 343.58423 (61, 79510)


In [9]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)],
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 96 arms, time budget = 0.07 sec'), FloatProgress(value=0.0, max=96.0…

Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.01}. metric: 81607
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.0031622776}. metric: 47814
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.001}. metric: 65174
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.00031622776}. metric: 80537
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-04}. metric: 136342
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622778e-05}. metric: 29428
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-05}. metric: 3973
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-06}. metric: 1684
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-06}. metric: 1065
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-07}. metric: 1067
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-07}. metric: 1375
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-08}. metric: 1326
Hyperparams: {'batch_size': 60000, 'L': 10, 'dt': 0.01}. metric: 82677
Hyperparams: {'batch_size': 60000, '

HBox(children=(HTML(value='Iteration 2/4, 32 arms, time budget = 0.22 sec'), FloatProgress(value=0.0, max=32.0…

Hyperparams: {'batch_size': 60, 'L': 5, 'dt': 3.1622776e-07}. metric: 563
Hyperparams: {'batch_size': 60, 'L': 5, 'dt': 1e-06}. metric: 740
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 1e-06}. metric: 998
Hyperparams: {'batch_size': 60, 'L': 10, 'dt': 1e-07}. metric: 570
Hyperparams: {'batch_size': 6000, 'L': 5, 'dt': 3.1622776e-06}. metric: 1410
Hyperparams: {'batch_size': 60, 'L': 10, 'dt': 3.1622776e-07}. metric: 710
Hyperparams: {'batch_size': 6000, 'L': 5, 'dt': 1e-06}. metric: 928
Hyperparams: {'batch_size': 60, 'L': 5, 'dt': 1e-07}. metric: 695
Hyperparams: {'batch_size': 6000, 'L': 10, 'dt': 3.1622776e-07}. metric: 824
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 3.1622776e-07}. metric: 709
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-06}. metric: 1002
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-07}. metric: 1160
Hyperparams: {'batch_size': 60, 'L': 10, 'dt': 3.1622776e-08}. metric: 909
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 1e-07}. metric: 84

HBox(children=(HTML(value='Iteration 3/4, 10 arms, time budget = 0.72 sec'), FloatProgress(value=0.0, max=10.0…

Hyperparams: {'batch_size': 60, 'L': 5, 'dt': 3.1622776e-07}. metric: 444
Hyperparams: {'batch_size': 60, 'L': 10, 'dt': 1e-07}. metric: 477
Hyperparams: {'batch_size': 60, 'L': 5, 'dt': 1e-07}. metric: 559
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 3.1622776e-07}. metric: 544
Hyperparams: {'batch_size': 60, 'L': 10, 'dt': 3.1622776e-07}. metric: 569
Hyperparams: {'batch_size': 60, 'L': 5, 'dt': 1e-06}. metric: 552
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 3.1622776e-08}. metric: 704
Hyperparams: {'batch_size': 6000, 'L': 10, 'dt': 3.1622776e-07}. metric: 902
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 1e-07}. metric: 673
Hyperparams: {'batch_size': 6000, 'L': 10, 'dt': 1e-07}. metric: 878

Number of samples: [29, 18, 12]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 2.40 sec'), FloatProgress(value=0.0, max=3.0),…

Hyperparams: {'batch_size': 60, 'L': 5, 'dt': 3.1622776e-07}. metric: 307
Hyperparams: {'batch_size': 60, 'L': 10, 'dt': 1e-07}. metric: 323
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 3.1622776e-07}. metric: 372

Number of samples: [100]
Running time: 1099.1 sec
{'batch_size': 60, 'L': 5, 'dt': 3.1622776e-07} 307.215 (100, 79510)


In [10]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=48.0…

Hyperparams: {'batch_size': 60000, 'dt': 0.01}. metric: 81348
Hyperparams: {'batch_size': 60000, 'dt': 0.0031622776}. metric: 105126
Hyperparams: {'batch_size': 60000, 'dt': 0.001}. metric: 81248
Hyperparams: {'batch_size': 60000, 'dt': 0.00031622776}. metric: 32771
Hyperparams: {'batch_size': 60000, 'dt': 1e-04}. metric: 8883
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}. metric: 1058
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. metric: 859
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}. metric: 1489
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}. metric: 1418
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}. metric: 1410
Hyperparams: {'batch_size': 60000, 'dt': 1e-07}. metric: 1448
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-08}. metric: 1458
Hyperparams: {'batch_size': 6000, 'dt': 0.01}. metric: inf
Hyperparams: {'batch_size': 6000, 'dt': 0.0031622776}. metric: inf
Hyperparams: {'batch_size': 6000, 'dt': 0.001}. metric: inf
Hyperparams: {'batch

HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 0.69 sec'), FloatProgress(value=0.0, max=16.0…

Hyperparams: {'batch_size': 6000, 'dt': 1e-06}. metric: 641
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-08}. metric: 664
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. metric: 818
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-07}. metric: 673
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-06}. metric: 961
Hyperparams: {'batch_size': 600, 'dt': 1e-07}. metric: 966
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}. metric: 1345
Hyperparams: {'batch_size': 6000, 'dt': 1e-05}. metric: 1418
Hyperparams: {'batch_size': 600, 'dt': 1e-06}. metric: 1521
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-07}. metric: 1110
Hyperparams: {'batch_size': 6000, 'dt': 1e-07}. metric: 776
Hyperparams: {'batch_size': 60, 'dt': 3.1622776e-08}. metric: 1517
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}. metric: 1119
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-08}. metric: 1216
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}. metric: 744
Hyperparams: {'batch_size': 600

HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 2.22 sec'), FloatProgress(value=0.0, max=5.0),…

Hyperparams: {'batch_size': 6000, 'dt': 1e-06}. metric: 424
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-08}. metric: 439
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-07}. metric: 444
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}. metric: 550
Hyperparams: {'batch_size': 6000, 'dt': 1e-07}. metric: 494

Number of samples: [39]
Running time: 921.4 sec
{'batch_size': 6000, 'dt': 1e-06} 424.0574 (39, 79510)


In [11]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
# T = 10
grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=48.0…

Hyperparams: {'batch_size': 60000, 'dt': 0.01}. metric: 81125
Hyperparams: {'batch_size': 60000, 'dt': 0.0031622776}. metric: 115621
Hyperparams: {'batch_size': 60000, 'dt': 0.001}. metric: 80055
Hyperparams: {'batch_size': 60000, 'dt': 0.00031622776}. metric: 32264
Hyperparams: {'batch_size': 60000, 'dt': 1e-04}. metric: 7548
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}. metric: 1961
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. metric: 877
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}. metric: 1441
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}. metric: 1429
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}. metric: 1408
Hyperparams: {'batch_size': 60000, 'dt': 1e-07}. metric: 1446
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-08}. metric: 1459
Hyperparams: {'batch_size': 6000, 'dt': 0.01}. metric: inf
Hyperparams: {'batch_size': 6000, 'dt': 0.0031622776}. metric: inf
Hyperparams: {'batch_size': 6000, 'dt': 0.001}. metric: inf
Hyperparams: {'batch

HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 0.69 sec'), FloatProgress(value=0.0, max=16.0…

Hyperparams: {'batch_size': 600, 'dt': 1e-06}. metric: 526
Hyperparams: {'batch_size': 60, 'dt': 3.1622776e-08}. metric: 523
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-07}. metric: 472
Hyperparams: {'batch_size': 60, 'dt': 1e-07}. metric: 802
Hyperparams: {'batch_size': 600, 'dt': 1e-07}. metric: 494
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-08}. metric: 592
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-06}. metric: 544
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}. metric: 852
Hyperparams: {'batch_size': 6000, 'dt': 1e-06}. metric: 683
Hyperparams: {'batch_size': 6000, 'dt': 1e-05}. metric: 656
Hyperparams: {'batch_size': 60, 'dt': 3.1622776e-07}. metric: 1766
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-06}. metric: 1866
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-07}. metric: 780
Hyperparams: {'batch_size': 6000, 'dt': 1e-07}. metric: 1037
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}. metric: 1333
Hyperparams: {'batch_size': 6000, 'dt': 3

HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 2.22 sec'), FloatProgress(value=0.0, max=5.0),…

Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-07}. metric: 430
Hyperparams: {'batch_size': 600, 'dt': 1e-07}. metric: 334
Hyperparams: {'batch_size': 60, 'dt': 3.1622776e-08}. metric: 408
Hyperparams: {'batch_size': 600, 'dt': 1e-06}. metric: 693
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-06}. metric: 392

Number of samples: [147]
Running time: 1125.8 sec
{'batch_size': 600, 'dt': 1e-07} 334.35962 (147, 79510)
