In [10]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random, vmap

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer
from sgmcmcjax.util import build_grad_log_post

import context
from models.logistic_regression.logistic_regression_model import gen_data, loglikelihood, logprior
from tuning.mamba import run_MAMBA
from tuning.ksd import imq_KSD


# Logistic regression - MAMBA

In [11]:


key = random.PRNGKey(42)
dim = 10
# Ndata = 100_000
Ndata = 1_000_000

theta_true, X, y_data = gen_data(key, dim, Ndata)

# testing and training data
num_train = int(Ndata*0.8)

X_train = X[:num_train]
X_test = X[num_train:]

y_train = y_data[:num_train]
y_test = y_data[num_train:]
print(X_train.shape, X_test.shape)
data = (X_train, y_train)


generating data, with N=1,000,000 and dim=10
(800000, 10) (200000, 10)


In [12]:
# get MAP
batch_size = int(0.01*X_train.shape[0])
run_adam = build_adam_optimizer(1e-2, loglikelihood, logprior, data, batch_size)
key = random.PRNGKey(0)
centering_value, logpost_array = run_adam(key, 5000, jnp.zeros(dim))
params_IC = centering_value

# T = 10
R = 1 # running time of longest sampler

error_fn = lambda x,y: imq_KSD(x[::5], y[::5])
logdt_range = -jnp.arange(1., 8., 0.5) 
batch_size_range = [int(10**(-elem)*X_train.shape[0]) for elem in range(0, 5)]

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))




In [20]:
grad_log_post_fb = build_grad_log_post(loglikelihood, logprior, data)

def grad_lp_LR_fb(theta):
    return grad_log_post_fb(theta, X_train, y_train)

batch_grad_lp_LR_fb = jit(vmap(grad_lp_LR_fb))

@jit
def get_fb_grads(samples):
    thin=10
    array_samples = jnp.array(samples[::thin])
    mygrads = batch_grad_lp_LR_fb(array_samples)
    return array_samples, mygrads

### run MAMBA

In [21]:
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 70 arms, time budget = 0.08 sec'), FloatProgress(value=0.0, max=70.0…

Hyperparams: {'batch_size': 800000, 'dt': 0.1}. metric: 627646
Hyperparams: {'batch_size': 800000, 'dt': 0.031622775}. metric: 620876
Hyperparams: {'batch_size': 800000, 'dt': 0.01}. metric: 599319
Hyperparams: {'batch_size': 800000, 'dt': 0.0031622776}. metric: 487506
Hyperparams: {'batch_size': 800000, 'dt': 0.001}. metric: 230270
Hyperparams: {'batch_size': 800000, 'dt': 0.00031622776}. metric: 69859
Hyperparams: {'batch_size': 800000, 'dt': 1e-04}. metric: 22171
Hyperparams: {'batch_size': 800000, 'dt': 3.1622778e-05}. metric: 4516
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}. metric: 1008
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}. metric: 2652
Hyperparams: {'batch_size': 800000, 'dt': 1e-06}. metric: 2898
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-07}. metric: 3223
Hyperparams: {'batch_size': 800000, 'dt': 1e-07}. metric: 3167
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-08}. metric: 3208
Hyperparams: {'batch_size': 80000, 'dt': 0.1}. metric: 6

HBox(children=(HTML(value='Iteration 2/3, 23 arms, time budget = 0.23 sec'), FloatProgress(value=0.0, max=23.0…

Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 521
Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 357
Hyperparams: {'batch_size': 800000, 'dt': 1e-05}. metric: 673
Hyperparams: {'batch_size': 800, 'dt': 1e-07}. metric: 666
Hyperparams: {'batch_size': 80000, 'dt': 1e-05}. metric: 716
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}. metric: 506
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}. metric: 401
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-06}. metric: 546
Hyperparams: {'batch_size': 8000, 'dt': 1e-05}. metric: 1326
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-06}. metric: 749
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}. metric: 406
Hyperparams: {'batch_size': 80, 'dt': 1e-07}. metric: 1447
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}. metric: 425
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-08}. metric: 696
Hyperparams: {'batch_size': 800000, 'dt': 3.1622776e-06}. metric: 560
Hyperparams: {'batch_size': 80, 'dt': 1e-06

HBox(children=(HTML(value='Iteration 3/3, 7 arms, time budget = 0.77 sec'), FloatProgress(value=0.0, max=7.0),…

Hyperparams: {'batch_size': 8000, 'dt': 1e-06}. metric: 111
Hyperparams: {'batch_size': 8000, 'dt': 3.1622776e-07}. metric: 242
Hyperparams: {'batch_size': 8000, 'dt': 1e-07}. metric: 238
Hyperparams: {'batch_size': 80000, 'dt': 3.1622776e-06}. metric: 286
Hyperparams: {'batch_size': 800, 'dt': 3.1622776e-07}. metric: 345
Hyperparams: {'batch_size': 80000, 'dt': 1e-06}. metric: 228
Hyperparams: {'batch_size': 800, 'dt': 1e-06}. metric: 229

Number of samples: [235, 39]
Running time: 157.7 sec
{'batch_size': 8000, 'dt': 1e-06} 110.581345 (235, 10)


In [None]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, 
                     grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

In [None]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

In [None]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range,
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

In [None]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)

grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

In [None]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
grid_params = {'log_dt': logdt_range, 
               'batch_size': batch_size_range
              }
best_arm = run_MAMBA(key, build_kernel, error_fn, R, params_IC, grid_params=grid_params, get_fb_grads=get_fb_grads)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)