In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer

import context
from models.logistic_regression.logistic_regression_model import gen_data, loglikelihood, logprior
from tuning.mamba import run_MAMBA
from tuning.ksd import imq_KSD


# Logistic regression - MAMBA

In [2]:


key = random.PRNGKey(42)
dim = 10
Ndata = 100000

theta_true, X, y_data = gen_data(key, dim, Ndata)

# testing and training data
num_train = int(Ndata*0.8)

X_train = X[:num_train]
X_test = X[num_train:]

y_train = y_data[:num_train]
y_test = y_data[num_train:]
print(X_train.shape, X_test.shape)
data = (X_train, y_train)




generating data, with N=100000 and dim=10
(80000, 10) (20000, 10)


In [3]:
batch_size = int(0.1*X_train.shape[0])
params_IC = theta_true

# get MAP
run_adam = build_adam_optimizer(1e-2, loglikelihood, logprior, data, batch_size)
key = random.PRNGKey(0)
centering_value, logpost_array = run_adam(key, 5000, jnp.zeros(dim))

### run MAMBA

In [4]:
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

T = 10
grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, time budget = 0.07 sec. 48 arms'), FloatProgress(value=0.0, max=48.0…




HBox(children=(HTML(value='Iteration 2/3, time budget = 0.21 sec. 16 arms'), FloatProgress(value=0.0, max=16.0…




HBox(children=(HTML(value='Iteration 3/3, time budget = 0.67 sec. 5 arms'), FloatProgress(value=0.0, max=5.0),…


Running time: 145.9 sec
{'batch_size': 8000, 'dt': 1e-04} 12.981178 (2505, 10)


In [5]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, imq_KSD, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, time budget = 0.07 sec. 48 arms'), FloatProgress(value=0.0, max=48.0…




HBox(children=(HTML(value='Iteration 2/3, time budget = 0.21 sec. 16 arms'), FloatProgress(value=0.0, max=16.0…




HBox(children=(HTML(value='Iteration 3/3, time budget = 0.67 sec. 5 arms'), FloatProgress(value=0.0, max=5.0),…


Running time: 160.2 sec
{'batch_size': 800, 'dt': 1e-04} 4.3017693 (5529, 10)


In [15]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)],
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 96 arms, time budget = 0.03 sec'), FloatProgress(value=0.0, max=96.0…




HBox(children=(HTML(value='Iteration 2/4, 32 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=32.0…




HBox(children=(HTML(value='Iteration 3/4, 10 arms, time budget = 0.25 sec'), FloatProgress(value=0.0, max=10.0…




HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 0.83 sec'), FloatProgress(value=0.0, max=3.0),…


Running time: 256.1 sec
{'batch_size': 800, 'L': 5, 'dt': 3.1622776e-06} 11.845586 (6440, 10)


In [5]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

T = 10
grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)],
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/4, 96 arms, time budget = 0.03 sec'), FloatProgress(value=0.0, max=96.0…


Running time: 190.6 sec


HBox(children=(HTML(value='Iteration 2/4, 32 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=32.0…


Running time: 211.7 sec


HBox(children=(HTML(value='Iteration 3/4, 10 arms, time budget = 0.25 sec'), FloatProgress(value=0.0, max=10.0…


Running time: 239.9 sec


HBox(children=(HTML(value='Iteration 4/4, 3 arms, time budget = 0.83 sec'), FloatProgress(value=0.0, max=3.0),…


Running time: 291.1 sec
Running time: 291.1 sec
{'batch_size': 800, 'L': 10, 'dt': 3.1622776e-06} 7.472683 (3460, 10)


In [5]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)

T = 10
grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.07 sec'), FloatProgress(value=0.0, max=48.0…




HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 0.21 sec'), FloatProgress(value=0.0, max=16.0…




HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 0.67 sec'), FloatProgress(value=0.0, max=5.0),…


Running time: 175.4 sec
{'batch_size': 80000, 'dt': 0.00031622776} 1.9816983 (968, 10)


In [4]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
T = 10
grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.07 sec'), FloatProgress(value=0.0, max=48.0…


Running time: 144.1 sec


HBox(children=(HTML(value='Iteration 2/3, 16 arms, time budget = 0.21 sec'), FloatProgress(value=0.0, max=16.0…


Running time: 176.0 sec


HBox(children=(HTML(value='Iteration 3/3, 5 arms, time budget = 0.67 sec'), FloatProgress(value=0.0, max=5.0),…


Running time: 228.9 sec
Running time: 228.9 sec
{'batch_size': 800, 'dt': 3.1622778e-05} 1.6534338 (4743, 10)
