In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import time

from collections import OrderedDict

import jax.numpy as jnp
from jax import value_and_grad, grad, jit, vmap, partial
from jax import random, ops, lax

import context 


from examples.bayesian_NN.NN_data import X_train, X_test, y_train, y_test
from examples.bayesian_NN.NN_model import init_network, predict, accuracy, log_post, grad_log_post, batch_loglik
from examples.bayesian_NN.sampler import sgld_NN, kernel_NN
from examples.bayesian_NN.util import flatten_NN_params, load_NN_MAP, flatten_NN_params_jaxscan
from examples.bayesian_NN.sampler import sgld_NN_time_budget, sgld_NN_time_budget_CV, scan_NN

from examples.bayesian_NN.svrg import sgld_NN_time_budget_svrg

from successive_halving.sh import Base_SGLD_SuccHalv_Time, run_SH_time_budget_keep_3
from successive_halving.sh import create_T_list_sghmc, create_T_list_2params, create_T_list_svrg
from ksd import imq_KSD
from util import wait_until_computed




## SH - NN

In [2]:
time_budget = 3

key = random.PRNGKey(0)
dt = 10**(-4.75)
batch_size = int(X_train.shape[0]*0.1)
params_IC = load_NN_MAP()

samples_NN_default, _ = sgld_NN_time_budget(key, time_budget, grad_log_post, dt, 
                                            batch_size, params_IC, X_train, y_train, save_rate=10)

print(len(samples_NN_default))

14


In [3]:

class NN_succ_halving(Base_SGLD_SuccHalv_Time):
    
    def __init__(self, key, logdt, batch_size_ratio, thin_step):
        self.key = key
        self.logdt = logdt
        self.save_rate = thin_step
        self.batch_size_ratio = batch_size_ratio
        self.x_0 = load_NN_MAP()
        self.last_sample = None
        super(NN_succ_halving, self).__init__(self.key, self.logdt, self.x_0)
        
    def _run_sampler(self, key, time_budget, x_0):
        dt = 10**self.logdt
        batch_size = int(self.batch_size_ratio*X_train.shape[0])
        samples, grads = sgld_NN_time_budget(key, time_budget, grad_log_post, 
                        dt, batch_size, x_0, X_train, y_train, save_rate=self.save_rate)
        if samples == []:
            return None, None
        else:
            pass
        try:
            self.last_sample = samples[-1]
        except:
            pass
        
        flat_samples = flatten_NN_params(samples)
        full_batch_grads = [grad_log_post(sam, X_train, y_train) for sam in samples]
        flat_grads = flatten_NN_params(full_batch_grads)
        return flat_samples, flat_grads
    
    def get_ksd(self):
        try:
            if any(np.isnan(self.samples[-1])):
                return 1e10
            else: pass
            self.KSD = imq_KSD(self.samples, self.grads)
            return self.KSD
        except:
            return 1e10
    
    def get_last_sample(self):
        return self.last_sample
    


### test SH class

In [8]:
key = random.PRNGKey(0)
my_sgld = NN_succ_halving(random.PRNGKey(0), logdt=-5., batch_size_ratio=0.1, thin_step=20)


my_sgld.run(5)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

11
480.56396


In [9]:
key = random.PRNGKey(1)
my_sgld = NN_succ_halving(random.PRNGKey(0), logdt=-5., batch_size_ratio=0.1, thin_step=20)


my_sgld.run(1)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

2
1261.2231


In [10]:
my_sgld.run(4)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

13
422.81372



## SH-KSD for NN

Setup:     
- r = 2
- n = 30
- eta = 3
- key = random.PRNGKey(10)
- thin by 20

Optimal:

In [13]:

R = 8 # total time budget for longest runs
n = 90 # number of step sizes to try at first
r = R/40
thin_step = 10
eta = 3 # cut trials by 3 every time

print(f"r={r:.3f}")
key = random.PRNGKey(2)

sampler_dict_sgld = run_SH_time_budget_keep_3(key, r, n ,eta, thin_step, 
                                  NN_succ_halving, create_T_list_2params, X_train.shape[0])



r=0.200
Number of iterations: 4
Total budget: around 72 sec
=====

Iteration 0: 90 configurations with 0.200 seconds each
Combinations: [(-5.6809406, 0.035111915), (-4.8188696, 0.24770762), (-3.1054559, 0.046415884), (-2.8742738, 0.869749), (-6.927352, 1.0), (-2.0856218, 0.017475285), (-2.124608, 0.07054803), (-4.2174926, 0.018738173), (-2.953824, 0.010722672), (-6.1667485, 0.75646335), (-5.401692, 0.043287612), (-4.943814, 0.4328761), (-4.1000185, 0.16297507), (-3.4697778, 0.010722672), (-6.300881, 0.46415886), (-6.246524, 0.021544348), (-4.6608067, 0.061359067), (-6.8173018, 0.1), (-2.2085094, 0.001629751), (-5.4803066, 0.004977024), (-3.4399357, 0.0015199113), (-3.0746777, 0.0021544343), (-4.885751, 0.0012328465), (-3.1537957, 0.011497568), (-5.071318, 0.65793324), (-5.0046825, 1.0), (-6.1107597, 0.0075646355), (-6.770213, 0.1), (-4.605062, 0.40370172), (-5.664422, 0.13219412), (-2.3886995, 1.0), (-6.7092285, 0.04977024), (-4.154315, 0.014174741), (-2.297781, 0.0132194115), (-2.1183

In [14]:
# key = 2
# running time: 13min 30sec
for k,v in sampler_dict_sgld.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-5.4803066, 0.004977024) 3700 233.77672
(-5.2038264, 0.035111915) 970 283.2232
(-5.3254724, 0.0021544343) 3980 292.94736


In [17]:
# key = 1
# running time:  12min
for k,v in sampler_dict_sgld.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-5.3111, 0.003274549) 4290 244.93042
(-4.93334, 0.023101298) 1410 249.42813
(-5.5381775, 0.0020092328) 5230 261.9676


In [19]:
# key = 0
# running time: 16min  
for k,v in sampler_dict_sgld.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-5.187175, 0.010722672) 3780 185.63173
(-5.480715, 0.0035111913) 5890 201.25934
(-5.6204176, 0.0026560882) 6460 220.57657


# SH for svrg

In [2]:
class NN_succ_halving_SVRG(Base_SGLD_SuccHalv_Time):
    
    def __init__(self, key, logdt, batch_size_ratio, thin_step):
        self.key = key
        self.logdt = logdt
        self.save_rate = thin_step
        self.batch_size_ratio = batch_size_ratio
        self.x_0 = load_NN_MAP()
        self.update_rate = 1000
        self.last_sample = None
        super(NN_succ_halving_SVRG, self).__init__(self.key, self.logdt, self.x_0)
        
    def _run_sampler(self, key, time_budget, x_0):
        dt = 10**self.logdt
        batch_size = int(self.batch_size_ratio*X_train.shape[0])
        samples = sgld_NN_time_budget_svrg(key, time_budget, dt, batch_size, 
                                           x_0, self.update_rate, save_rate=self.save_rate)
        if samples == []:
            return None, None
        else:
            pass
        try:
            self.last_sample = samples[-1]
        except:
            pass
        
        flat_samples = flatten_NN_params(samples)
        full_batch_grads = [grad_log_post(sam, X_train, y_train) for sam in samples]
        flat_grads = flatten_NN_params(full_batch_grads)
        return flat_samples, flat_grads
    
    def get_ksd(self):
        try:
            if any(np.isnan(self.samples[-1])):
                return 1e10
            else: pass
            self.KSD = imq_KSD(self.samples, self.grads)
            return self.KSD
        except:
            return 1e10
    
    def get_last_sample(self):
        return self.last_sample
    


### test SH class

In [3]:
key = random.PRNGKey(0)
my_sgld = NN_succ_halving_SVRG(random.PRNGKey(0), logdt=-5., batch_size_ratio=0.1, thin_step=20)


my_sgld.run(5)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

4
691.9497


In [4]:
key = random.PRNGKey(0)
my_sgld = NN_succ_halving_SVRG(random.PRNGKey(0), logdt=-5., batch_size_ratio=0.1, thin_step=20)

my_sgld.run(2)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

2
880.5888


In [5]:
my_sgld.run(3)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

4
635.0441


#### run SH

In [6]:

R = 12 # total time budget for longest runs
n = 90 # number of step sizes to try at first
r = R/40
thin_step = 10
eta = 3 # cut trials by 3 every time

print(f"r={r:.3f}")
key = random.PRNGKey(0)

sampler_dict_svrg = run_SH_time_budget_keep_3(key, r, n ,eta, thin_step, 
                                  NN_succ_halving_SVRG, create_T_list_2params, X_train.shape[0])



r=0.300
Number of iterations: 4
Total budget: around 108 sec
=====

Iteration 0: 90 configurations with 0.300 seconds each
Combinations: [(-2.9291453, 0.93260336), (-6.4782596, 0.00811131), (-5.5711884, 0.024770763), (-4.7319803, 0.003764935), (-6.044933, 0.0053366995), (-4.836668, 0.015199109), (-3.7557585, 0.017475285), (-3.0864086, 0.869749), (-2.750773, 0.49770236), (-6.2133875, 0.0061359066), (-6.2843676, 0.020092327), (-5.797586, 0.0053366995), (-2.1424685, 0.4328761), (-3.8307626, 0.017475285), (-2.713562, 0.0013219408), (-3.0279179, 0.010722672), (-4.327333, 0.0010722672), (-5.3738036, 0.037649363), (-5.0438943, 0.16297507), (-5.0316076, 0.0010722672), (-2.7530785, 0.14174742), (-2.9914374, 0.040370174), (-3.1181636, 0.0028480361), (-5.4463973, 0.5722368), (-5.470988, 0.020092327), (-2.6559381, 0.0035111913), (-3.0144324, 0.13219412), (-6.1687517, 0.16297507), (-6.3341594, 0.057223674), (-4.1798353, 0.010722672), (-6.433, 0.869749), (-3.1092925, 0.0012328465), (-5.62562, 0.1149

In [7]:
# key = 0
# running time: 10min
for k,v in sampler_dict_svrg.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-5.187175, 0.010722672) 1440 178.37445
(-4.836668, 0.015199109) 1360 199.92183
(-4.7952933, 0.0132194115) 1300 204.08676


In [9]:
# key = 1
# running time: 8min
for k,v in sampler_dict_svrg.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-5.3111, 0.003274549) 2120 182.78842
(-4.93334, 0.023101298) 920 190.37178
(-4.8867292, 0.01629751) 1200 190.56728


In [11]:
# key = 2
# running time: 
for k,v in sampler_dict_svrg.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-4.951375, 0.0070548006) 1660 198.46674
(-5.3254724, 0.0021544343) 2150 205.87973
(-4.700827, 0.028480362) 860 245.19313


## SH for sg-hmc

In [7]:
from examples.bayesian_NN.sghmc import sghmc_NN_time_budget

class NN_succ_halving_HMC_L(Base_SGLD_SuccHalv_Time):
    
    def __init__(self, key, logdt, batch_size_ratio, L, thin_step):
        self.key = key
        self.logdt = logdt
        self.batch_size_ratio = batch_size_ratio
        self.x_0 = load_NN_MAP()
        self.L = L
        self.thin_step = thin_step
        self.last_sample = None
        super(NN_succ_halving_HMC_L, self).__init__(self.key, self.logdt, self.x_0)
        
    def _run_sampler(self, key, time_budget, x_0):
        dt = 10**self.logdt
        batch_size = int(self.batch_size_ratio*X_train.shape[0])
        samples = sghmc_NN_time_budget(key, time_budget, dt, batch_size, x_0, self.L, save_rate=self.thin_step)
        if samples == []:
            return None, None
        else:
            pass
        try:
            self.last_sample = samples[-1]
        except:
            pass    
        flat_samples = flatten_NN_params(samples)
        full_batch_grads = [grad_log_post(sam, X_train, y_train) for sam in samples]
        flat_grads = flatten_NN_params(full_batch_grads)
        
        return flat_samples, flat_grads
    
    def get_ksd(self):
        try:
            if any(np.isnan(self.samples[-1])):
                return 1e10
            else: pass
            self.KSD = imq_KSD(self.samples, self.grads)
            return self.KSD
        except:
            return 1e10
    
    def get_last_sample(self):
        return self.last_sample
    

    

### test SH class

In [10]:

key = random.PRNGKey(0)
nn_my_hmc = NN_succ_halving_HMC_L(key, -7., 0.1, 10, 10)

nn_my_hmc.run(5)
print(nn_my_hmc.get_ksd())
print(nn_my_hmc.len_samples)


1371.9117
4


In [11]:

key = random.PRNGKey(0)
nn_my_hmc = NN_succ_halving_HMC_L(key, -7., 0.1, 10, 10)

nn_my_hmc.run(1)
print(nn_my_hmc.get_ksd())
print(nn_my_hmc.len_samples)

nn_my_hmc.run(2)
print(nn_my_hmc.get_ksd())
print(nn_my_hmc.len_samples)


nn_my_hmc.run(2)
print(nn_my_hmc.get_ksd())
print(nn_my_hmc.len_samples)


4496.396
1
2598.093
3
2073.715
4


In [7]:
0.75 + 0.75*3 + 0.75*9 + 0.75*27

30.0

In [2]:
(300/4)/96

0.78125

In [16]:
R = 30
thin_step = 5
r = R/40 # number of samples per run
n = 120 # number of step sizes to try at first
eta = 3 # cut trials by 3 every time

key = random.PRNGKey(0)

sampler_dict_hmc = run_SH_time_budget_keep_3(key, r, n, eta, thin_step, 
                               NN_succ_halving_HMC_L, create_T_list_sghmc, X_train.shape[0])



Number of iterations: 4
Total budget: around 360 sec
=====

Iteration 0: 120 configurations with 0.750 seconds each
Combinations: [(-5.6016817, 0.0035111913, 14), (-2.9268637, 0.30538556, 16), (-4.8808064, 0.0013219408, 16), (-6.961109, 0.5722368, 9), (-2.39332, 0.0061359066, 10), (-6.3222876, 0.035111915, 4), (-5.015533, 0.93260336, 4), (-6.9114394, 0.6135907, 26), (-6.307716, 0.0040370165, 15), (-5.0523543, 0.0011497568, 15), (-6.4231853, 0.001, 7), (-3.9714003, 0.0010722672, 7), (-4.149378, 0.004328762, 9), (-6.2886944, 0.40370172, 9), (-2.4729414, 0.1, 19), (-5.1166677, 0.06579333, 19), (-6.677946, 0.061359067, 5), (-6.695702, 0.28480357, 8), (-5.365959, 0.037649363, 15), (-4.5360165, 0.053366996, 3), (-4.330592, 0.15199111, 29), (-4.8197184, 0.03274549, 18), (-4.1471167, 0.0070548006, 21), (-5.4690056, 0.869749, 12), (-4.171063, 0.49770236, 15), (-5.787445, 0.008697491, 29), (-4.201998, 0.0040370165, 10), (-5.719651, 0.004328762, 7), (-4.001929, 0.0020092328, 27), (-3.3212676, 0.1

In [17]:
# key = 0
# thin = 5, R=30
# running time: 41 minutes

for k,v in sampler_dict_hmc.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-5.590087, 0.01629751, 3) 1335 211.60286
(-6.845521, 0.0061359066, 9) 785 215.57306
(-5.498445, 0.053366996, 3) 600 226.20459


In [19]:
# key = 1
# thin = 5, R=30
# running time: 31 minutes

for k,v in sampler_dict_hmc.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-6.3645906, 0.012328468, 7) 825 200.55487
(-6.9202814, 0.006579331, 12) 645 206.98772
(-6.445711, 0.030538555, 8) 420 213.53226


In [21]:
# key = 2
# thin = 5, R=30
# running time: 40 minutes

for k,v in sampler_dict_hmc.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-5.881216, 0.006579331, 3) 4215 147.20659
(-6.990305, 0.004328762, 17) 1110 193.81314
(-6.752221, 0.021544348, 6) 995 222.56189
