In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random, partial

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer

import context
from models.logistic_regression.logistic_regression_model import gen_data, loglikelihood, logprior
from tuning.mamba import run_MAMBA
from tuning.ksd import imq_KSD


# Logistic regression - MAMBA

In [2]:


key = random.PRNGKey(42)
dim = 10
Ndata = 100000

theta_true, X, y_data = gen_data(key, dim, Ndata)

# testing and training data
num_train = int(Ndata*0.8)

X_train = X[:num_train]
X_test = X[num_train:]

y_train = y_data[:num_train]
y_test = y_data[num_train:]
print(X_train.shape, X_test.shape)
data = (X_train, y_train)




generating data, with N=100000 and dim=10
(80000, 10) (20000, 10)


In [3]:
Niters = 5000
batch_size = int(0.1*X_train.shape[0])
# error_fn = lambda x: logloss_samples(x, X_test, y_test)

key = random.PRNGKey(0)
key, subkey = random.split(key)
params_IC = theta_true

# get MAP
run_adam = build_adam_optimizer(1e-2, loglikelihood, logprior, data, batch_size)
centering_value, logpost_array = run_adam(key, Niters, jnp.zeros(dim))

### run MAMBA

In [6]:
build_kernel = lambda dt: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

T = 5
grid_params = {'log_dt': -jnp.arange(2., 8., 0.5)}
best_arm = run_MAMBA(key, build_kernel, imq_KSD, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/2, time budget = 0.2 sec'), FloatProgress(value=0.0, max=12.0), HTML(va…




HBox(children=(HTML(value='Iteration 2/2, time budget = 0.6 sec'), FloatProgress(value=0.0, max=4.0), HTML(val…


Running time: 44.1 sec
{'dt': 1e-04} 17.490345 (2195, 10)


In [7]:
build_kernel = lambda dt: build_sgldCV_kernel(dt, loglikelihood, logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, imq_KSD, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/2, time budget = 0.2 sec'), FloatProgress(value=0.0, max=12.0), HTML(va…




HBox(children=(HTML(value='Iteration 2/2, time budget = 0.6 sec'), FloatProgress(value=0.0, max=4.0), HTML(val…


Running time: 39.8 sec
{'dt': 3.1622778e-05} 17.102705 (1872, 10)
