In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random, vmap

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context
from models.logistic_regression.logistic_regression_model import gen_data, loglikelihood, logprior
from tuning.mamba import run_MAMBA
from tuning.ksd import FSSD_opt, imq_KSD, get_test_locations, linear_imq_KSD




# Logistic regression - MAMBA

In [2]:


key = random.PRNGKey(42)
dim = 10
# Ndata = 100_000
Ndata = 1_000_000

theta_true, X, y_data = gen_data(key, dim, Ndata)

# testing and training data
num_train = int(Ndata*0.8)

X_train = X[:num_train]
X_test = X[num_train:]

y_train = y_data[:num_train]
y_test = y_data[num_train:]
print(X_train.shape, X_test.shape)
data = (X_train, y_train)


generating data, with N=1,000,000 and dim=10
(800000, 10) (200000, 10)


In [3]:
# get MAP
batch_size = int(0.01*X_train.shape[0])
run_adam = build_adam_optimizer(1e-2, loglikelihood, logprior, data, batch_size)
key = random.PRNGKey(0)
centering_value, logpost_array = run_adam(key, 5000, jnp.zeros(dim))
params_IC = centering_value

fixed_bs = int(0.1*X_train.shape[0])

# T = 10
R = 3 # running time of longest sampler

error_fn = lambda x,y: FSSD_opt(x, y, get_test_locations(x), 100)
# error_fn = lambda x,y: imq_KSD(x, y)
# error_fn = lambda x,y: linear_imq_KSD(x, y)

logdt_range = -jnp.arange(1., 8., 0.5) 
# batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 5)]
batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 3)]

print(batch_size_range)

[800000, 80000, 8000]


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))

In [4]:
grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)

def grad_lp_LR_fb(theta):
    return grad_log_post_fb(theta, X_train, y_train)

batch_grad_lp_LR_fb = jit(vmap(grad_lp_LR_fb))

@jit
def get_fb_grads(samples):
    thin = 10
    array_samples = jnp.array(samples[::thin])
    mygrads = batch_grad_lp_LR_fb(array_samples)
    return array_samples, mygrads

### run MAMBA

In [5]:
# diag sds, another run
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)




HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(3, 10). metric: 23671
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(3, 10). metric: -26979
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(3, 10). metric: 228163
Hyperparams: {'batch_si

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 0.69 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(12, 10). metric: 2227
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(32, 10). metric: 1984
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(32, 10). metric: 1439
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(203, 10). metric: 485
Hyperparams: {'batch_size': 8000, 'dt': 0.031622775}., sample shape=(199, 10). metric: 3659
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(198, 10). metric: -2661
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(204, 10). metric: 6218
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(199, 10). metric: 762
Hyperparams: {'batch_size': 8000, 'dt': 0.01}., sample shape=(201, 10). metric: 24610
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(12, 10). metric: -305
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(37, 10). metric: 996
Hyperparams: {'batch_size'

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 2.42 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(728, 10). metric: -837
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(43, 10). metric: 186
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(744, 10). metric: 440
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(714, 10). metric: 31

Number of samples: [728]
Running time: 138.1 sec
{'batch_size': 8000, 'dt': 1e-05} -836.56055 (728, 10)


In [6]:
# diag sds
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(3, 10). metric: 23671
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(3, 10). metric: -26979
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(3, 10). metric: 228163
Hyperparams: {'batch_si

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 0.69 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(12, 10). metric: 2227
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(209, 10). metric: -5264
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(32, 10). metric: -1447
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(32, 10). metric: 758
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(198, 10). metric: 503
Hyperparams: {'batch_size': 8000, 'dt': 0.031622775}., sample shape=(200, 10). metric: 2438
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(202, 10). metric: 6589
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(204, 10). metric: 1181
Hyperparams: {'batch_size': 8000, 'dt': 0.01}., sample shape=(201, 10). metric: 23449
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(33, 10). metric: 6431
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(12, 10). metric: -276
Hyperparams: {'batch_siz

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 2.42 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(662, 10). metric: 1391
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(111, 10). metric: -958
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(43, 10). metric: -347
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(707, 10). metric: 415

Number of samples: [111]
Running time: 121.0 sec
{'batch_size': 80000, 'dt': 1e-05} -958.38495 (111, 10)


In [7]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 0.69 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(27, 10). metric: 1119
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(178, 10). metric: -156
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(164, 10). metric: 694
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(154, 10). metric: 100
Hyperparams: {'batch_size': 8000, 'dt': 0.031622775}., sample shape=(175, 10). metric: 6070
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(178, 10). metric: 1969
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(28, 10). metric: -226
Hyperparams: {'batch_size': 8000, 'dt': 0.01}., sample shape=(171, 10). metric: 42510
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(173, 10). metric: 2389
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(26, 10). metric: 3044
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}., sample shape=(172, 10). metric: 4744
Hyperparams: {'batch_size': 

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 2.42 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(97, 10). metric: 734
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(633, 10). metric: -17
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(597, 10). metric: 62
Hyperparams: {'batch_size': 8000, 'dt': 0.1}., sample shape=(593, 10). metric: 661

Number of samples: [633]
Running time: 149.8 sec
{'batch_size': 8000, 'dt': 1e-05} -16.761747 (633, 10)


In [8]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 84 arms, time budget = 0.07 sec'), FloatProgress(value=0.0, max=84.0…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 

HBox(children=(HTML(value='Iteration 2/4, 28 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=28.0…

Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(10, 10). metric: 7811
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(18, 10). metric: 10074
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(10, 10). metric: 15892
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(17, 10). metric: -9215
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(18, 10). metric: 21056
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(10, 10). metric: 57683
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(18, 10). metric: 35595
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size'

HBox(children=(HTML(value='Iteration 3/4, 9 arms, time budget = 0.70 sec'), FloatProgress(value=0.0, max=9.0),…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(57, 10). metric: -2342
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(31, 10). metric: -333
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(58, 10). metric: 1830
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(31, 10). metric: 5990
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(56, 10). metric: 15733
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(57, 10). metric: 15200
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(28, 10). metric: 8549
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(5, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(5, 10). metric: inf

Number of samples: [57, 31, 58]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 2.10 sec'), FloatProgress(value=0.0, max=3.0),…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(172, 10). metric: -1263
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(91, 10). metric: 910
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(176, 10). metric: 243

Number of samples: [172]
Running time: 193.7 sec
{'batch_size': 8000, 'L': 5, 'dt': 1e-06} -1263.4249 (172, 10)


In [9]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 84 arms, time budget = 0.07 sec'), FloatProgress(value=0.0, max=84.0…

Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 

HBox(children=(HTML(value='Iteration 2/4, 28 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=28.0…

Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(8, 10). metric: 5729
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(8, 10). metric: 727
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(15, 10). metric: 6893
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(15, 10). metric: 1502
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(15, 10). metric: 7680
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(8, 10). metric: 35625
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(15, 10). metric: 38757
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.01}., sample shape=(2, 10). metric: inf
Hyperparams: {'batch_size': 800000

HBox(children=(HTML(value='Iteration 3/4, 9 arms, time budget = 0.70 sec'), FloatProgress(value=0.0, max=9.0),…

Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(26, 10). metric: -441
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(48, 10). metric: -293
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(26, 10). metric: 1239
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(49, 10). metric: 2709
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(49, 10). metric: 7648
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(26, 10). metric: 3397
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(49, 10). metric: 12096
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.1}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'L': 5, 'dt': 0.031622775}., sample shape=(3, 10). metric: inf

Number of samples: [26, 48, 26]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 2.10 sec'), FloatProgress(value=0.0, max=3.0),…

Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(78, 10). metric: 142
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-06}., sample shape=(147, 10). metric: 295
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(78, 10). metric: 725

Number of samples: [78]
Running time: 222.6 sec
{'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-07} 141.6383 (78, 10)


In [10]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }

# build_kernel = lambda dt: build_sgnht_kernel(dt, loglikelihood, logprior, data, fixed_bs)
# grid_params = {'log_dt': logdt_range}
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(3, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(3, 10). metric: -716859
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(3, 10). metric: -162732
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(3, 10). metric: -41507
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(3, 10). metric: -248949
Hyperparams: {'b

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 0.69 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(12, 10). metric: -15566
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(12, 10). metric: -2067
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-07}., sample shape=(12, 10). metric: -8305
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(12, 10). metric: -8248
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(33, 10). metric: -29195
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(33, 10). metric: -20903
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(12, 10). metric: -5632
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(45, 10). metric: -6093
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(198, 10). metric: -10460
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(34, 10). metric: 857
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(199, 10). me

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 2.42 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(111, 10). metric: -9186
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(114, 10). metric: -5332
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(43, 10). metric: -1895
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(713, 10). metric: -1379

Number of samples: [111]
Running time: 165.4 sec
{'batch_size': 80000, 'dt': 3.1622776e-06} -9186.1875 (111, 10)


In [11]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }

# build_kernel = lambda dt: build_sgnhtCV_kernel(dt, loglikelihood, 
#                                                   logprior, data, fixed_bs, centering_value)
# grid_params = {'log_dt': logdt_range}
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 42 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=42.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.01}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.001}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}., sample shape=(1, 10). metric: inf
Hyperparams: {'batch_size': 800

HBox(children=(HTML(value='Iteration 2/3, 14 arms, time budget = 0.69 sec'), FloatProgress(value=0.0, max=14.0…

Hyperparams: {'batch_size': 80000, 'dt': 3.1622778e-05}., sample shape=(28, 10). metric: -5569
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(28, 10). metric: -1485
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(27, 10). metric: -3255
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(28, 10). metric: -1572
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(28, 10). metric: -3757
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(28, 10). metric: -1372
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-08}., sample shape=(27, 10). metric: -2989
Hyperparams: {'batch_size': 8000, 'dt': 3.1622778e-05}., sample shape=(174, 10). metric: 33
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(174, 10). metric: -516
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(174, 10). metric: -339
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(174, 10). metric: -252
H

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 2.42 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 80000, 'dt': 3.1622778e-05}., sample shape=(97, 10). metric: -751
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(98, 10). metric: -378
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(94, 10). metric: -376
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-08}., sample shape=(97, 10). metric: -78

Number of samples: [97]
Running time: 163.0 sec
{'batch_size': 80000, 'dt': 3.1622778e-05} -751.2123 (97, 10)
