In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random, vmap

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context
from models.logistic_regression.logistic_regression_model import gen_data, loglikelihood, logprior
from tuning.mamba import run_MAMBA
from tuning.ksd import imq_KSD




# Logistic regression - MAMBA

In [2]:


key = random.PRNGKey(42)
dim = 10
# Ndata = 100_000
Ndata = 1_000_000

theta_true, X, y_data = gen_data(key, dim, Ndata)

# testing and training data
num_train = int(Ndata*0.8)

X_train = X[:num_train]
X_test = X[num_train:]

y_train = y_data[:num_train]
y_test = y_data[num_train:]
print(X_train.shape, X_test.shape)
data = (X_train, y_train)


generating data, with N=1,000,000 and dim=10
(800000, 10) (200000, 10)


In [3]:
# get MAP
batch_size = int(0.01*X_train.shape[0])
run_adam = build_adam_optimizer(1e-2, loglikelihood, logprior, data, batch_size)
key = random.PRNGKey(0)
centering_value, logpost_array = run_adam(key, 5000, jnp.zeros(dim))
params_IC = centering_value

# T = 10
R = 1 # running time of longest sampler

error_fn = lambda x,y: imq_KSD(x, y)
logdt_range = -jnp.arange(1., 8., 0.5) 
batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 4)]

print(batch_size_range)

[800000, 80000, 8000, 800]


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))

In [4]:
grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)

def grad_lp_LR_fb(theta):
    return grad_log_post_fb(theta, X_train, y_train)

batch_grad_lp_LR_fb = jit(vmap(grad_lp_LR_fb))

@jit
def get_fb_grads(samples):
    thin = 10
    array_samples = jnp.array(samples[::thin])
    mygrads = batch_grad_lp_LR_fb(array_samples)
    return array_samples, mygrads

### run MAMBA

In [5]:
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)




HBox(children=(HTML(value='Iteration 1/3, 56 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=56.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(2, 10). metric: 313867
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(2, 10). metric: 310463
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: 599319
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(2, 10). metric: 244357
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: 230270
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(2, 10). metric: 35102
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: 22171
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: 4516
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: 1008
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: 2652
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(2, 10). metric: 2127
Hyperpar

HBox(children=(HTML(value='Iteration 2/3, 18 arms, time budget = 0.24 sec'), FloatProgress(value=0.0, max=18.0…

Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(12, 10). metric: 171
Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(179, 10). metric: 528
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(178, 10). metric: 594
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(65, 10). metric: 328
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(68, 10). metric: 238
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(5, 10). metric: 422
Hyperparams: {'batch_size': 800, 'dt': 1e-07}., sample shape=(181, 10). metric: 470
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(185, 10). metric: 371
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(67, 10). metric: 355
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(12, 10). metric: 519
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(69, 10). metric: 432
Hyperparams: {'batch_size': 80000

HBox(children=(HTML(value='Iteration 3/3, 6 arms, time budget = 0.72 sec'), FloatProgress(value=0.0, max=6.0),…

Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(36, 10). metric: 374
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(201, 10). metric: 211
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(36, 10). metric: 353
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(218, 10). metric: 125
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(221, 10). metric: 236
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(586, 10). metric: 248

Number of samples: [218, 201]
Running time: 122.3 sec
{'batch_size': 8000, 'dt': 1e-06} 125.33214 (218, 10)


In [6]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 56 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=56.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: 633948
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: 625294
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: 596300
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: 477570
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: 217501
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(1, 10). metric: 73072
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: 19466
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: 3622
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: 930
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: 2590
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: 2878
Hyperpara

HBox(children=(HTML(value='Iteration 2/3, 18 arms, time budget = 0.24 sec'), FloatProgress(value=0.0, max=18.0…

Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(58, 10). metric: 63
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(175, 10). metric: 66
Hyperparams: {'batch_size': 800, 'dt': 1e-05}., sample shape=(167, 10). metric: 56
Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(171, 10). metric: 91
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(58, 10). metric: 125
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(174, 10). metric: 167
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(59, 10). metric: 168
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(10, 10). metric: 236
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(2, 10). metric: 647
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(57, 10). metric: 414
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(10, 10). metric: 326
Hyperparams: {'batch_size': 800, 'dt'

HBox(children=(HTML(value='Iteration 3/3, 6 arms, time budget = 0.72 sec'), FloatProgress(value=0.0, max=6.0),…

Hyperparams: {'batch_size': 800, 'dt': 1e-05}., sample shape=(555, 10). metric: 37
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(187, 10). metric: 53
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(574, 10). metric: 48
Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(522, 10). metric: 89
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(174, 10). metric: 54
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(566, 10). metric: 108

Number of samples: [555, 574]
Running time: 140.9 sec
{'batch_size': 800, 'dt': 1e-05} 37.13933 (555, 10)


In [7]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 112 arms, time budget = 0.03 sec'), FloatProgress(value=0.0, max=112…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(1, 10). metric: 391962
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(1, 10). metric: 315911
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(1, 10). metric: 351372
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 10). metric: 416450
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(1, 10). metric: 436851
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 10). metric: 463559
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 10). metric: 347783
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: 15751
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 10). metric: 10584
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: 2970
Hyperparam

Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.1}., sample shape=(11, 10). metric: 68155
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.031622775}., sample shape=(11, 10). metric: 53257
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.01}., sample shape=(11, 10). metric: 42416
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.0031622776}., sample shape=(10, 10). metric: 49468
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.001}., sample shape=(11, 10). metric: 37818
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.00031622776}., sample shape=(11, 10). metric: 45739
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-04}., sample shape=(11, 10). metric: 43477
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(11, 10). metric: 32026
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-05}., sample shape=(10, 10). metric: 35832
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(11, 10). metric: 6698
Hyperparams: {'batch_size': 800, 'L':

HBox(children=(HTML(value='Iteration 2/4, 37 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=37.0…

Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(39, 10). metric: 584
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(28, 10). metric: 512
Hyperparams: {'batch_size': 800000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(2, 10). metric: 736
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(4, 10). metric: 1048
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}., sample shape=(39, 10). metric: 1078
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(28, 10). metric: 407
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-06}., sample shape=(2, 10). metric: 340
Hyperparams: {'batch_size': 80000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(2, 10). metric: 945
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(41, 10). metric: 695
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(7, 10). metric: 626
Hyperparams: {'batch_s

HBox(children=(HTML(value='Iteration 3/4, 12 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-06}., sample shape=(3, 10). metric: 602
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(93, 10). metric: 187
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(94, 10). metric: 188
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(21, 10). metric: 351
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(140, 10). metric: 439
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(11, 10). metric: 282
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(21, 10). metric: 307
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(142, 10). metric: 225
Hyperparams: {'batch_size': 800000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(3, 10). metric: 606
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(11, 10). metric: 534
Hyperparams: {'batch_size':

HBox(children=(HTML(value='Iteration 4/4, 4 arms, time budget = 0.70 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(291, 10). metric: 195
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(294, 10). metric: 254
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(444, 10). metric: 195
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(32, 10). metric: 280

Number of samples: [291]
Running time: 238.5 sec
{'batch_size': 800, 'L': 10, 'dt': 1e-07} 194.5676 (291, 10)


In [8]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 112 arms, time budget = 0.03 sec'), FloatProgress(value=0.0, max=112…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(1, 10). metric: 390082
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(1, 10). metric: 309232
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(1, 10). metric: 346898
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 10). metric: 416598
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(1, 10). metric: 435713
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 10). metric: 458591
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 10). metric: 336311
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: 11153
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 10). metric: 10747
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: 3178
Hyperparam

Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.1}., sample shape=(9, 10). metric: 89581
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.031622775}., sample shape=(9, 10). metric: 60704
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.01}., sample shape=(9, 10). metric: 58946
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.0031622776}., sample shape=(9, 10). metric: 56423
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.001}., sample shape=(9, 10). metric: 55588
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.00031622776}., sample shape=(9, 10). metric: 58400
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-04}., sample shape=(9, 10). metric: 47552
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(9, 10). metric: 30641
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-05}., sample shape=(9, 10). metric: 41355
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(9, 10). metric: 16761
Hyperparams: {'batch_size': 800, 'L': 5, 'dt':

HBox(children=(HTML(value='Iteration 2/4, 37 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=37.0…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(21, 10). metric: 125
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-06}., sample shape=(31, 10). metric: 157
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(32, 10). metric: 169
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(22, 10). metric: 254
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(23, 10). metric: 221
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}., sample shape=(33, 10). metric: 185
Hyperparams: {'batch_size': 800000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(2, 10). metric: 551
Hyperparams: {'batch_size': 80000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(2, 10). metric: 713
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(3, 10). metric: 337
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-06}., sample shape=(2, 10). metric: 414
Hyperparams: {'batch_size': 800

HBox(children=(HTML(value='Iteration 3/4, 12 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(76, 10). metric: 101
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-06}., sample shape=(112, 10). metric: 46
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(112, 10). metric: 80
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}., sample shape=(116, 10). metric: 227
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(75, 10). metric: 206
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(15, 10). metric: 202
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(77, 10). metric: 112
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(9, 10). metric: 307
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(17, 10). metric: 256
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-06}., sample shape=(3, 10). metric: 645
Hyperparams: {'batch_size': 8000, 'L': 10

HBox(children=(HTML(value='Iteration 4/4, 4 arms, time budget = 0.70 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-06}., sample shape=(352, 10). metric: 51
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(310, 10). metric: 83
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(237, 10). metric: 37
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(235, 10). metric: 62

Number of samples: [237]
Running time: 302.7 sec
{'batch_size': 800, 'L': 10, 'dt': 3.1622776e-07} 36.54479 (237, 10)


In [9]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 56 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=56.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: 627259
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: 621158
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: 597799
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: 486930
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: 237438
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(1, 10). metric: 66769
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: 23182
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: 5319
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: 1454
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: 2576
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: 3050
Hyperpar

HBox(children=(HTML(value='Iteration 2/3, 18 arms, time budget = 0.24 sec'), FloatProgress(value=0.0, max=18.0…

Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(13, 10). metric: 213
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}., sample shape=(62, 10). metric: 213
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(12, 10). metric: 841
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(13, 10). metric: 295
Hyperparams: {'batch_size': 800, 'dt': 1e-07}., sample shape=(123, 10). metric: 701
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(57, 10). metric: 458
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-08}., sample shape=(61, 10). metric: 362
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(12, 10). metric: 432
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(60, 10). metric: 432
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-08}., sample shape=(147, 10). metric: 527
Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(143, 10). metric: 815
Hyperparams: {'batch_size': 800,

HBox(children=(HTML(value='Iteration 3/3, 6 arms, time budget = 0.72 sec'), FloatProgress(value=0.0, max=6.0),…

Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(37, 10). metric: 236
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}., sample shape=(210, 10). metric: 176
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(37, 10). metric: 163
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-08}., sample shape=(210, 10). metric: 107
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(36, 10). metric: 139
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(209, 10). metric: 218

Number of samples: [210, 36]
Running time: 192.9 sec
{'batch_size': 8000, 'dt': 3.1622776e-08} 106.707985 (210, 10)


In [10]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 56 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=56.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: 633599
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: 625572
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: 594703
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: 476958
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: 225221
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(1, 10). metric: 70078
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: 20501
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: 4429
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: 1291
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: 2516
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: 3030
Hyperpar

HBox(children=(HTML(value='Iteration 2/3, 18 arms, time budget = 0.24 sec'), FloatProgress(value=0.0, max=18.0…

Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(153, 10). metric: 32
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(107, 10). metric: 33
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(51, 10). metric: 16
Hyperparams: {'batch_size': 800, 'dt': 1e-07}., sample shape=(122, 10). metric: 43
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(46, 10). metric: 41
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-08}., sample shape=(125, 10). metric: 99
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(126, 10). metric: 61
Hyperparams: {'batch_size': 800, 'dt': 1e-05}., sample shape=(132, 10). metric: 152
Hyperparams: {'batch_size': 8000, 'dt': 3.1622778e-05}., sample shape=(56, 10). metric: 118
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(54, 10). metric: 59
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(57, 10). metric: 48
Hyperparams: {'batch_size': 8000, 'dt

HBox(children=(HTML(value='Iteration 3/3, 6 arms, time budget = 0.72 sec'), FloatProgress(value=0.0, max=6.0),…

Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(185, 10). metric: 13
Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(527, 10). metric: 11
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(451, 10). metric: 13
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(179, 10). metric: 51
Hyperparams: {'batch_size': 800, 'dt': 1e-07}., sample shape=(494, 10). metric: 25
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(179, 10). metric: 10

Number of samples: [179, 527]
Running time: 235.0 sec
{'batch_size': 8000, 'dt': 3.1622776e-06} 9.652608 (179, 10)
