In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random, vmap

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context
from models.logistic_regression.logistic_regression_model import gen_data, loglikelihood, logprior
from tuning.mamba import run_MAMBA
from tuning.ksd import imq_KSD


# Logistic regression - MAMBA

In [2]:


key = random.PRNGKey(42)
dim = 10
# Ndata = 100_000
Ndata = 1_000_000

theta_true, X, y_data = gen_data(key, dim, Ndata)

# testing and training data
num_train = int(Ndata*0.8)

X_train = X[:num_train]
X_test = X[num_train:]

y_train = y_data[:num_train]
y_test = y_data[num_train:]
print(X_train.shape, X_test.shape)
data = (X_train, y_train)




generating data, with N=1,000,000 and dim=10
(800000, 10) (200000, 10)


In [4]:
# get MAP
batch_size = int(0.01*X_train.shape[0])
run_adam = build_adam_optimizer(1e-2, loglikelihood, logprior, data, batch_size)
key = random.PRNGKey(0)
centering_value, logpost_array = run_adam(key, 5000, jnp.zeros(dim))
params_IC = centering_value

# T = 10
R = 1 # running time of longest sampler

error_fn = lambda x,y: imq_KSD(x, y)
logdt_range = -jnp.arange(1., 8., 0.5) 
batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 5)]

In [6]:
grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)

def grad_lp_LR_fb(theta):
    return grad_log_post_fb(theta, X_train, y_train)

batch_grad_lp_LR_fb = jit(vmap(grad_lp_LR_fb))

@jit
def get_fb_grads(samples):
    thin = 10
    array_samples = jnp.array(samples[::thin])
    mygrads = batch_grad_lp_LR_fb(array_samples)
    return array_samples, mygrads

### run MAMBA

In [7]:
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 70 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=70.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}. metric: 313867
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}. metric: 310463
Hyperparams: {'batch_size': 800000, 'dt': 0.01}. metric: 299682
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}. metric: 244357
Hyperparams: {'batch_size': 800000, 'dt': 0.001}. metric: 115542
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}. metric: 35102
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}. metric: 12068
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}. metric: 39675
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}. metric: 846
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}. metric: 1159
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}. metric: 2127
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-07}. metric: 2993
Hyperparams: {'batch_size': 800000, 'dt': 1e-07}. metric: 3141
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-08}. metric: 3174
Hyperparams: {'batch_size': 80000, 'dt': 0.1}. metric: 2

HBox(children=(HTML(value='Iteration 2/3, 23 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=23.0…

Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 400
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}. metric: 115
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}. metric: 373
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 353
Hyperparams: {'batch_size': 800, 'dt': 1e-07}. metric: 287
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 856
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 214
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}. metric: 317
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}. metric: 354
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}. metric: 412
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}. metric: 634
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}. metric: 386
Hyperparams: {'batch_size': 800, 'dt': 1e-05}. metric: 1214
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}. metric: 399
Hyperparams: {'batch_size': 80, 'dt': 1e-06}. metric: 1237
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}. 

HBox(children=(HTML(value='Iteration 3/3, 7 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=7.0),…

Hyperparams: {'batch_size': 80000, 'dt': 1e-05}. metric: 193
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 252
Hyperparams: {'batch_size': 800, 'dt': 1e-07}. metric: 421
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}. metric: 331
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 212
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}. metric: 178
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}. metric: 358

Number of samples: [13, 42]
Running time: 161.0 sec
{'batch_size': 800000, 'dt': 3.1622776e-06} 177.81099 (13, 10)


In [8]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 70 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=70.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}. metric: 633948
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}. metric: 625294
Hyperparams: {'batch_size': 800000, 'dt': 0.01}. metric: 596300
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}. metric: 477570
Hyperparams: {'batch_size': 800000, 'dt': 0.001}. metric: 217501
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}. metric: 73072
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}. metric: 19466
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}. metric: 3622
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}. metric: 930
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}. metric: 2590
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}. metric: 2878
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-07}. metric: 3219
Hyperparams: {'batch_size': 800000, 'dt': 1e-07}. metric: 3166
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-08}. metric: 3207
Hyperparams: {'batch_size': 80000, 'dt': 0.1}. metric: 21

HBox(children=(HTML(value='Iteration 2/3, 23 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=23.0…

Hyperparams: {'batch_size': 800, 'dt': 1e-05}. metric: 68
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 69
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}. metric: 195
Hyperparams: {'batch_size': 80, 'dt': 1e-05}. metric: 156
Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 113
Hyperparams: {'batch_size': 80, 'dt': 3.1622776e-06}. metric: 70
Hyperparams: {'batch_size': 80, 'dt': 1e-06}. metric: 146
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 82
Hyperparams: {'batch_size': 80, 'dt': 3.1622776e-07}. metric: 142
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}. metric: 203
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 201
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}. metric: 298
Hyperparams: {'batch_size': 800, 'dt': 1e-07}. metric: 354
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}. metric: 543
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}. metric: 323
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}. metric: 3

HBox(children=(HTML(value='Iteration 3/3, 7 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=7.0),…

Hyperparams: {'batch_size': 800, 'dt': 1e-05}. metric: 61
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 30
Hyperparams: {'batch_size': 80, 'dt': 3.1622776e-06}. metric: 39
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 59
Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 58
Hyperparams: {'batch_size': 80, 'dt': 3.1622776e-07}. metric: 170
Hyperparams: {'batch_size': 80, 'dt': 1e-06}. metric: 71

Number of samples: [488, 500]
Running time: 171.8 sec
{'batch_size': 800, 'dt': 3.1622776e-06} 29.6894 (488, 10)


In [9]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 140 arms, time budget = 0.03 sec'), FloatProgress(value=0.0, max=140…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}. metric: 391962
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}. metric: 315911
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}. metric: 351372
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}. metric: 416450
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}. metric: 436851
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}. metric: 463559
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}. metric: 347783
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}. metric: 15751
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}. metric: 10584
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}. metric: 2970
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-06}. metric: 1141
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-07}. metric: 2790
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-07}. metric: 2926
Hyperparams: {'batc

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}. metric: 1110
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}. metric: 676
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 0.1}. metric: 56658
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 0.031622775}. metric: 48415
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 0.01}. metric: 29002
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 0.0031622776}. metric: 39723
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 0.001}. metric: 32819
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 0.00031622776}. metric: 32871
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-04}. metric: 36657
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 3.1622778e-05}. metric: 40703
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-05}. metric: 20030
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 3.1622776e-06}. metric: 18711
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-06}. metric: 5013
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 3.1622776e-07}. metric: 2569

HBox(children=(HTML(value='Iteration 2/4, 46 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=46.0…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}. metric: 372
Hyperparams: {'batch_size': 800000, 'L': 10, 'dt': 3.1622776e-07}. metric: 616
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}. metric: 513
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}. metric: 728
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}. metric: 442
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}. metric: 679
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-06}. metric: 484
Hyperparams: {'batch_size': 80000, 'L': 10, 'dt': 3.1622776e-07}. metric: 880
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-06}. metric: 1152
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}. metric: 616
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-06}. metric: 835
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}. metric: 749
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-08}. metric: 558
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 3.162277

HBox(children=(HTML(value='Iteration 3/4, 15 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=15.0…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}. metric: 186
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}. metric: 408
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-06}. metric: 746
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}. metric: 534
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-08}. metric: 191
Hyperparams: {'batch_size': 800000, 'L': 10, 'dt': 3.1622776e-07}. metric: 637
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}. metric: 414
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-07}. metric: 1165
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}. metric: 381
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}. metric: 384
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}. metric: 602
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}. metric: 410
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}. metric: 218
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-06}.

HBox(children=(HTML(value='Iteration 4/4, 5 arms, time budget = 0.70 sec'), FloatProgress(value=0.0, max=5.0),…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}. metric: 257
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-08}. metric: 200
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}. metric: 339
Hyperparams: {'batch_size': 80000, 'L': 10, 'dt': 3.1622776e-07}. metric: 222
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}. metric: 207

Number of samples: [417]
Running time: 319.8 sec
{'batch_size': 800, 'L': 5, 'dt': 3.1622776e-08} 199.65329 (417, 10)


In [10]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 140 arms, time budget = 0.03 sec'), FloatProgress(value=0.0, max=140…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}. metric: 390082
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}. metric: 309232
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}. metric: 346898
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}. metric: 416598
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}. metric: 435713
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}. metric: 458591
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}. metric: 336311
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}. metric: 11153
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}. metric: 10747
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}. metric: 3178
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-06}. metric: 1073
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-07}. metric: 2778
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-07}. metric: 2921
Hyperparams: {'batc

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}. metric: 555
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}. metric: 732
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 0.1}. metric: 61767
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 0.031622775}. metric: 49054
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 0.01}. metric: 37992
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 0.0031622776}. metric: 40190
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 0.001}. metric: 40891
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 0.00031622776}. metric: 37804
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-04}. metric: 32807
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 3.1622778e-05}. metric: 22807
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-05}. metric: 18565
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 3.1622776e-06}. metric: 6370
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-06}. metric: 206
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 3.1622776e-07}. metric: 383
Hyp

HBox(children=(HTML(value='Iteration 2/4, 46 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=46.0…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-07}. metric: 120
Hyperparams: {'batch_size': 80, 'L': 10, 'dt': 1e-07}. metric: 122
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-06}. metric: 121
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-06}. metric: 168
Hyperparams: {'batch_size': 80, 'L': 10, 'dt': 3.1622776e-07}. metric: 132
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}. metric: 144
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 3.1622776e-07}. metric: 171
Hyperparams: {'batch_size': 80, 'L': 10, 'dt': 3.1622776e-08}. metric: 163
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-07}. metric: 263
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}. metric: 135
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}. metric: 201
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}. metric: 271
Hyperparams: {'batch_size': 800000, 'L': 10, 'dt': 3.1622776e-07}. metric: 644
Hyperparams: {'batch_size': 80000, 'L': 10, 'dt': 3.1622776e-07}. met

HBox(children=(HTML(value='Iteration 3/4, 15 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=15.0…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-07}. metric: 104
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-06}. metric: 65
Hyperparams: {'batch_size': 80, 'L': 10, 'dt': 1e-07}. metric: 58
Hyperparams: {'batch_size': 80, 'L': 10, 'dt': 3.1622776e-07}. metric: 100
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}. metric: 76
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}. metric: 128
Hyperparams: {'batch_size': 80, 'L': 10, 'dt': 3.1622776e-08}. metric: 114
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-06}. metric: 70
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 3.1622776e-07}. metric: 93
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}. metric: 160
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 3.1622776e-08}. metric: 203
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}. metric: 191
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-07}. metric: 109
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}. metric: 93
Hyperparams:

HBox(children=(HTML(value='Iteration 4/4, 5 arms, time budget = 0.70 sec'), FloatProgress(value=0.0, max=5.0),…

Hyperparams: {'batch_size': 80, 'L': 10, 'dt': 1e-07}. metric: 39
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-06}. metric: 53
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-06}. metric: 39
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}. metric: 53
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}. metric: 104

Number of samples: [583]
Running time: 367.5 sec
{'batch_size': 80, 'L': 10, 'dt': 1e-07} 38.839844 (583, 10)


In [11]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 70 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=70.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}. metric: 627259
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}. metric: 621158
Hyperparams: {'batch_size': 800000, 'dt': 0.01}. metric: 597799
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}. metric: 486930
Hyperparams: {'batch_size': 800000, 'dt': 0.001}. metric: 237438
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}. metric: 66769
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}. metric: 23182
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}. metric: 5319
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}. metric: 1454
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}. metric: 2576
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}. metric: 3050
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-07}. metric: 3183
Hyperparams: {'batch_size': 800000, 'dt': 1e-07}. metric: 3166
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-08}. metric: 3182
Hyperparams: {'batch_size': 80000, 'dt': 0.1}. metric: i

HBox(children=(HTML(value='Iteration 2/3, 23 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=23.0…

Hyperparams: {'batch_size': 80000, 'dt': 1e-06}. metric: 644
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}. metric: 692
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}. metric: 244
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}. metric: 339
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 338
Hyperparams: {'batch_size': 800, 'dt': 1e-07}. metric: 752
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-08}. metric: 389
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}. metric: 204
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}. metric: 353
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 456
Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 809
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-08}. metric: 725
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 1383
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}. metric: 1530
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}. metric: 229
Hyperparams: {'batch_size': 800, 'dt': 

HBox(children=(HTML(value='Iteration 3/3, 7 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=7.0),…

Hyperparams: {'batch_size': 800000, 'dt': 1e-06}. metric: 164
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}. metric: 94
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}. metric: 211
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}. metric: 171
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}. metric: 169
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 238
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}. metric: 152

Number of samples: [15, 40]
Running time: 233.1 sec
{'batch_size': 800000, 'dt': 3.1622776e-06} 94.33203 (15, 10)


In [12]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 70 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=70.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}. metric: 633599
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}. metric: 625572
Hyperparams: {'batch_size': 800000, 'dt': 0.01}. metric: 594703
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}. metric: 476958
Hyperparams: {'batch_size': 800000, 'dt': 0.001}. metric: 225221
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}. metric: 70078
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}. metric: 20501
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}. metric: 4429
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}. metric: 1291
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}. metric: 2516
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}. metric: 3030
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-07}. metric: 3179
Hyperparams: {'batch_size': 800000, 'dt': 1e-07}. metric: 3165
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-08}. metric: 3182
Hyperparams: {'batch_size': 80000, 'dt': 0.1}. metric: i

HBox(children=(HTML(value='Iteration 2/3, 23 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=23.0…

Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 18
Hyperparams: {'batch_size': 800, 'dt': 1e-07}. metric: 52
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}. metric: 36
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-08}. metric: 78
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}. metric: 54
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 52
Hyperparams: {'batch_size': 80, 'dt': 3.1622776e-08}. metric: 106
Hyperparams: {'batch_size': 800, 'dt': 1e-05}. metric: 108
Hyperparams: {'batch_size': 80, 'dt': 1e-06}. metric: 88
Hyperparams: {'batch_size': 80, 'dt': 3.1622776e-07}. metric: 80
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}. metric: 50
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 32
Hyperparams: {'batch_size': 8000, 'dt': 3.1622778e-05}. metric: 85
Hyperparams: {'batch_size': 80, 'dt': 1e-07}. metric: 98
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 50
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}. metric: 91
Hype

HBox(children=(HTML(value='Iteration 3/3, 7 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=7.0),…

Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 12
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 10
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}. metric: 11
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 17
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}. metric: 33
Hyperparams: {'batch_size': 800, 'dt': 1e-07}. metric: 27
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 16

Number of samples: [194, 534]
Running time: 270.3 sec
{'batch_size': 8000, 'dt': 1e-06} 9.666649 (194, 10)
