In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context

from models.PMF.pmf_model import logprior, loglikelihood
from models.PMF.util import R_train, R_test, load_PMF_MAP

from tuning.mamba import run_MAMBA
from tuning.ksd import imq_KSD




# PMF - MAMBA

In [2]:
data = (R_train,)


key = random.PRNGKey(0)
centering_value = load_PMF_MAP()
params_IC = load_PMF_MAP()

err_fn = lambda x,y: imq_KSD(x[:], y[:])

grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)

def get_fb_grads(samples):
    thin = 10
    mygrads = [grad_log_post_fb(sam, *data) for sam in samples[::thin]]
    
    return samples[::thin], mygrads

R = 10
batch_size_range = [int(10**(-elem)*R_train.shape[0]) for elem in range(0, 3)]
print(batch_size_range)

[80000, 8000, 800]


### run MAMBA

In [3]:
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)


HBox(children=(HTML(value='Iteration 1/3, 36 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=36.0…

Hyperparams: {'batch_size': 80000, 'dt': 0.01}., sample shape=(10, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.0031622776}., sample shape=(10, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.001}., sample shape=(9, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.00031622776}., sample shape=(9, 52580). metric: 1692
Hyperparams: {'batch_size': 80000, 'dt': 1e-04}., sample shape=(10, 52580). metric: 1881
Hyperparams: {'batch_size': 80000, 'dt': 3.1622778e-05}., sample shape=(10, 52580). metric: 6315
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(9, 52580). metric: 871
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(9, 52580). metric: 803
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(9, 52580). metric: 752
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(9, 52580). metric: 706
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(9, 52580). metric: 89

HBox(children=(HTML(value='Iteration 2/3, 12 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'batch_size': 800, 'dt': 1e-05}., sample shape=(225, 52580). metric: 467
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(156, 52580). metric: 639
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(246, 52580). metric: 584
Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(240, 52580). metric: 692
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(158, 52580). metric: 487
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(157, 52580). metric: 707
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(247, 52580). metric: 709
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(36, 52580). metric: 760
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(157, 52580). metric: 717
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(36, 52580). metric: 733
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}., sample shape=(158, 52580). metric: 

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 6.92 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 800, 'dt': 1e-05}., sample shape=(763, 52580). metric: 274
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(510, 52580). metric: 312
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(775, 52580). metric: 414
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(486, 52580). metric: 482

Number of samples: [763]
Running time: 681.7 sec
{'batch_size': 800, 'dt': 1e-05} 274.2214 (763, 52580)


In [4]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)


HBox(children=(HTML(value='Iteration 1/3, 36 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=36.0…

Hyperparams: {'batch_size': 80000, 'dt': 0.01}., sample shape=(5, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.0031622776}., sample shape=(5, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.001}., sample shape=(5, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.00031622776}., sample shape=(5, 52580). metric: 2502
Hyperparams: {'batch_size': 80000, 'dt': 1e-04}., sample shape=(5, 52580). metric: 1855
Hyperparams: {'batch_size': 80000, 'dt': 3.1622778e-05}., sample shape=(5, 52580). metric: 12469
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(5, 52580). metric: 1051
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(5, 52580). metric: 929
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(5, 52580). metric: 909
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(5, 52580). metric: 796
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(5, 52580). metric: 808


HBox(children=(HTML(value='Iteration 2/3, 12 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(179, 52580). metric: 606
Hyperparams: {'batch_size': 800, 'dt': 1e-05}., sample shape=(173, 52580). metric: 479
Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(181, 52580). metric: 707
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(103, 52580). metric: 665
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(97, 52580). metric: 693
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(104, 52580). metric: 567
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(178, 52580). metric: 713
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(103, 52580). metric: 712
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-08}., sample shape=(19, 52580). metric: 1209
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(19, 52580). metric: 942
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(19, 52580). 

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 6.92 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 800, 'dt': 1e-05}., sample shape=(561, 52580). metric: 301
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(342, 52580). metric: 366
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(593, 52580). metric: 445
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(343, 52580). metric: 529

Number of samples: [561]
Running time: 578.6 sec
{'batch_size': 800, 'dt': 1e-05} 301.14493 (561, 52580)


In [5]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)


HBox(children=(HTML(value='Iteration 1/3, 72 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=72.0…

Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.01}., sample shape=(3, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.0031622776}., sample shape=(3, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.001}., sample shape=(3, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.00031622776}., sample shape=(3, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-04}., sample shape=(3, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(3, 52580). metric: 19748
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-05}., sample shape=(3, 52580). metric: 2949
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(3, 52580). metric: 39849
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-06}., sample shape=(3, 52580). metric: 1208
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(3, 52580). metric: 933
Hyperp

HBox(children=(HTML(value='Iteration 2/3, 24 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=24.0…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(31, 52580). metric: 779
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(55, 52580). metric: 800
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(35, 52580). metric: 807
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(19, 52580). metric: 816
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-06}., sample shape=(53, 52580). metric: 749
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(30, 52580). metric: 799
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(19, 52580). metric: 785
Hyperparams: {'batch_size': 80000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(6, 52580). metric: 960
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}., sample shape=(53, 52580). metric: 797
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(34, 52580). metric: 761
Hyperparam

HBox(children=(HTML(value='Iteration 3/3, 8 arms, time budget = 6.92 sec'), FloatProgress(value=0.0, max=8.0),…

Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-06}., sample shape=(169, 52580). metric: 529
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}., sample shape=(108, 52580). metric: 788
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(168, 52580). metric: 795
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(95, 52580). metric: 782
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(59, 52580). metric: 762
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}., sample shape=(167, 52580). metric: 797
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(94, 52580). metric: 712
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(169, 52580). metric: 690

Number of samples: [169, 169]
Running time: 831.0 sec
{'batch_size': 800, 'L': 5, 'dt': 1e-06} 529.09717 (169, 52580)


In [4]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)


grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)


HBox(children=(HTML(value='Iteration 1/3, 72 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=72.0…

Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.01}., sample shape=(2, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.0031622776}., sample shape=(2, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.001}., sample shape=(2, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.00031622776}., sample shape=(2, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-04}., sample shape=(2, 52580). metric: 33468765437952
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622778e-05}., sample shape=(2, 52580). metric: 29561
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-05}., sample shape=(2, 52580). metric: 3949
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-06}., sample shape=(2, 52580). metric: 59709
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-06}., sample shape=(2, 52580). metric: 1277
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(2, 52580). metric:

HBox(children=(HTML(value='Iteration 2/3, 24 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=24.0…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(27, 52580). metric: 766
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(41, 52580). metric: 840
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-06}., sample shape=(36, 52580). metric: 849
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(25, 52580). metric: 829
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(27, 52580). metric: 789
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}., sample shape=(14, 52580). metric: 874
Hyperparams: {'batch_size': 80000, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(3, 52580). metric: 1179
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(6, 52580). metric: 1349
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(25, 52580). metric: 817
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}., sample shape=(38, 52580). metric: 803
H

HBox(children=(HTML(value='Iteration 3/3, 8 arms, time budget = 6.92 sec'), FloatProgress(value=0.0, max=8.0),…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}., sample shape=(77, 52580). metric: 786
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}., sample shape=(80, 52580). metric: 737
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}., sample shape=(127, 52580). metric: 781
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-07}., sample shape=(79, 52580). metric: 602
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(78, 52580). metric: 769
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}., sample shape=(132, 52580). metric: 725
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-06}., sample shape=(134, 52580). metric: 550
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}., sample shape=(75, 52580). metric: 833

Number of samples: [134, 79]
Running time: 959.1 sec
{'batch_size': 800, 'L': 5, 'dt': 1e-06} 549.52606 (134, 52580)


In [6]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)


grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)


HBox(children=(HTML(value='Iteration 1/3, 36 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=36.0…

Hyperparams: {'batch_size': 80000, 'dt': 0.01}., sample shape=(9, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.0031622776}., sample shape=(9, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.001}., sample shape=(9, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.00031622776}., sample shape=(9, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 1e-04}., sample shape=(9, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 3.1622778e-05}., sample shape=(9, 52580). metric: 4291
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(9, 52580). metric: 756
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(9, 52580). metric: 891
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(9, 52580). metric: 806
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(9, 52580). metric: 740
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(9, 52580). metric: 733
Hype

HBox(children=(HTML(value='Iteration 2/3, 12 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(149, 52580). metric: 321
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(147, 52580). metric: 247
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(147, 52580). metric: 386
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(36, 52580). metric: 884
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(36, 52580). metric: 829
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}., sample shape=(145, 52580). metric: 511
Hyperparams: {'batch_size': 8000, 'dt': 3.1622778e-05}., sample shape=(144, 52580). metric: 234
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(36, 52580). metric: 479
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-08}., sample shape=(145, 52580). metric: 796
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-08}., sample shape=(220, 52580). metric: 757
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}., sample shape=(146, 52580

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 6.92 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 8000, 'dt': 3.1622778e-05}., sample shape=(470, 52580). metric: 127
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(469, 52580). metric: 126
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(483, 52580). metric: 154
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(475, 52580). metric: 210

Number of samples: [469]
Running time: 942.5 sec
{'batch_size': 8000, 'dt': 1e-05} 126.28391 (469, 52580)


In [7]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)


HBox(children=(HTML(value='Iteration 1/3, 36 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=36.0…

Hyperparams: {'batch_size': 80000, 'dt': 0.01}., sample shape=(5, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.0031622776}., sample shape=(5, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.001}., sample shape=(5, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.00031622776}., sample shape=(5, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 1e-04}., sample shape=(5, 52580). metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 3.1622778e-05}., sample shape=(5, 52580). metric: 7530
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}., sample shape=(5, 52580). metric: 973
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}., sample shape=(5, 52580). metric: 1041
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(5, 52580). metric: 766
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(5, 52580). metric: 691
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}., sample shape=(5, 52580). metric: 776
Hyp

HBox(children=(HTML(value='Iteration 2/3, 12 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=12.0…

Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(98, 52580). metric: 344
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(96, 52580). metric: 372
Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(154, 52580). metric: 443
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}., sample shape=(90, 52580). metric: 476
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}., sample shape=(19, 52580). metric: 887
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}., sample shape=(163, 52580). metric: 500
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-08}., sample shape=(98, 52580). metric: 853
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-08}., sample shape=(162, 52580). metric: 782
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(161, 52580). metric: 397
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}., sample shape=(19, 52580). metric: 880
Hyperparams: {'batch_size': 800, 'dt': 1e-07}., sample shape=(163, 52580). metri

HBox(children=(HTML(value='Iteration 3/3, 4 arms, time budget = 6.92 sec'), FloatProgress(value=0.0, max=4.0),…

Hyperparams: {'batch_size': 8000, 'dt': 1e-05}., sample shape=(319, 52580). metric: 158
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}., sample shape=(315, 52580). metric: 197
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}., sample shape=(531, 52580). metric: 234
Hyperparams: {'batch_size': 800, 'dt': 1e-06}., sample shape=(524, 52580). metric: 226

Number of samples: [319]
Running time: 1074.8 sec
{'batch_size': 8000, 'dt': 1e-05} 158.40613 (319, 52580)
