In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random, vmap

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context
from models.logistic_regression.logistic_regression_model import gen_data, loglikelihood, logprior
from tuning.mamba import run_MAMBA
from tuning.ksd import imq_KSD




# Logistic regression - MAMBA

In [2]:


key = random.PRNGKey(42)
dim = 10
# Ndata = 100_000
Ndata = 1_000_000

theta_true, X, y_data = gen_data(key, dim, Ndata)

# testing and training data
num_train = int(Ndata*0.8)

X_train = X[:num_train]
X_test = X[num_train:]

y_train = y_data[:num_train]
y_test = y_data[num_train:]
print(X_train.shape, X_test.shape)
data = (X_train, y_train)


generating data, with N=1,000,000 and dim=10
(800000, 10) (200000, 10)


In [3]:
# get MAP
batch_size = int(0.01*X_train.shape[0])
run_adam = build_adam_optimizer(1e-2, loglikelihood, logprior, data, batch_size)
key = random.PRNGKey(0)
centering_value, logpost_array = run_adam(key, 5000, jnp.zeros(dim))
params_IC = centering_value

# T = 10
R = 1 # running time of longest sampler

error_fn = lambda x,y: imq_KSD(x, y)
logdt_range = -jnp.arange(1., 8., 0.5) 
batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 4)]

print(batch_size_range)

[800000, 80000, 8000, 800]


In [4]:
grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)

def grad_lp_LR_fb(theta):
    return grad_log_post_fb(theta, X_train, y_train)

batch_grad_lp_LR_fb = jit(vmap(grad_lp_LR_fb))

@jit
def get_fb_grads(samples):
    thin = 10
    array_samples = jnp.array(samples[::thin])
    mygrads = batch_grad_lp_LR_fb(array_samples)
    return array_samples, mygrads

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))




### run MAMBA

In [5]:
%%time

build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 56 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=56.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: 627644
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: 620876
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: 599319
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: 487506
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: 230270
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(1, 10). metric: 69859
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: 22171
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: 4516
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: 1008
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: 2652
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: 2898
Hyperpar

HBox(children=(HTML(value='Iteration 2/3, 18 arms, time budget = 0.24 sec'), FloatProgress(value=0.0, max=18.0…

Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(179, 10). metric: 594
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(230, 10). metric: 532
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(12, 10). metric: 732
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(70, 10). metric: 315
Hyperparams: {'batch_size': 800, 'dt': 1e-07}., sample shape=(214, 10). metric: 542
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(231, 10). metric: 376
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(5, 10). metric: 283
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(69, 10). metric: 197
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(13, 10). metric: 367
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(61, 10). metric: 278
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(68, 10). metric: 342
Hyperparams: {'batch_size': 80000

HBox(children=(HTML(value='Iteration 3/3, 6 arms, time budget = 0.72 sec'), FloatProgress(value=0.0, max=6.0),…

Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(226, 10). metric: 258
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(218, 10). metric: 300
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(14, 10). metric: 221
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(15, 10). metric: 214
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(232, 10). metric: 240
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(37, 10). metric: 277

Number of samples: [15, 14]
Running time: 115.2 sec
{'batch_size': 800000, 'dt': 3.1622776e-06} 213.62427 (15, 10)
CPU times: user 2min 15s, sys: 36.1 s, total: 2min 51s
Wall time: 1min 55s


In [8]:
%%time

build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 56 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=56.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: 633948
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: 625294
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: 596300
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: 477570
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: 217501
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(1, 10). metric: 73072
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: 19466
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: 3622
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: 930
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: 2590
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: 2878
Hyperpara

HBox(children=(HTML(value='Iteration 2/3, 18 arms, time budget = 0.24 sec'), FloatProgress(value=0.0, max=18.0…

Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(224, 10). metric: 78
Hyperparams: {'batch_size': 800, 'dt': 1e-05}., sample shape=(216, 10). metric: 36
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(49, 10). metric: 113
Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(201, 10). metric: 119
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(53, 10). metric: 124
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(197, 10). metric: 157
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(9, 10). metric: 188
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(57, 10). metric: 235
Hyperparams: {'batch_size': 800, 'dt': 1e-07}., sample shape=(220, 10). metric: 277
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(2, 10). metric: 644
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(58, 10). metric: 406
Hyperparams: {'batch_size': 80000, 'dt': 3.16

HBox(children=(HTML(value='Iteration 3/3, 6 arms, time budget = 0.72 sec'), FloatProgress(value=0.0, max=6.0),…

Hyperparams: {'batch_size': 800, 'dt': 1e-05}., sample shape=(709, 10). metric: 29
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(467, 10). metric: 54
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(175, 10). metric: 56
Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(623, 10). metric: 87
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(167, 10). metric: 52
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(600, 10). metric: 102

Number of samples: [709, 167]
Running time: 152.0 sec
{'batch_size': 800, 'dt': 1e-05} 29.480639 (709, 10)
CPU times: user 3min 1s, sys: 1min, total: 4min 2s
Wall time: 2min 40s


In [9]:
%%time 

build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 112 arms, time budget = 0.03 sec'), FloatProgress(value=0.0, max=112…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(1, 10). metric: 391962
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(1, 10). metric: 315911
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(1, 10). metric: 351372
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 10). metric: 416450
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(1, 10). metric: 436851
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 10). metric: 463559
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 10). metric: 347783
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: 15751
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 10). metric: 10584
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: 2970
Hyperparam

Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.1}., sample shape=(11, 10). metric: 68155
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.031622775}., sample shape=(12, 10). metric: 49789
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.01}., sample shape=(11, 10). metric: 42416
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.0031622776}., sample shape=(11, 10). metric: 45929
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.001}., sample shape=(11, 10). metric: 37818
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.00031622776}., sample shape=(12, 10). metric: 43111
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-04}., sample shape=(11, 10). metric: 43477
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(11, 10). metric: 32026
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-05}., sample shape=(12, 10). metric: 37742
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(11, 10). metric: 6698
Hyperparams: {'batch_size': 800, 'L':

HBox(children=(HTML(value='Iteration 2/4, 37 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=37.0…

Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(34, 10). metric: 544
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(28, 10). metric: 538
Hyperparams: {'batch_size': 800000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(2, 10). metric: 736
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(4, 10). metric: 1104
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}., sample shape=(41, 10). metric: 964
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(20, 10). metric: 498
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-06}., sample shape=(2, 10). metric: 340
Hyperparams: {'batch_size': 80000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(2, 10). metric: 945
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(43, 10). metric: 611
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(5, 10). metric: 613
Hyperparams: {'batch_si

HBox(children=(HTML(value='Iteration 3/4, 12 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-06}., sample shape=(3, 10). metric: 496
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(84, 10). metric: 169
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(91, 10). metric: 197
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(135, 10). metric: 151
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(20, 10). metric: 614
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(141, 10). metric: 346
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(19, 10). metric: 307
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(98, 10). metric: 738
Hyperparams: {'batch_size': 800000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(3, 10). metric: 606
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(10, 10). metric: 330
Hyperparams: {'batch_size': 

HBox(children=(HTML(value='Iteration 4/4, 4 arms, time budget = 0.70 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(465, 10). metric: 224
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(286, 10). metric: 325
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(300, 10). metric: 150
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(60, 10). metric: 241

Number of samples: [300]
Running time: 239.0 sec
{'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08} 150.34276 (300, 10)
CPU times: user 4min 8s, sys: 36.2 s, total: 4min 44s
Wall time: 3min 59s


In [10]:
%%time

build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 112 arms, time budget = 0.03 sec'), FloatProgress(value=0.0, max=112…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(1, 10). metric: 390082
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(1, 10). metric: 309232
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(1, 10). metric: 346898
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 10). metric: 416598
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(1, 10). metric: 435713
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 10). metric: 458591
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 10). metric: 336311
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: 11153
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 10). metric: 10747
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: 3178
Hyperparam

Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.1}., sample shape=(10, 10). metric: 83090
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.031622775}., sample shape=(9, 10). metric: 60704
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.01}., sample shape=(9, 10). metric: 58946
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.0031622776}., sample shape=(10, 10). metric: 52544
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.001}., sample shape=(10, 10). metric: 51304
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 0.00031622776}., sample shape=(10, 10). metric: 53056
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-04}., sample shape=(10, 10). metric: 44199
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(10, 10). metric: 32222
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-05}., sample shape=(10, 10). metric: 39187
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(10, 10). metric: 16305
Hyperparams: {'batch_size': 800, 'L': 

HBox(children=(HTML(value='Iteration 2/4, 37 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=37.0…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(25, 10). metric: 116
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-06}., sample shape=(40, 10). metric: 114
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(39, 10). metric: 150
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(24, 10). metric: 230
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}., sample shape=(40, 10). metric: 183
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(24, 10). metric: 232
Hyperparams: {'batch_size': 800000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(2, 10). metric: 551
Hyperparams: {'batch_size': 80000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(2, 10). metric: 713
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(3, 10). metric: 375
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(6, 10). metric: 380
Hyperparams: {'batch_size': 80000

HBox(children=(HTML(value='Iteration 3/4, 12 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-06}., sample shape=(131, 10). metric: 126
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(82, 10). metric: 80
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(133, 10). metric: 90
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}., sample shape=(128, 10). metric: 218
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(81, 10). metric: 105
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(80, 10). metric: 125
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(16, 10). metric: 275
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(9, 10). metric: 307
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(17, 10). metric: 175
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(128, 10). metric: 175
Hyperparams: {'batch_size': 8000, 

HBox(children=(HTML(value='Iteration 4/4, 4 arms, time budget = 0.70 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(252, 10). metric: 67
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(398, 10). metric: 53
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(250, 10). metric: 48
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(253, 10). metric: 74

Number of samples: [250]
Running time: 270.2 sec
{'batch_size': 800, 'L': 10, 'dt': 1e-07} 48.145916 (250, 10)
CPU times: user 5min 18s, sys: 35.2 s, total: 5min 53s
Wall time: 4min 46s


In [11]:
%%time

build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 56 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=56.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: 627259
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: 621158
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: 597799
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: 486930
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: 237438
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(1, 10). metric: 66769
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: 23182
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: 5319
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: 1454
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: 2576
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: 3050
Hyperpar

HBox(children=(HTML(value='Iteration 2/3, 18 arms, time budget = 0.24 sec'), FloatProgress(value=0.0, max=18.0…

Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(13, 10). metric: 213
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(12, 10). metric: 729
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}., sample shape=(56, 10). metric: 287
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(12, 10). metric: 208
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-08}., sample shape=(202, 10). metric: 544
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(62, 10). metric: 429
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(67, 10). metric: 383
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-08}., sample shape=(66, 10). metric: 235
Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(184, 10). metric: 725
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(201, 10). metric: 896
Hyperparams: {'batch_size': 800, 'dt': 1e-07}., sample shape=(188, 10). metric: 566
Hyperparams: {'batch_size

HBox(children=(HTML(value='Iteration 3/3, 6 arms, time budget = 0.72 sec'), FloatProgress(value=0.0, max=6.0),…

Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(40, 10). metric: 206
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(37, 10). metric: 107
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-08}., sample shape=(224, 10). metric: 135
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}., sample shape=(206, 10). metric: 116
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(36, 10). metric: 132
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(227, 10). metric: 170

Number of samples: [37, 206]
Running time: 176.9 sec
{'batch_size': 80000, 'dt': 1e-06} 107.34649 (37, 10)
CPU times: user 3min 8s, sys: 39.3 s, total: 3min 47s
Wall time: 2min 57s


In [13]:
%%time

build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 56 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=56.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: 633599
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: 625572
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: 594703
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: 476958
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: 225221
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(1, 10). metric: 70078
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: 20501
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: 4429
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: 1291
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: 2516
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: 3030
Hyperpar

HBox(children=(HTML(value='Iteration 2/3, 18 arms, time budget = 0.24 sec'), FloatProgress(value=0.0, max=18.0…

Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(140, 10). metric: 25
Hyperparams: {'batch_size': 800, 'dt': 1e-07}., sample shape=(141, 10). metric: 44
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(36, 10). metric: 98
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(107, 10). metric: 51
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(95, 10). metric: 46
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(49, 10). metric: 55
Hyperparams: {'batch_size': 800, 'dt': 1e-05}., sample shape=(143, 10). metric: 102
Hyperparams: {'batch_size': 8000, 'dt': 3.1622778e-05}., sample shape=(52, 10). metric: 146
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(53, 10). metric: 45
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-08}., sample shape=(165, 10). metric: 37
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(55, 10). metric: 35
Hyperparams: {'batch_size': 8000, 'dt'

HBox(children=(HTML(value='Iteration 3/3, 6 arms, time budget = 0.72 sec'), FloatProgress(value=0.0, max=6.0),…

Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(560, 10). metric: 12
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(149, 10). metric: 20
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-08}., sample shape=(563, 10). metric: 30
Hyperparams: {'batch_size': 800, 'dt': 1e-07}., sample shape=(523, 10). metric: 30
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(162, 10). metric: 20
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(428, 10). metric: 18

Number of samples: [560, 428]
Running time: 243.1 sec
{'batch_size': 800, 'dt': 1e-06} 11.950268 (560, 10)
CPU times: user 4min 16s, sys: 57.9 s, total: 5min 14s
Wall time: 4min 11s
