In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context

from models.bayesian_NN.NN_model import logprior, loglikelihood, init_network, get_ECE_MCE
from models.bayesian_NN.NN_data import X_train, y_train, X_test, y_test
from models.bayesian_NN.util import load_NN_MAP
from util import flatten_param_list

from tuning.mamba import run_MAMBA
from tuning.ksd import FSSD_opt, FSSD, imq_KSD, get_test_locations, linear_imq_KSD




# NN - MAMBA-ECE

In [2]:
data = (X_train, y_train)

key = random.PRNGKey(0)
params_IC = load_NN_MAP()
centering_value = load_NN_MAP()

# err_fn = lambda x,y: get_ECE_MCE(x[::10], X_test, y_test, M=10, pbar=False)[0]

err_fn = lambda x,y: imq_KSD(x, y)

# err_fn = lambda x,y: FSSD_opt(x, y, get_test_locations(x), 100)
# err_fn = lambda x,y: FSSD(x, y, get_test_locations(x, J=50))

# get_fb_grads = None
batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 3)]
print(batch_size_range)

R = 10

[60000, 6000, 600]


In [3]:
grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)


def get_fb_grads(samples):
    thin = 10
    mygrads = [grad_log_post_fb(sam, *data) for sam in samples[::thin]]
    
    return samples[::thin], mygrads

### run MAMBA

In [4]:
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 36 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=36.0…

Hyperparams: {'batch_size': 60000, 'dt': 0.01}., sample shape=(1, 79510). metric: 81625
Hyperparams: {'batch_size': 60000, 'dt': 0.0031622776}., sample shape=(1, 79510). metric: 108797
Hyperparams: {'batch_size': 60000, 'dt': 0.001}., sample shape=(1, 79510). metric: 83430
Hyperparams: {'batch_size': 60000, 'dt': 0.00031622776}., sample shape=(1, 79510). metric: 29712
Hyperparams: {'batch_size': 60000, 'dt': 1e-04}., sample shape=(1, 79510). metric: 8383
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}., sample shape=(1, 79510). metric: 1920
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}., sample shape=(1, 79510). metric: 1045
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}., sample shape=(1, 79510). metric: 1288
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}., sample shape=(1, 79510). metric: 1438
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}., sample shape=(1, 79510). metric: 1362
Hyperparams: {'batch_size': 60000, 'dt': 1e-07}., sample shape=(1, 79510). me

HBox(children=(HTML(value='Iteration 2/3, 12 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-06}., sample shape=(131, 79510). metric: 297
Hyperparams: {'batch_size': 600, 'dt': 1e-06}., sample shape=(129, 79510). metric: 326
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-07}., sample shape=(118, 79510). metric: 452
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-06}., sample shape=(20, 79510). metric: 433
Hyperparams: {'batch_size': 6000, 'dt': 1e-06}., sample shape=(17, 79510). metric: 521
Hyperparams: {'batch_size': 600, 'dt': 1e-07}., sample shape=(120, 79510). metric: 509
Hyperparams: {'batch_size': 6000, 'dt': 1e-05}., sample shape=(20, 79510). metric: 467
Hyperparams: {'batch_size': 600, 'dt': 1e-05}., sample shape=(97, 79510). metric: 566
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-08}., sample shape=(121, 79510). metric: 756
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}., sample shape=(4, 79510). metric: 587
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}., sample shape=(3, 79510). metric: 735
Hype

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 6.92 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-06}., sample shape=(391, 79510). metric: 193
Hyperparams: {'batch_size': 600, 'dt': 1e-06}., sample shape=(386, 79510). metric: 237
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-06}., sample shape=(56, 79510). metric: 318
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-07}., sample shape=(348, 79510). metric: 326

Number of samples: [391]
Running time: 594.9 sec
{'batch_size': 600, 'dt': 3.1622776e-06} 193.17734 (391, 79510)


In [5]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 36 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=36.0…

Hyperparams: {'batch_size': 60000, 'dt': 0.01}., sample shape=(1, 79510). metric: 81039
Hyperparams: {'batch_size': 60000, 'dt': 0.0031622776}., sample shape=(1, 79510). metric: 107830
Hyperparams: {'batch_size': 60000, 'dt': 0.001}., sample shape=(1, 79510). metric: 82992
Hyperparams: {'batch_size': 60000, 'dt': 0.00031622776}., sample shape=(1, 79510). metric: 29394
Hyperparams: {'batch_size': 60000, 'dt': 1e-04}., sample shape=(1, 79510). metric: 6824
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}., sample shape=(1, 79510). metric: 3284
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}., sample shape=(1, 79510). metric: 1216
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}., sample shape=(1, 79510). metric: 1247
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}., sample shape=(1, 79510). metric: 1448
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}., sample shape=(1, 79510). metric: 1361
Hyperparams: {'batch_size': 60000, 'dt': 1e-07}., sample shape=(1, 79510). me

HBox(children=(HTML(value='Iteration 2/3, 12 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'batch_size': 600, 'dt': 1e-05}., sample shape=(88, 79510). metric: 317
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-06}., sample shape=(57, 79510). metric: 316
Hyperparams: {'batch_size': 600, 'dt': 1e-06}., sample shape=(79, 79510). metric: 359
Hyperparams: {'batch_size': 6000, 'dt': 1e-05}., sample shape=(13, 79510). metric: 406
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-07}., sample shape=(82, 79510). metric: 452
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-06}., sample shape=(12, 79510). metric: 465
Hyperparams: {'batch_size': 6000, 'dt': 1e-06}., sample shape=(12, 79510). metric: 589
Hyperparams: {'batch_size': 600, 'dt': 1e-07}., sample shape=(78, 79510). metric: 559
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-08}., sample shape=(83, 79510). metric: 872
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}., sample shape=(2, 79510). metric: 868
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}., sample shape=(2, 79510). metric: 1005
Hyperpar

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 6.92 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-06}., sample shape=(226, 79510). metric: 174
Hyperparams: {'batch_size': 600, 'dt': 1e-05}., sample shape=(260, 79510). metric: 230
Hyperparams: {'batch_size': 600, 'dt': 1e-06}., sample shape=(252, 79510). metric: 253
Hyperparams: {'batch_size': 6000, 'dt': 1e-05}., sample shape=(37, 79510). metric: 294

Number of samples: [226]
Running time: 496.2 sec
{'batch_size': 600, 'dt': 3.1622776e-06} 174.00894 (226, 79510)


In [6]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 72 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=72.0…

Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.01}., sample shape=(1, 79510). metric: 44773
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 79510). metric: 67544
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.001}., sample shape=(1, 79510). metric: 68994
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 79510). metric: 80529
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 79510). metric: 144456
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 79510). metric: 16597
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 79510). metric: 4157
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 79510). metric: 1561
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-06}., sample shape=(1, 79510). metric: 1046
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(1, 79510). metric: 

HBox(children=(HTML(value='Iteration 2/3, 24 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=24.0…

Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(30, 79510). metric: 419
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 1e-07}., sample shape=(15, 79510). metric: 448
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(15, 79510). metric: 515
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 1e-07}., sample shape=(26, 79510). metric: 501
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 1e-06}., sample shape=(26, 79510). metric: 543
Hyperparams: {'batch_size': 6000, 'L': 5, 'dt': 1e-06}., sample shape=(5, 79510). metric: 730
Hyperparams: {'batch_size': 6000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(6, 79510). metric: 576
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(14, 79510). metric: 621
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-06}., sample shape=(2, 79510). metric: 832
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(2, 79510). metric: 858
Hyperparams: {'

HBox(children=(HTML(value='Iteration 3/3, 8 arms, time budget = 6.92 sec'), FloatProgress(value=0.0, max=8.0),…

Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(87, 79510). metric: 310
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 1e-07}., sample shape=(46, 79510). metric: 326
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 1e-07}., sample shape=(85, 79510). metric: 377
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(45, 79510). metric: 377
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 1e-06}., sample shape=(84, 79510). metric: 349
Hyperparams: {'batch_size': 6000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(16, 79510). metric: 530
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(44, 79510). metric: 401
Hyperparams: {'batch_size': 6000, 'L': 5, 'dt': 1e-07}., sample shape=(15, 79510). metric: 553

Number of samples: [87, 46]
Running time: 847.9 sec
{'batch_size': 600, 'L': 5, 'dt': 3.1622776e-07} 309.57224 (87, 79510)


In [7]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 72 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=72.0…

Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.01}., sample shape=(1, 79510). metric: 81607
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 79510). metric: 47814
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.001}., sample shape=(1, 79510). metric: 65174
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 79510). metric: 80537
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 79510). metric: 136342
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 79510). metric: 29428
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 79510). metric: 3973
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 79510). metric: 1684
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-06}., sample shape=(1, 79510). metric: 1065
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(1, 79510). metric: 

HBox(children=(HTML(value='Iteration 2/3, 24 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=24.0…

Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 1e-06}., sample shape=(19, 79510). metric: 373
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(19, 79510). metric: 391
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 1e-07}., sample shape=(10, 79510). metric: 518
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(10, 79510). metric: 422
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 1e-07}., sample shape=(18, 79510). metric: 513
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(10, 79510). metric: 566
Hyperparams: {'batch_size': 6000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(3, 79510). metric: 8588
Hyperparams: {'batch_size': 6000, 'L': 5, 'dt': 1e-06}., sample shape=(3, 79510). metric: 793
Hyperparams: {'batch_size': 6000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(2, 79510). metric: 811
Hyperparams: {'batch_size': 60000, 'L': 5, 'dt': 1e-06}., sample shape=(2, 79510). metric: 992
Hyperparams: {

HBox(children=(HTML(value='Iteration 3/3, 8 arms, time budget = 6.92 sec'), FloatProgress(value=0.0, max=8.0),…

Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 1e-06}., sample shape=(58, 79510). metric: 258
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(58, 79510). metric: 300
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(30, 79510). metric: 311
Hyperparams: {'batch_size': 600, 'L': 5, 'dt': 1e-07}., sample shape=(57, 79510). metric: 379
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 1e-07}., sample shape=(30, 79510). metric: 339
Hyperparams: {'batch_size': 600, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(30, 79510). metric: 409
Hyperparams: {'batch_size': 6000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(9, 79510). metric: 489
Hyperparams: {'batch_size': 6000, 'L': 5, 'dt': 1e-07}., sample shape=(9, 79510). metric: 603

Number of samples: [58, 58]
Running time: 996.2 sec
{'batch_size': 600, 'L': 5, 'dt': 1e-06} 258.42096 (58, 79510)


In [5]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 36 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=36.0…

Hyperparams: {'batch_size': 60000, 'dt': 0.01}., sample shape=(1, 79510). metric: 81348
Hyperparams: {'batch_size': 60000, 'dt': 0.0031622776}., sample shape=(1, 79510). metric: 105126
Hyperparams: {'batch_size': 60000, 'dt': 0.001}., sample shape=(1, 79510). metric: 81248
Hyperparams: {'batch_size': 60000, 'dt': 0.00031622776}., sample shape=(1, 79510). metric: 32771
Hyperparams: {'batch_size': 60000, 'dt': 1e-04}., sample shape=(1, 79510). metric: 8883
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}., sample shape=(1, 79510). metric: 1058
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}., sample shape=(1, 79510). metric: 859
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}., sample shape=(1, 79510). metric: 1489
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}., sample shape=(1, 79510). metric: 1418
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}., sample shape=(1, 79510). metric: 1410
Hyperparams: {'batch_size': 60000, 'dt': 1e-07}., sample shape=(1, 79510). met

HBox(children=(HTML(value='Iteration 2/3, 12 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-08}., sample shape=(122, 79510). metric: 429
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-07}., sample shape=(21, 79510). metric: 470
Hyperparams: {'batch_size': 6000, 'dt': 1e-06}., sample shape=(18, 79510). metric: 634
Hyperparams: {'batch_size': 6000, 'dt': 1e-07}., sample shape=(19, 79510). metric: 513
Hyperparams: {'batch_size': 600, 'dt': 1e-07}., sample shape=(115, 79510). metric: 468
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-07}., sample shape=(125, 79510). metric: 569
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}., sample shape=(4, 79510). metric: 670
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-08}., sample shape=(19, 79510). metric: 601
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}., sample shape=(4, 79510). metric: 1094
Hyperparams: {'batch_size': 600, 'dt': 3.1622778e-05}., sample shape=(106, 79510). metric: 251
Hyperparams: {'batch_size': 600, 'dt': 1e-06}., sample shape=(125, 79510). metric

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 6.92 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 600, 'dt': 3.1622778e-05}., sample shape=(373, 79510). metric: 133
Hyperparams: {'batch_size': 600, 'dt': 1e-05}., sample shape=(402, 79510). metric: 145
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-08}., sample shape=(397, 79510). metric: 272
Hyperparams: {'batch_size': 600, 'dt': 1e-07}., sample shape=(374, 79510). metric: 299

Number of samples: [373]
Running time: 1003.7 sec
{'batch_size': 600, 'dt': 3.1622778e-05} 132.67917 (373, 79510)


In [8]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 36 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=36.0…

Hyperparams: {'batch_size': 60000, 'dt': 0.01}., sample shape=(1, 79510). metric: 81125
Hyperparams: {'batch_size': 60000, 'dt': 0.0031622776}., sample shape=(1, 79510). metric: 115621
Hyperparams: {'batch_size': 60000, 'dt': 0.001}., sample shape=(1, 79510). metric: 80055
Hyperparams: {'batch_size': 60000, 'dt': 0.00031622776}., sample shape=(1, 79510). metric: 32264
Hyperparams: {'batch_size': 60000, 'dt': 1e-04}., sample shape=(1, 79510). metric: 7548
Hyperparams: {'batch_size': 60000, 'dt': 3.1622778e-05}., sample shape=(1, 79510). metric: 1961
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}., sample shape=(1, 79510). metric: 877
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-06}., sample shape=(1, 79510). metric: 1441
Hyperparams: {'batch_size': 60000, 'dt': 1e-06}., sample shape=(1, 79510). metric: 1429
Hyperparams: {'batch_size': 60000, 'dt': 3.1622776e-07}., sample shape=(1, 79510). metric: 1408
Hyperparams: {'batch_size': 60000, 'dt': 1e-07}., sample shape=(1, 79510). met

HBox(children=(HTML(value='Iteration 2/3, 12 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-07}., sample shape=(86, 79510). metric: 384
Hyperparams: {'batch_size': 600, 'dt': 1e-07}., sample shape=(85, 79510). metric: 290
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-08}., sample shape=(85, 79510). metric: 356
Hyperparams: {'batch_size': 6000, 'dt': 1e-05}., sample shape=(13, 79510). metric: 1336
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-06}., sample shape=(12, 79510). metric: 525
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-07}., sample shape=(12, 79510). metric: 454
Hyperparams: {'batch_size': 600, 'dt': 1e-06}., sample shape=(78, 79510). metric: 718
Hyperparams: {'batch_size': 6000, 'dt': 1e-06}., sample shape=(12, 79510). metric: 416
Hyperparams: {'batch_size': 6000, 'dt': 1e-07}., sample shape=(12, 79510). metric: 578
Hyperparams: {'batch_size': 60000, 'dt': 1e-05}., sample shape=(2, 79510). metric: 858
Hyperparams: {'batch_size': 6000, 'dt': 3.1622776e-08}., sample shape=(12, 79510). metric: 765
Hyperp

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 6.92 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 600, 'dt': 1e-07}., sample shape=(261, 79510). metric: 241
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-08}., sample shape=(268, 79510). metric: 225
Hyperparams: {'batch_size': 600, 'dt': 3.1622776e-07}., sample shape=(270, 79510). metric: 296
Hyperparams: {'batch_size': 6000, 'dt': 1e-06}., sample shape=(38, 79510). metric: 342

Number of samples: [268]
Running time: 818.8 sec
{'batch_size': 600, 'dt': 3.1622776e-08} 224.86502 (268, 79510)
