In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random, vmap

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context
from models.logistic_regression.logistic_regression_model import gen_data, loglikelihood, logprior
from tuning.mamba import run_MAMBA
from tuning.ksd import FSSD_opt, imq_KSD, get_test_locations, linear_imq_KSD




# Logistic regression - MAMBA

FSSD with a long running time (R=10 rather than 3)

In [2]:


key = random.PRNGKey(42)
dim = 10
# Ndata = 100_000
Ndata = 1_000_000

theta_true, X, y_data = gen_data(key, dim, Ndata)

# testing and training data
num_train = int(Ndata*0.8)

X_train = X[:num_train]
X_test = X[num_train:]

y_train = y_data[:num_train]
y_test = y_data[num_train:]
print(X_train.shape, X_test.shape)
data = (X_train, y_train)


generating data, with N=1,000,000 and dim=10
(800000, 10) (200000, 10)


In [3]:
# get MAP
batch_size = int(0.01*X_train.shape[0])
run_adam = build_adam_optimizer(1e-2, loglikelihood, logprior, data, batch_size)
key = random.PRNGKey(0)
centering_value, logpost_array = run_adam(key, 5000, jnp.zeros(dim))
params_IC = centering_value

fixed_bs = int(0.1*X_train.shape[0])

# T = 10
R = 5 # running time of longest sampler

error_fn = lambda x,y: FSSD_opt(x, y, get_test_locations(x), 100)

logdt_range = -jnp.arange(1., 8., 0.5) 
# batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 5)]
batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 3)]

print(batch_size_range)

[800000, 80000, 8000]


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))

In [4]:
grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)

def grad_lp_LR_fb(theta):
    return grad_log_post_fb(theta, X_train, y_train)

batch_grad_lp_LR_fb = jit(vmap(grad_lp_LR_fb))

@jit
def get_fb_grads(samples):
    thin = 10
    array_samples = jnp.array(samples[::thin])
    mygrads = batch_grad_lp_LR_fb(array_samples)
    return array_samples, mygrads

### run MAMBA

In [5]:
# diag sds, another run
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)




HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.38 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(5, 10). metric: 21218
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(5, 10). metric: -13107
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(5, 10). metric: 125779
Hyperparams: {'batch_si

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 1.15 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(20, 10). metric: 2509
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(329, 10). metric: -3597
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(54, 10). metric: -196
Hyperparams: {'batch_size': 80000, 'dt': 0.1}., sample shape=(51, 10). metric: 5991
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(322, 10). metric: 472
Hyperparams: {'batch_size': 8000, 'dt': 0.031622775}., sample shape=(323, 10). metric: 3352
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(54, 10). metric: -1446
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(54, 10). metric: 5109
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(336, 10). metric: 2071
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(332, 10). metric: -416
Hyperparams: {'batch_size': 8000, 'dt': 0.01}., sample shape=(297, 10). metric: 24503
Hyperparams: {'batch_size'

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 4.04 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(1175, 10). metric: -359
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(70, 10). metric: 187
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(180, 10). metric: 631
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(1023, 10). metric: 433

Number of samples: [1175]
Running time: 215.9 sec
{'batch_size': 8000, 'dt': 1e-05} -358.98062 (1175, 10)


In [7]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.38 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(2, 10). metric: 46300
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(2, 10). metric: -84846
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(2, 10). metric: 385410
Hyperparams: {'batch_si

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 1.15 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(7, 10). metric: 14374
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(46, 10). metric: -490
Hyperparams: {'batch_size': 80000, 'dt': 0.1}., sample shape=(43, 10). metric: 7443
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(280, 10). metric: -103
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(286, 10). metric: 885
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(293, 10). metric: 235
Hyperparams: {'batch_size': 80000, 'dt': 0.031622775}., sample shape=(46, 10). metric: 51450
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(45, 10). metric: 2307
Hyperparams: {'batch_size': 8000, 'dt': 0.031622775}., sample shape=(287, 10). metric: 4636
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(294, 10). metric: 575
Hyperparams: {'batch_size': 80000, 'dt': 0.01}., sample shape=(46, 10). metric: 588004
Hyperparams: {'batch_s

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 4.04 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(158, 10). metric: -199
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(1010, 10). metric: 39
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(1057, 10). metric: 22
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(1051, 10). metric: 89

Number of samples: [158]
Running time: 196.1 sec
{'batch_size': 80000, 'dt': 1e-05} -199.04362 (158, 10)


In [8]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 84 arms, time budget = 0.12 sec'), FloatProgress(value=0.0, max=84.0…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 

HBox(children=(HTML(value='Iteration 2/4, 28 arms, time budget = 0.38 sec'), FloatProgress(value=0.0, max=28.0…

Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-06}., sample shape=(6, 10). metric: -105
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(16, 10). metric: -10921
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(15, 10). metric: 8377
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(29, 10). metric: 192
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(28, 10). metric: -12404
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(29, 10). metric: 8272
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(15, 10). metric: 9397
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(6, 10). metric: 31977
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(28, 10). metric: 26094
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-07}., sample shape=(6, 10). metric: 98930
Hyperparams: {'batch_

HBox(children=(HTML(value='Iteration 3/4, 9 arms, time budget = 1.17 sec'), FloatProgress(value=0.0, max=9.0),…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(90, 10). metric: -2780
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(7, 10). metric: 13383
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(48, 10). metric: 26623
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-06}., sample shape=(7, 10). metric: -59
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-06}., sample shape=(17, 10). metric: 6101
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(91, 10). metric: 534
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(95, 10). metric: 1753
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(47, 10). metric: 720
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(46, 10). metric: 7133

Number of samples: [90, 7, 91]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 3.50 sec'), FloatProgress(value=0.0, max=3.0),…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(275, 10). metric: 673
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-06}., sample shape=(18, 10). metric: 3364
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(290, 10). metric: -364

Number of samples: [290]
Running time: 214.0 sec
{'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07} -363.97076 (290, 10)


In [9]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 84 arms, time budget = 0.12 sec'), FloatProgress(value=0.0, max=84.0…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 

HBox(children=(HTML(value='Iteration 2/4, 28 arms, time budget = 0.38 sec'), FloatProgress(value=0.0, max=28.0…

Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(13, 10). metric: 8174
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(24, 10). metric: -143
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(14, 10). metric: -865
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(25, 10). metric: 1943
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(23, 10). metric: 5802
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(14, 10). metric: 18578
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(24, 10). metric: 19621
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 80

HBox(children=(HTML(value='Iteration 3/4, 9 arms, time budget = 1.17 sec'), FloatProgress(value=0.0, max=9.0),…

Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(41, 10). metric: 131
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(79, 10). metric: -337
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(81, 10). metric: 493
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(79, 10). metric: 5967
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(42, 10). metric: 975
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(42, 10). metric: 2607
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(76, 10). metric: 8059
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(4, 10). metric: inf

Number of samples: [79, 41, 81]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 3.50 sec'), FloatProgress(value=0.0, max=3.0),…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(242, 10). metric: 335
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(119, 10). metric: 35
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(231, 10). metric: -27

Number of samples: [231]
Running time: 257.8 sec
{'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07} -26.962803 (231, 10)


In [10]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.38 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(5, 10). metric: -259258
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(5, 10). metric: -100993
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(5, 10). metric: -85543
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(5, 10). metric: -109971
Hyperparams: {'b

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 1.15 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(19, 10). metric: -9396
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(20, 10). metric: -10340
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(20, 10). metric: -1626
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(20, 10). metric: -2800
Hyperparams: {'batch_size': 800000, 'dt': 1e-07}., sample shape=(20, 10). metric: -6206
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-07}., sample shape=(20, 10). metric: -6215
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(53, 10). metric: -19145
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(54, 10). metric: -9956
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(53, 10). metric: -6273
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(301, 10). metric: -5244
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(294, 10). m

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 4.04 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(183, 10). metric: -3378
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(72, 10). metric: -596
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(198, 10). metric: -2789
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(71, 10). metric: -591

Number of samples: [183]
Running time: 202.7 sec
{'batch_size': 80000, 'dt': 3.1622776e-06} -3377.6277 (183, 10)


In [11]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.38 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(2, 10). metric: -522274
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(2, 10). metric: -162790
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(2, 10). metric: 436822
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(2, 10). metric: -704890
Hyperparams: {'b

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 1.15 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(6, 10). metric: -123508
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(7, 10). metric: -334821
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(7, 10). metric: -108212
Hyperparams: {'batch_size': 80000, 'dt': 3.1622778e-05}., sample shape=(45, 10). metric: -2113
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(45, 10). metric: -1318
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(44, 10). metric: -1451
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(46, 10). metric: -1293
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(44, 10). metric: -2131
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(46, 10). metric: 140
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-08}., sample shape=(46, 10). metric: -606
Hyperparams: {'batch_size': 8000, 'dt': 3.1622778e-05}., sample shape=(280, 10). metric

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 4.04 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(20, 10). metric: -27453
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(20, 10). metric: -1909
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(20, 10). metric: -15917
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(147, 10). metric: -231

Number of samples: [20]
Running time: 214.3 sec
{'batch_size': 800000, 'dt': 3.1622778e-05} -27453.271 (20, 10)
