In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import jit, random

from sgmcmcjax.kernels import build_sgld_kernel, build_sghmc_kernel, build_sgnht_kernel
from sgmcmcjax.kernels import build_sgldCV_kernel, build_sghmcCV_kernel, build_sgnhtCV_kernel
from sgmcmcjax.optimizer import build_adam_optimizer

import context

from models.bayesian_NN.NN_model import logprior, loglikelihood, init_network, accuracy_BNN
from models.bayesian_NN.NN_data import X_train, y_train, X_test, y_test
from models.bayesian_NN.util import load_NN_MAP, add_noise_NN_params

from tuning.mamba import run_MAMBA
from tuning.ksd import imq_KSD




# NN - MAMBA

In [2]:


data = (X_train, y_train)

error_fn = lambda sam: -accuracy_BNN(sam[::20], X_test, y_test)

# Niters = 1000
key = random.PRNGKey(0)

Niters = 1000
key, subkey = random.split(key)
params_IC = add_noise_NN_params(subkey, load_NN_MAP(), 1.)
batch_size = int(0.1*X_train.shape[0])
centering_value = load_NN_MAP()

err_fn = lambda x,y: imq_KSD(x[::2], y[::2])

### run MAMBA

In [3]:
build_kernel = lambda dt, batch_size: build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)

T = 0.1
grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }

grid_params = None
best_arm = run_MAMBA(key, build_kernel, err_fn, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

HBox(children=(HTML(value='Iteration 1/3, 48 arms, time budget = 0.00 sec'), FloatProgress(value=0.0, max=48.0…




KeyboardInterrupt: 

In [None]:
build_kernel = lambda dt, batch_size: build_sgldCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

best_arm = run_MAMBA(key, build_kernel, imq_KSD, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

In [None]:
build_kernel = lambda dt, L, batch_size: build_sghmc_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size)

grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)],
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

In [None]:
build_kernel = lambda dt, L, batch_size: build_sghmcCV_kernel(dt, L, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)

T = 10
grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)],
               "L": [5, 10]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

In [None]:
build_kernel = lambda dt, batch_size: build_sgnht_kernel(dt, loglikelihood, logprior, data, batch_size)

T = 10
grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)

In [None]:
build_kernel = lambda dt, batch_size: build_sgnhtCV_kernel(dt, loglikelihood, 
                                                  logprior, data, batch_size, centering_value)
T = 10
grid_params = {'log_dt': -jnp.arange(2., 8., 0.5), 
               'batch_size': [int(10**(-elem)*X_train.shape[0]) for elem in range(0,4)]
              }
best_arm = run_MAMBA(key, build_kernel, imq_KSD, T, params_IC, grid_params=grid_params, eta=3)

print(best_arm.hyperparameters, best_arm.metric, best_arm.samples.shape)