In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context

from models.PMF.pmf_model import logprior, loglikelihood
from models.PMF.util import R_train, R_test, load_PMF_MAP

from tuning.mamba import run_MAMBA
from tuning.ksd import imq_KSD




# PMF - MAMBA

In [18]:
data = (R_train,)


key = random.PRNGKey(0)
centering_value = load_PMF_MAP()
params_IC = load_PMF_MAP()

err_fn = lambda x,y: imq_KSD(x[:], y[:])

grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)

def get_fb_grads(samples):
    thin = 10
    mygrads = [grad_log_post_fb(sam, *data) for sam in samples[::thin]]
    
    return samples[::thin], mygrads

R = 10

### run MAMBA

In [19]:
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*R_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, err_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)


HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=48.0…

Hyperparams: {'batch_size': 80000, 'dt': 0.01}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.0031622776}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.001}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.00031622776}. metric: 1990
Hyperparams: {'batch_size': 80000, 'dt': 1e-04}. metric: 1883
Hyperparams: {'batch_size': 80000, 'dt': 3.1622778e-05}. metric: 8944
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}. metric: 930
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}. metric: 863
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}. metric: 827
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}. metric: 720
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}. metric: 833
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-08}. metric: 787
Hyperparams: {'batch_size': 8000, 'dt': 0.01}. metric: inf
Hyperparams: {'batch_size': 8000, 'dt': 0.0031622776}. metric: inf
Hyperparams: {'batch_size': 8000, 'dt': 0.001}. metric: inf
Hyperparams: {'batch_size': 8000,

HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=16.0…

Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 637
Hyperparams: {'batch_size': 800, 'dt': 1e-05}. metric: 452
Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 699
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 591
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 711
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}. metric: 496
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}. metric: 707
Hyperparams: {'batch_size': 80, 'dt': 3.1622776e-07}. metric: 762
Hyperparams: {'batch_size': 80, 'dt': 1e-06}. metric: 776
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}. metric: 737
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}. metric: 895
Hyperparams: {'batch_size': 80, 'dt': 1e-07}. metric: 801
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-08}. metric: 1243
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-08}. metric: 837
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}. metric: 700
Hyperparams: {'batch_size': 80000, 'dt': 1e-

HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 7.38 sec'), FloatProgress(value=0.0, max=5.0),…

Hyperparams: {'batch_size': 800, 'dt': 1e-05}. metric: 270
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}. metric: 319
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 406
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 471
Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 574

Number of samples: [795]
Running time: 899.9 sec
{'batch_size': 800, 'dt': 1e-05} 269.85458 (795, 52580)


In [20]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, imq_KSD, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)


HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=48.0…

Hyperparams: {'batch_size': 80000, 'dt': 0.01}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.0031622776}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.001}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.00031622776}. metric: 2889
Hyperparams: {'batch_size': 80000, 'dt': 1e-04}. metric: 2033
Hyperparams: {'batch_size': 80000, 'dt': 3.1622778e-05}. metric: 15554
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}. metric: 1169
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}. metric: 1001
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}. metric: 935
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}. metric: 780
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}. metric: 774
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-08}. metric: 710
Hyperparams: {'batch_size': 8000, 'dt': 0.01}. metric: inf
Hyperparams: {'batch_size': 8000, 'dt': 0.0031622776}. metric: inf
Hyperparams: {'batch_size': 8000, 'dt': 0.001}. metric: inf
Hyperparams: {'batch_size': 80

HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=16.0…

Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 619
Hyperparams: {'batch_size': 80, 'dt': 1e-06}. metric: 693
Hyperparams: {'batch_size': 800, 'dt': 1e-05}. metric: 471
Hyperparams: {'batch_size': 80, 'dt': 3.1622776e-06}. metric: 634
Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 709
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 669
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 720
Hyperparams: {'batch_size': 80, 'dt': 1e-05}. metric: 548
Hyperparams: {'batch_size': 80, 'dt': 3.1622776e-07}. metric: 725
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}. metric: 564
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}. metric: 720
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-08}. metric: 1201
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}. metric: 746
Hyperparams: {'batch_size': 80, 'dt': 1e-07}. metric: 746
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}. metric: 1043
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07

HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 7.38 sec'), FloatProgress(value=0.0, max=5.0),…

Hyperparams: {'batch_size': 800, 'dt': 1e-05}. metric: 291
Hyperparams: {'batch_size': 80, 'dt': 1e-05}. metric: 344
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}. metric: 361
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 443
Hyperparams: {'batch_size': 80, 'dt': 3.1622776e-06}. metric: 451

Number of samples: [600]
Running time: 862.9 sec
{'batch_size': 800, 'dt': 1e-05} 290.94128 (600, 52580)


In [21]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*R_train.shape[0]) for elem in range(0,4)],
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)


HBox(children=(HTML(value='Iteration 1/4, 96 arms, time budget = 0.25 sec'), FloatProgress(value=0.0, max=96.0…

Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.01}. metric: 7061963264
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.0031622776}. metric: 541318
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.001}. metric: 793388
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.00031622776}. metric: 2198029
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-04}. metric: 111425
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622778e-05}. metric: 59012
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-05}. metric: 6364
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-06}. metric: 1406
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-06}. metric: 1567
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-07}. metric: 1335
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-07}. metric: 867
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-08}. metric: 649
Hyperparams: {'batch_size': 80000, 'L': 10, 'dt': 0.01}. metric: inf
Hyperparams: {'batch_size': 800

HBox(children=(HTML(value='Iteration 2/4, 32 arms, time budget = 0.75 sec'), FloatProgress(value=0.0, max=32.0…

Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-08}. metric: 1238
Hyperparams: {'batch_size': 80000, 'L': 10, 'dt': 3.1622776e-08}. metric: 1276
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}. metric: 1219
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-07}. metric: 1175
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}. metric: 762
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}. metric: 853
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-07}. metric: 856
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 3.1622776e-07}. metric: 967
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}. metric: 862
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}. metric: 844
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-08}. metric: 965
Hyperparams: {'batch_size': 80, 'L': 10, 'dt': 3.1622776e-08}. metric: 915
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 1e-07}. metric: 1017
Hyperparams: {'batch_size': 8000, 'L': 5, 

HBox(children=(HTML(value='Iteration 3/4, 10 arms, time budget = 2.40 sec'), FloatProgress(value=0.0, max=10.0…

Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}. metric: 800
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}. metric: 813
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}. metric: 824
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-07}. metric: 845
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}. metric: 772
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}. metric: 775
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}. metric: 813
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}. metric: 789
Hyperparams: {'batch_size': 80, 'L': 10, 'dt': 3.1622776e-08}. metric: 842
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-07}. metric: 916

Number of samples: [38, 34, 21]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 8.00 sec'), FloatProgress(value=0.0, max=3.0),…

Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 1e-07}. metric: 780
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}. metric: 770
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}. metric: 759

Number of samples: [67]
Running time: 962.4 sec
{'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08} 759.1282 (67, 52580)


In [22]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)


grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*R_train.shape[0]) for elem in range(0,4)],
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)


HBox(children=(HTML(value='Iteration 1/4, 96 arms, time budget = 0.25 sec'), FloatProgress(value=0.0, max=96.0…

Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.01}. metric: 2301742592
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.0031622776}. metric: 278386
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.001}. metric: 808773
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 0.00031622776}. metric: 2202052
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-04}. metric: 111436
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622778e-05}. metric: 59006
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-05}. metric: 6363
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-06}. metric: 1406
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-06}. metric: 1567
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-07}. metric: 1335
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-07}. metric: 867
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-08}. metric: 649
Hyperparams: {'batch_size': 80000, 'L': 10, 'dt': 0.01}. metric: inf
Hyperparams: {'batch_size': 800

HBox(children=(HTML(value='Iteration 2/4, 32 arms, time budget = 0.75 sec'), FloatProgress(value=0.0, max=32.0…

Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 3.1622776e-08}. metric: 1070
Hyperparams: {'batch_size': 80000, 'L': 10, 'dt': 3.1622776e-08}. metric: 1198
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 3.1622776e-07}. metric: 832
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-08}. metric: 959
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}. metric: 793
Hyperparams: {'batch_size': 80000, 'L': 5, 'dt': 1e-07}. metric: 1121
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-08}. metric: 1065
Hyperparams: {'batch_size': 8000, 'L': 5, 'dt': 3.1622776e-07}. metric: 973
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-07}. metric: 822
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}. metric: 929
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}. metric: 835
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-06}. metric: 936
Hyperparams: {'batch_size': 80, 'L': 10, 'dt': 3.1622776e-08}. metric: 858
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt

HBox(children=(HTML(value='Iteration 3/4, 10 arms, time budget = 2.40 sec'), FloatProgress(value=0.0, max=10.0…

Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 1e-07}. metric: 792
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-07}. metric: 814
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 3.1622776e-07}. metric: 799
Hyperparams: {'batch_size': 800, 'L': 5, 'dt': 3.1622776e-07}. metric: 820
Hyperparams: {'batch_size': 80, 'L': 10, 'dt': 1e-07}. metric: 810
Hyperparams: {'batch_size': 80, 'L': 10, 'dt': 3.1622776e-08}. metric: 771
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 1e-07}. metric: 810
Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}. metric: 765
Hyperparams: {'batch_size': 8000, 'L': 10, 'dt': 3.1622776e-08}. metric: 819
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-06}. metric: 769

Number of samples: [30, 56, 32]


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 8.00 sec'), FloatProgress(value=0.0, max=3.0),…

Hyperparams: {'batch_size': 800, 'L': 10, 'dt': 3.1622776e-08}. metric: 749
Hyperparams: {'batch_size': 80, 'L': 5, 'dt': 1e-06}. metric: 529
Hyperparams: {'batch_size': 80, 'L': 10, 'dt': 3.1622776e-08}. metric: 759

Number of samples: [179]
Running time: 1183.9 sec
{'batch_size': 80, 'L': 5, 'dt': 1e-06} 528.87866 (179, 52580)


In [23]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)


grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*R_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)


HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=48.0…

Hyperparams: {'batch_size': 80000, 'dt': 0.01}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.0031622776}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.001}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.00031622776}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 1e-04}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 3.1622778e-05}. metric: 5502
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}. metric: 841
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}. metric: 985
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}. metric: 794
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}. metric: 715
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}. metric: 722
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-08}. metric: 926
Hyperparams: {'batch_size': 8000, 'dt': 0.01}. metric: inf
Hyperparams: {'batch_size': 8000, 'dt': 0.0031622776}. metric: inf
Hyperparams: {'batch_size': 8000, 'dt': 0.001}. metric: inf
Hyperparams: {'batch_size': 8000, '

HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=16.0…

Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 337
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 411
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}. metric: 340
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}. metric: 875
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}. metric: 848
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}. metric: 523
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-08}. metric: 804
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-08}. metric: 773
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}. metric: 746
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}. metric: 691
Hyperparams: {'batch_size': 800, 'dt': 1e-07}. metric: 652
Hyperparams: {'batch_size': 8000, 'dt': 3.1622778e-05}. metric: 242
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}. metric: 492
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}. metric: 544
Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 454
Hyperparams: {'batch_size': 800, 'dt': 3.162

HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 7.38 sec'), FloatProgress(value=0.0, max=5.0),…

Hyperparams: {'batch_size': 8000, 'dt': 3.1622778e-05}. metric: 128
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 172
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 162
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}. metric: 130
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 218

Number of samples: [451]
Running time: 1311.3 sec
{'batch_size': 8000, 'dt': 3.1622778e-05} 128.48126 (451, 52580)


In [24]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*R_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)


HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=48.0…

Hyperparams: {'batch_size': 80000, 'dt': 0.01}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.0031622776}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.001}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 0.00031622776}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 1e-04}. metric: inf
Hyperparams: {'batch_size': 80000, 'dt': 3.1622778e-05}. metric: 9396
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}. metric: 1056
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}. metric: 1062
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}. metric: 747
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}. metric: 677
Hyperparams: {'batch_size': 80000, 'dt': 1e-07}. metric: 834
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-08}. metric: 983
Hyperparams: {'batch_size': 8000, 'dt': 0.01}. metric: inf
Hyperparams: {'batch_size': 8000, 'dt': 0.0031622776}. metric: inf
Hyperparams: {'batch_size': 8000, 'dt': 0.001}. metric: inf
Hyperparams: {'batch_size': 8000,

HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 2.31 sec'), FloatProgress(value=0.0, max=16.0…

Hyperparams: {'batch_size': 8000, 'dt': 1e-05}. metric: 348
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 381
Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 441
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-07}. metric: 890
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 458
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}. metric: 520
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-08}. metric: 838
Hyperparams: {'batch_size': 80, 'dt': 3.1622776e-08}. metric: 803
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}. metric: 898
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-08}. metric: 786
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 429
Hyperparams: {'batch_size': 800, 'dt': 1e-07}. metric: 660
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}. metric: 607
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}. metric: 759
Hyperparams: {'batch_size': 80, 'dt': 1e-07}. metric: 778
Hyperparams: {'batch_size': 80000, 'dt': 1e

HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 7.38 sec'), FloatProgress(value=0.0, max=5.0),…

Hyperparams: {'batch_size': 8000, 'dt': 1e-05}. metric: 158
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 201
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 239
Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 250
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 248

Number of samples: [321]
Running time: 1449.6 sec
{'batch_size': 8000, 'dt': 1e-05} 157.72705 (321, 52580)
