In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random, vmap

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context
from models.logistic_regression.logistic_regression_model import gen_data, loglikelihood, logprior
from tuning.mamba import run_MAMBA
from tuning.ksd import FSSD_opt, imq_KSD, get_test_locations, linear_imq_KSD




# Logistic regression - MAMBA

In [2]:


key = random.PRNGKey(42)
dim = 10
# Ndata = 100_000
Ndata = 1_000_000

theta_true, X, y_data = gen_data(key, dim, Ndata)

# testing and training data
num_train = int(Ndata*0.8)

X_train = X[:num_train]
X_test = X[num_train:]

y_train = y_data[:num_train]
y_test = y_data[num_train:]
print(X_train.shape, X_test.shape)
data = (X_train, y_train)


generating data, with N=1,000,000 and dim=10
(800000, 10) (200000, 10)


In [5]:
# get MAP
batch_size = int(0.01*X_train.shape[0])
run_adam = build_adam_optimizer(1e-2, loglikelihood, logprior, data, batch_size)
key = random.PRNGKey(0)
centering_value, logpost_array = run_adam(key, 5000, jnp.zeros(dim))
params_IC = centering_value

fixed_bs = int(0.1*X_train.shape[0])

# T = 10
R = 2 # running time of longest sampler

error_fn = lambda x,y: FSSD_opt(x, y, get_test_locations(x), 100)
# error_fn = lambda x,y: imq_KSD(x, y)
# error_fn = lambda x,y: linear_imq_KSD(x, y)

logdt_range = -jnp.arange(1., 8., 0.5) 
# batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 5)]
batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 3)]


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))




In [6]:
batch_size_range

[800000, 80000, 8000]

In [7]:
grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)

def grad_lp_LR_fb(theta):
    return grad_log_post_fb(theta, X_train, y_train)

batch_grad_lp_LR_fb = jit(vmap(grad_lp_LR_fb))

@jit
def get_fb_grads(samples):
    thin = 10
    array_samples = jnp.array(samples[::thin])
    mygrads = batch_grad_lp_LR_fb(array_samples)
    return array_samples, mygrads

### run MAMBA

In [10]:
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }

# build_kernel = lambda dt: build_sgld_kernel(dt, loglikelihood, logprior, data, fixed_bs)
# grid_params = {'log_dt': logdt_range}
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.15 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(2, 10). metric: 59201
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(2, 10). metric: -87406
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(2, 10). metric: 390315
Hyperparams: {'batch_si

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 0.46 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(8, 10). metric: 12121
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(22, 10). metric: 1045
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(129, 10). metric: 409
Hyperparams: {'batch_size': 8000, 'dt': 0.031622775}., sample shape=(125, 10). metric: 3455
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(22, 10). metric: 3257
Hyperparams: {'batch_size': 8000, 'dt': 0.01}., sample shape=(128, 10). metric: 16690
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(132, 10). metric: 4945
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(132, 10). metric: -2941
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(127, 10). metric: 3197
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(128, 10). metric: 11510
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(8, 10). metric: 1837
Hyperparams: {'b

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 1.62 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(485, 10). metric: 451
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(475, 10). metric: 489
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(75, 10). metric: -1233
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(29, 10). metric: -1314

Number of samples: [29]
Running time: 117.0 sec
{'batch_size': 800000, 'dt': 1e-05} -1313.6257 (29, 10)


In [14]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

# build_kernel = lambda dt: build_sgldCV_kernel(dt, loglikelihood, 
#                                           logprior, data, fixed_bs, centering_value)
# grid_params = {'log_dt': logdt_range}

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.15 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 0.46 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(18, 10). metric: 986
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(116, 10). metric: 244
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(110, 10). metric: 610
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(113, 10). metric: 329
Hyperparams: {'batch_size': 8000, 'dt': 0.031622775}., sample shape=(115, 10). metric: 4000
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(113, 10). metric: 3027
Hyperparams: {'batch_size': 8000, 'dt': 0.01}., sample shape=(117, 10). metric: 43135
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(112, 10). metric: 8560
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(18, 10). metric: 1542
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(21, 10). metric: 1838
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}., sample shape=(116, 10). metric: 2294
Hyperparams: {'batch_size': 80

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 1.62 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(428, 10). metric: -30
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(421, 10). metric: 152
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(418, 10). metric: 659
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(66, 10). metric: -205

Number of samples: [66]
Running time: 110.0 sec
{'batch_size': 80000, 'dt': 1e-05} -204.73564 (66, 10)


In [16]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }

# build_kernel = lambda dt, L: build_sghmc_kernel(dt, L, loglikelihood, 
#                                                   logprior, data, fixed_bs)
# grid_params = {'log_dt': logdt_range, 
#                "L": [5, 10]
#               }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 84 arms, time budget = 0.05 sec'), FloatProgress(value=0.0, max=84.0…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 

HBox(children=(HTML(value='Iteration 2/4, 28 arms, time budget = 0.15 sec'), FloatProgress(value=0.0, max=28.0…

Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(7, 10). metric: 854
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(12, 10). metric: 16339
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(7, 10). metric: 68222
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(12, 10). metric: -17310
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(7, 10). metric: 13594
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(12, 10). metric: 24180
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(12, 10). metric: 54896
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 8

HBox(children=(HTML(value='Iteration 3/4, 9 arms, time budget = 0.47 sec'), FloatProgress(value=0.0, max=9.0),…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(38, 10). metric: -7961
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(21, 10). metric: 2740
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(21, 10). metric: 10702
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(38, 10). metric: 7042
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(38, 10). metric: 13704
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(37, 10). metric: 19157
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(21, 10). metric: -6979
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(4, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(4, 10). metric: inf

Number of samples: [38, 21, 21]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 1.40 sec'), FloatProgress(value=0.0, max=3.0),…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(114, 10). metric: -907
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(62, 10). metric: 1184
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(62, 10). metric: 1651

Number of samples: [114]
Running time: 174.5 sec
{'batch_size': 8000, 'L': 5, 'dt': 1e-06} -906.7637 (114, 10)


In [17]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }

# build_kernel = lambda dt, L: build_sghmcCV_kernel(dt, L, loglikelihood, 
#                                                   logprior, data, fixed_bs, centering_value)
# grid_params = {'log_dt': logdt_range, 
#                "L": [5, 10]
#               }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 84 arms, time budget = 0.05 sec'), FloatProgress(value=0.0, max=84.0…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 

HBox(children=(HTML(value='Iteration 2/4, 28 arms, time budget = 0.15 sec'), FloatProgress(value=0.0, max=28.0…

Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(6, 10). metric: 992
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(6, 10). metric: -67
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(10, 10). metric: 5152
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(11, 10). metric: 859
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(10, 10). metric: 5894
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(6, 10). metric: 50396
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(10, 10). metric: 104862
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(10, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size':

HBox(children=(HTML(value='Iteration 3/4, 9 arms, time budget = 0.47 sec'), FloatProgress(value=0.0, max=9.0),…

Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(18, 10). metric: -233
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(34, 10). metric: -780
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(18, 10). metric: 1646
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(32, 10). metric: 4149
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(32, 10). metric: 7156
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(18, 10). metric: 5643
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(32, 10). metric: 21566
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(33, 10). metric: 48600852
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(3, 10). metric: inf

Number of samples: [34, 18, 18]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 1.40 sec'), FloatProgress(value=0.0, max=3.0),…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(101, 10). metric: 295
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(53, 10). metric: 558
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(52, 10). metric: 1268

Number of samples: [101]
Running time: 212.3 sec
{'batch_size': 8000, 'L': 5, 'dt': 1e-06} 295.35696 (101, 10)


In [13]:
# build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)
# grid_params = {'log_dt': logdt_range, 
#                'batch_size': batch_size_range
#               }

build_kernel = lambda dt: build_sgnht_kernel(dt, loglikelihood, logprior, data, fixed_bs)
grid_params = {'log_dt': logdt_range}
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params)#, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/2, 14 arms, time budget = 0.25 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'dt': 0.1}., sample shape=(97, 10). metric: inf
Hyperparams: {'dt': 0.031622775}., sample shape=(97, 10). metric: inf
Hyperparams: {'dt': 0.01}., sample shape=(97, 10). metric: inf
Hyperparams: {'dt': 0.0031622776}., sample shape=(97, 10). metric: inf
Hyperparams: {'dt': 0.001}., sample shape=(99, 10). metric: inf
Hyperparams: {'dt': 0.00031622776}., sample shape=(98, 10). metric: inf
Hyperparams: {'dt': 1e-04}., sample shape=(97, 10). metric: inf
Hyperparams: {'dt': 3.1622778e-05}., sample shape=(99, 10). metric: 313
Hyperparams: {'dt': 1e-05}., sample shape=(97, 10). metric: 243
Hyperparams: {'dt': 3.1622776e-06}., sample shape=(99, 10). metric: 189
Hyperparams: {'dt': 1e-06}., sample shape=(97, 10). metric: 165
Hyperparams: {'dt': 3.1622776e-07}., sample shape=(98, 10). metric: 170
Hyperparams: {'dt': 1e-07}., sample shape=(95, 10). metric: 325
Hyperparams: {'dt': 3.1622776e-08}., sample shape=(96, 10). metric: 497

Number of samples: [97, 98, 99, 97]


HBox(children=(HTML(value='Iteration 2/2, 4 arms, time budget = 0.88 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'dt': 1e-06}., sample shape=(386, 10). metric: 90
Hyperparams: {'dt': 3.1622776e-07}., sample shape=(382, 10). metric: 63
Hyperparams: {'dt': 3.1622776e-06}., sample shape=(383, 10). metric: 90
Hyperparams: {'dt': 1e-05}., sample shape=(384, 10). metric: 152

Number of samples: [382]
Running time: 43.1 sec
{'dt': 3.1622776e-07} 63.38139 (382, 10)


In [14]:
# build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
#                                                   logprior, data, batch_size, centering_value)
# grid_params = {'log_dt': logdt_range, 
#                'batch_size': batch_size_range
#               }

build_kernel = lambda dt: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, fixed_bs, centering_value)
grid_params = {'log_dt': logdt_range}
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/2, 14 arms, time budget = 0.25 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'dt': 0.1}., sample shape=(8, 10). metric: inf
Hyperparams: {'dt': 0.031622775}., sample shape=(8, 10). metric: inf
Hyperparams: {'dt': 0.01}., sample shape=(8, 10). metric: inf
Hyperparams: {'dt': 0.0031622776}., sample shape=(7, 10). metric: inf
Hyperparams: {'dt': 0.001}., sample shape=(8, 10). metric: inf
Hyperparams: {'dt': 0.00031622776}., sample shape=(7, 10). metric: inf
Hyperparams: {'dt': 1e-04}., sample shape=(8, 10). metric: inf
Hyperparams: {'dt': 3.1622778e-05}., sample shape=(7, 10). metric: 286
Hyperparams: {'dt': 1e-05}., sample shape=(8, 10). metric: 219
Hyperparams: {'dt': 3.1622776e-06}., sample shape=(8, 10). metric: 280
Hyperparams: {'dt': 1e-06}., sample shape=(7, 10). metric: 173
Hyperparams: {'dt': 3.1622776e-07}., sample shape=(8, 10). metric: 204
Hyperparams: {'dt': 1e-07}., sample shape=(7, 10). metric: 359
Hyperparams: {'dt': 3.1622776e-08}., sample shape=(8, 10). metric: 549

Number of samples: [7, 8, 8, 8]


HBox(children=(HTML(value='Iteration 2/2, 4 arms, time budget = 0.88 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'dt': 1e-06}., sample shape=(32, 10). metric: 58
Hyperparams: {'dt': 3.1622776e-07}., sample shape=(33, 10). metric: 56
Hyperparams: {'dt': 1e-05}., sample shape=(32, 10). metric: 87
Hyperparams: {'dt': 3.1622776e-06}., sample shape=(33, 10). metric: 12

Number of samples: [33]
Running time: 50.8 sec
{'dt': 3.1622776e-06} 12.241525 (33, 10)
