In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random, vmap

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context
from models.logistic_regression.logistic_regression_model import gen_data, loglikelihood, logprior
from tuning.mamba import run_MAMBA
from tuning.ksd import FSSD_opt, imq_KSD, get_test_locations, linear_imq_KSD




# Logistic regression - MAMBA

In [2]:


key = random.PRNGKey(42)
dim = 10
# Ndata = 100_000
Ndata = 1_000_000

theta_true, X, y_data = gen_data(key, dim, Ndata)

# testing and training data
num_train = int(Ndata*0.8)

X_train = X[:num_train]
X_test = X[num_train:]

y_train = y_data[:num_train]
y_test = y_data[num_train:]
print(X_train.shape, X_test.shape)
data = (X_train, y_train)


generating data, with N=1,000,000 and dim=10
(800000, 10) (200000, 10)


In [3]:
# get MAP
batch_size = int(0.01*X_train.shape[0])
run_adam = build_adam_optimizer(1e-2, loglikelihood, logprior, data, batch_size)
key = random.PRNGKey(0)
centering_value, logpost_array = run_adam(key, 5000, jnp.zeros(dim))
params_IC = centering_value

fixed_bs = int(0.1*X_train.shape[0])

R = 1 # running time of longest sampler

error_fn = lambda x,y: FSSD_opt(x, y, get_test_locations(x), 100)
# error_fn = lambda x,y: imq_KSD(x, y)
# error_fn = lambda x,y: linear_imq_KSD(x, y)

logdt_range = -jnp.arange(1., 8., 0.5) 
# batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 5)]
batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 3)]

print(batch_size_range)

[800000, 80000, 8000]


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))

In [4]:
grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)

def grad_lp_LR_fb(theta):
    return grad_log_post_fb(theta, X_train, y_train)

batch_grad_lp_LR_fb = jit(vmap(grad_lp_LR_fb))

@jit
def get_fb_grads(samples):
    thin = 10
    array_samples = jnp.array(samples[::thin])
    mygrads = batch_grad_lp_LR_fb(array_samples)
    return array_samples, mygrads

### run MAMBA

In [11]:
# diag sds, another run
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(2, 10). metric: -87406
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(5, 10). metric: 16659
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(13, 10). metric: 6891
Hyperparams: {'batch_size': 8000, 'dt': 0.031622775}., sample shape=(68, 10). metric: 2241
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(66, 10). metric: 376
Hyperparams: {'batch_size': 8000, 'dt': 0.01}., sample shape=(60, 10). metric: 23369
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(64, 10). metric: 295
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(62, 10). metric: 10421
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(66, 10). metric: 20598
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(13, 10). metric: 1482
Hyperparams: {'batch_size': 8000, 'dt': 0.0031622776}., sample shape=(64, 10). metric: 210919
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(11, 10). metric: 12621
Hyperparams: {

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 0.81 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(237, 10). metric: 438
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(229, 10). metric: 366
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(231, 10). metric: 398
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(41, 10). metric: 567

Number of samples: [229]
Running time: 78.3 sec
{'batch_size': 8000, 'dt': 3.1622776e-06} 365.65576 (229, 10)


In [6]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(9, 10). metric: 2520
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(40, 10). metric: -641
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(42, 10). metric: 683
Hyperparams: {'batch_size': 8000, 'dt': 0.031622775}., sample shape=(37, 10). metric: 5603
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(35, 10). metric: 408
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(8, 10). metric: 2279
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(39, 10). metric: 3235
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(44, 10). metric: 24862
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(8, 10). metric: 12597
Hyperparams: {'batch_size': 8000, 'dt': 0.0031622776}., sample shape=(42, 10). metric: 480276
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}., sample shape=(52, 10). metric: 46754
Hyperparams: {'batch_size':

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 0.81 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(178, 10). metric: -160
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(178, 10). metric: 289
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(191, 10). metric: 684
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(32, 10). metric: -78

Number of samples: [178]
Running time: 114.3 sec
{'batch_size': 8000, 'dt': 1e-05} -160.46092 (178, 10)


In [7]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 84 arms, time budget = 0.03 sec'), FloatProgress(value=0.0, max=84.0…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 

HBox(children=(HTML(value='Iteration 2/4, 28 arms, time budget = 0.07 sec'), FloatProgress(value=0.0, max=28.0…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(6, 10). metric: -50459
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(4, 10). metric: 570797
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000

HBox(children=(HTML(value='Iteration 3/4, 9 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=9.0),…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(11, 10). metric: -48144
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(15, 10). metric: 61036
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(3, 10). metric: inf

Number of samples: [11, 15, 3]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 0.70 sec'), FloatProgress(value=0.0, max=3.0),…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(42, 10). metric: -6139
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(30, 10). metric: 22678
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(5, 10). metric: inf

Number of samples: [42]
Running time: 187.7 sec
{'batch_size': 8000, 'L': 5, 'dt': 1e-06} -6139.0854 (42, 10)


In [8]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 84 arms, time budget = 0.03 sec'), FloatProgress(value=0.0, max=84.0…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 

HBox(children=(HTML(value='Iteration 2/4, 28 arms, time budget = 0.07 sec'), FloatProgress(value=0.0, max=28.0…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(5, 10). metric: 3697
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(6, 10). metric: 45088
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(5, 10). metric: 519547
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L

HBox(children=(HTML(value='Iteration 3/4, 9 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=9.0),…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(15, 10). metric: -717
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(15, 10). metric: 9555
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(15, 10). metric: 17129
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(14, 10). metric: 49535212
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(3, 10). metric: inf

Number of samples: [15, 15, 15]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 0.70 sec'), FloatProgress(value=0.0, max=3.0),…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(40, 10). metric: 1034
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(33, 10). metric: 12625
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(40, 10). metric: 16804

Number of samples: [40]
Running time: 252.1 sec
{'batch_size': 8000, 'L': 5, 'dt': 1e-06} 1034.2524 (40, 10)


In [9]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }

# build_kernel = lambda dt: build_sgnht_kernel(dt, loglikelihood, logprior, data, fixed_bs)
# grid_params = {'log_dt': logdt_range}
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(9, 10). metric: -169525
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(12, 10). metric: -14698
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(41, 10). metric: -26631
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(10, 10). metric: 11241
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(44, 10). metric: -31184
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(12, 10). metric: -30713
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(36, 10). metric: -88909
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(10, 10). metric: inf
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-08}., sample shape=(48, 10). metric: 9157
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}., sample shape=(40, 10). metric: -9938
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(37, 10). metric: 4952
Hyperpa

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 0.81 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(33, 10). metric: -12913
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(113, 10). metric: -25292
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(123, 10). metric: -11131
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(30, 10). metric: -2922

Number of samples: [113]
Running time: 178.0 sec
{'batch_size': 8000, 'dt': 1e-06} -25291.873 (113, 10)


In [10]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }

# build_kernel = lambda dt: build_sgnhtCV_kernel(dt, loglikelihood, 
#                                                   logprior, data, fixed_bs, centering_value)
# grid_params = {'log_dt': logdt_range}
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 80000, 'dt': 3.1622778e-05}., sample shape=(8, 10). metric: -60347
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(9, 10). metric: -16301
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(8, 10). metric: -41505
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(8, 10). metric: -107778
Hyperparams: {'batch_size': 8000, 'dt': 3.1622778e-05}., sample shape=(39, 10). metric: -7920
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(34, 10). metric: -2460
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(36, 10). metric: -1824
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(39, 10). metric: -4188
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}., sample shape=(41, 10). metric: 377
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(42, 10). metric: -921
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(8, 10). metric: -53716
Hyp

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 0.81 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(25, 10). metric: -10110
Hyperparams: {'batch_size': 80000, 'dt': 3.1622778e-05}., sample shape=(27, 10). metric: -8234
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(30, 10). metric: -1160
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(26, 10). metric: -3413

Number of samples: [25]
Running time: 176.9 sec
{'batch_size': 80000, 'dt': 1e-05} -10110.214 (25, 10)
