In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import time

from collections import OrderedDict

import jax.numpy as jnp
from jax import value_and_grad, grad, jit, vmap, partial
from jax import random, ops, lax

import context 


from examples.bayesian_NN.NN_data import X_train, X_test, y_train, y_test
from examples.bayesian_NN.NN_model import init_network, predict, accuracy, log_post, grad_log_post, batch_loglik
from examples.bayesian_NN.sampler import sgld_NN, kernel_NN
from examples.bayesian_NN.util import flatten_NN_params, load_NN_MAP, flatten_NN_params_jaxscan
from examples.bayesian_NN.sampler import sgld_NN_time_budget, sgld_NN_time_budget_CV, scan_NN

from examples.bayesian_NN.svrg import sgld_NN_time_budget_svrg
from examples.bayesian_NN.sghmc import sghmc_NN_time_budget

from successive_halving.sh import Base_SGLD_SuccHalv_Time, run_SH_time_budget_keep_3
from successive_halving.sh import create_T_list_2params, create_T_list_svrg, create_T_list_sghmc
from ksd import imq_KSD
from util import wait_until_computed


from examples.bayesian_NN.NN_model import get_ECE_MCE



## SH - NN

In [13]:
time_budget = 2

key = random.PRNGKey(0)
dt = 10**(-4.75)
batch_size = int(X_train.shape[0]*0.01)
params_IC = load_NN_MAP()

samples_NN_default, _ = sgld_NN_time_budget(key, time_budget, grad_log_post, dt, 
                                            batch_size, params_IC, X_train, y_train, save_rate=10)

print(len(samples_NN_default))

866


In [14]:
flat_samples_NN = flatten_NN_params(samples_NN_default)

# get_ECE_MCE(flat_samples_NN, X_test, y_test, M=10, pbar=False)

(0.011109485626220703, DeviceArray(0.17217025, dtype=float32))

In [16]:
get_ECE_MCE(flat_samples_NN[::2], X_test, y_test, M=10, pbar=True)

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




(0.01101431884765625, DeviceArray(0.17034847, dtype=float32))

In [17]:

class NN_succ_halving_ECE(Base_SGLD_SuccHalv_Time):
    
    def __init__(self, key, logdt, batch_size_ratio, thin_step):
        self.key = key
        self.logdt = logdt
        self.save_rate = thin_step
        self.batch_size_ratio = batch_size_ratio
        self.x_0 = load_NN_MAP()
        self.last_sample = None
        super(NN_succ_halving_ECE, self).__init__(self.key, self.logdt, self.x_0)
        
    def _run_sampler(self, key, time_budget, x_0):
        dt = 10**self.logdt
        batch_size = int(self.batch_size_ratio*X_train.shape[0])
        samples, _ = sgld_NN_time_budget(key, time_budget, grad_log_post, 
                        dt, batch_size, x_0, X_train, y_train, save_rate=self.save_rate)
        if samples == []:
            return None, None
        else:
            pass
        try:
            self.last_sample = samples[-1]
        except:
            pass
        
        flat_samples = flatten_NN_params(samples)
        flat_grads = flat_samples # don't need the gradients for ECE
        return flat_samples, flat_grads
    
    def get_ksd(self):
        "get ECE rather than KSD"
        try:
            if any(np.isnan(self.samples[-1])):
                return 1e10
            else: pass
            self.KSD = get_ECE_MCE(self.samples, X_test, y_test, M=10, pbar=False)[0]
            return self.KSD
        except:
            return 1e10
    
    def get_last_sample(self):
        return self.last_sample
    


### test SH class

In [18]:
key = random.PRNGKey(0)
my_sgld = NN_succ_halving_ECE(random.PRNGKey(0), logdt=-5., batch_size_ratio=0.1, thin_step=10)


my_sgld.run(5)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

23
0.019670138549804688


In [19]:
key = random.PRNGKey(1)
my_sgld = NN_succ_halving_ECE(random.PRNGKey(0), logdt=-5., batch_size_ratio=0.1, thin_step=10)


my_sgld.run(1)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

4
0.020552792358398436


In [20]:
my_sgld.run(4)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

19
0.0191966064453125


### try different parameters

In [23]:
key = random.PRNGKey(0)
my_sgld = NN_succ_halving_ECE(random.PRNGKey(0), logdt=-5., batch_size_ratio=1, thin_step=10)


my_sgld.run(5)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

1
0.021060829162597657


In [24]:
key = random.PRNGKey(0)
my_sgld = NN_succ_halving_ECE(random.PRNGKey(0), logdt=-3., batch_size_ratio=0.1, thin_step=10)


my_sgld.run(5)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

16
0.4029949951171875


In [25]:
key = random.PRNGKey(0)
my_sgld = NN_succ_halving_ECE(random.PRNGKey(0), logdt=-3., batch_size_ratio=1, thin_step=10)


my_sgld.run(5)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

1
0.442045556640625


In [26]:
key = random.PRNGKey(0)
my_sgld = NN_succ_halving_ECE(random.PRNGKey(0), logdt=-5., batch_size_ratio=0.1, thin_step=10)


my_sgld.run(5)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

16
0.019932244873046876


In [27]:
key = random.PRNGKey(0)
my_sgld = NN_succ_halving_ECE(random.PRNGKey(0), logdt=-5., batch_size_ratio=0.01, thin_step=10)


my_sgld.run(5)
print(my_sgld.len_samples)
print(my_sgld.get_ksd())

101
0.018306248474121094



## SH-KSD for NN



In [33]:

R = 8 # total time budget for longest runs
n = 90 # number of step sizes to try at first
r = R/40
thin_step = 10
eta = 3 # cut trials by 3 every time

print(f"r={r:.3f}")
key = random.PRNGKey(0)

sampler_dict_sgld = run_SH_time_budget_keep_3(key, r, n ,eta, thin_step, 
                                  NN_succ_halving_ECE, create_T_list_2params, X_train.shape[0])



r=0.200
Number of iterations: 4
Total budget: around 72 sec
=====

Iteration 0: 90 configurations with 0.200 seconds each
Combinations: [(-2.9291453, 0.93260336), (-6.4782596, 0.00811131), (-5.5711884, 0.024770763), (-4.7319803, 0.003764935), (-6.044933, 0.0053366995), (-4.836668, 0.015199109), (-3.7557585, 0.017475285), (-3.0864086, 0.869749), (-2.750773, 0.49770236), (-6.2133875, 0.0061359066), (-6.2843676, 0.020092327), (-5.797586, 0.0053366995), (-2.1424685, 0.4328761), (-3.8307626, 0.017475285), (-2.713562, 0.0013219408), (-3.0279179, 0.010722672), (-4.327333, 0.0010722672), (-5.3738036, 0.037649363), (-5.0438943, 0.16297507), (-5.0316076, 0.0010722672), (-2.7530785, 0.14174742), (-2.9914374, 0.040370174), (-3.1181636, 0.0028480361), (-5.4463973, 0.5722368), (-5.470988, 0.020092327), (-2.6559381, 0.0035111913), (-3.0144324, 0.13219412), (-6.1687517, 0.16297507), (-6.3341594, 0.057223674), (-4.1798353, 0.010722672), (-6.433, 0.869749), (-3.1092925, 0.0012328465), (-5.62562, 0.11497

In [35]:
# key = 0
# running time: 13min 30sec
for k,v in sampler_dict_sgld.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-5.0316076, 0.0010722672) 3100 0.013025357055664063
(-4.7319803, 0.003764935) 2360 0.013391537475585937
(-4.685914, 0.00811131) 1800 0.015119924926757813


In [22]:
# key = 1
# running time: 4 min
for k,v in sampler_dict_sgld.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-4.846458, 0.004328762) 4680 0.013626751708984376
(-4.857614, 0.04977024) 700 0.01661865997314453
(-4.8506765, 0.01629751) 1860 0.016694073486328124


In [24]:
# key = 2
# running time: 6 min
for k,v in sampler_dict_sgld.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-4.885751, 0.0012328465) 3170 0.011815101623535156
(-4.7569885, 0.001629751) 3580 0.014084030151367187
(-4.700827, 0.028480362) 710 0.01746497344970703


# svrg

In [2]:
class NN_succ_halving_SVRG_ECE(Base_SGLD_SuccHalv_Time):
    
    def __init__(self, key, logdt, batch_size_ratio, thin_step):
        self.key = key
        self.logdt = logdt
        self.save_rate = thin_step
        self.batch_size_ratio = batch_size_ratio
        self.x_0 = load_NN_MAP()
        self.update_rate = 1000
        self.last_sample = None
        super(NN_succ_halving_SVRG_ECE, self).__init__(self.key, self.logdt, self.x_0)
        
    def _run_sampler(self, key, time_budget, x_0):
        dt = 10**self.logdt
        batch_size = int(self.batch_size_ratio*X_train.shape[0])
        samples = sgld_NN_time_budget_svrg(key, time_budget, dt, batch_size, 
                                           x_0, self.update_rate, save_rate=self.save_rate)
        if samples == []:
            return None, None
        else:
            pass
        try:
            self.last_sample = samples[-1]
        except:
            pass
        
        flat_samples = flatten_NN_params(samples)
        return flat_samples, flat_samples
    
    def get_ksd(self):
        try:
            if any(np.isnan(self.samples[-1])):
                return 1e10
            else: pass
            self.KSD = get_ECE_MCE(self.samples, X_test, y_test, M=10, pbar=False)[0]
            return self.KSD
        except:
            return 1e10
    
    def get_last_sample(self):
        return self.last_sample
    


### test SH class

In [3]:
key = random.PRNGKey(0)
my_svrg = NN_succ_halving_SVRG_ECE(random.PRNGKey(0), logdt=-5., batch_size_ratio=0.1, thin_step=10)


my_svrg.run(5)
print(my_svrg.len_samples)
print(my_svrg.get_ksd())


16
0.019568865966796876


In [4]:
key = random.PRNGKey(0)
my_svrg = NN_succ_halving_SVRG_ECE(random.PRNGKey(0), logdt=-5., batch_size_ratio=0.1, thin_step=10)

my_svrg.run(2)
print(my_svrg.len_samples)
print(my_svrg.get_ksd())


6
0.019401632690429688


In [5]:
my_svrg.run(3)
print(my_svrg.len_samples)
print(my_svrg.get_ksd())



15
0.02037673645019531


In [4]:
15/50

0.3

In [5]:
0.3 + 3*0.3 + 9*0.3 + 27*0.3

12.0

In [6]:
15/50

0.3

In [7]:
12/40

0.3

#### run SH

In [6]:

R = 12 # total time budget for longest runs
n = 90 # number of step sizes to try at first
r = R/40
thin_step = 10
eta = 3 # cut trials by 3 every time

print(f"r={r:.3f}")
key = random.PRNGKey(0)

sampler_dict_svrg = run_SH_time_budget_keep_3(key, r, n ,eta, thin_step, 
                                  NN_succ_halving_SVRG_ECE, create_T_list_2params, X_train.shape[0])



r=0.300
Number of iterations: 4
Total budget: around 108 sec
=====

Iteration 0: 90 configurations with 0.300 seconds each
Combinations: [(-2.9291453, 0.93260336), (-6.4782596, 0.00811131), (-5.5711884, 0.024770763), (-4.7319803, 0.003764935), (-6.044933, 0.0053366995), (-4.836668, 0.015199109), (-3.7557585, 0.017475285), (-3.0864086, 0.869749), (-2.750773, 0.49770236), (-6.2133875, 0.0061359066), (-6.2843676, 0.020092327), (-5.797586, 0.0053366995), (-2.1424685, 0.4328761), (-3.8307626, 0.017475285), (-2.713562, 0.0013219408), (-3.0279179, 0.010722672), (-4.327333, 0.0010722672), (-5.3738036, 0.037649363), (-5.0438943, 0.16297507), (-5.0316076, 0.0010722672), (-2.7530785, 0.14174742), (-2.9914374, 0.040370174), (-3.1181636, 0.0028480361), (-5.4463973, 0.5722368), (-5.470988, 0.020092327), (-2.6559381, 0.0035111913), (-3.0144324, 0.13219412), (-6.1687517, 0.16297507), (-6.3341594, 0.057223674), (-4.1798353, 0.010722672), (-6.433, 0.869749), (-3.1092925, 0.0012328465), (-5.62562, 0.1149

In [7]:
# key = 0
# running time: 13 min 
for k,v in sampler_dict_svrg.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-4.7319803, 0.003764935) 1620 0.014001058959960937
(-5.0316076, 0.0010722672) 1840 0.014453311157226563
(-4.685914, 0.00811131) 1350 0.015924351501464845


In [9]:
# key = 1
# running time: 10 min 
for k,v in sampler_dict_svrg.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-4.846458, 0.004328762) 1370 0.01698354797363281
(-4.8506765, 0.01629751) 810 0.01809205322265625
(-5.3111, 0.003274549) 1390 0.01884635009765625


In [11]:
# key = 2
# running time: 8 min 
for k,v in sampler_dict_svrg.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-4.7569885, 0.001629751) 1740 0.012007249450683594
(-4.885751, 0.0012328465) 1770 0.013648966979980469
(-4.6608067, 0.061359067) 370 0.01791618194580078


## sg-hmc

In [32]:


class NN_succ_halving_HMC_L_ECE(Base_SGLD_SuccHalv_Time):
    
    def __init__(self, key, logdt, batch_size_ratio, L, thin_step):
        self.key = key
        self.logdt = logdt
        self.batch_size_ratio = batch_size_ratio
        self.x_0 = load_NN_MAP()
        self.L = L
        self.thin_step = thin_step
        self.last_sample = None
        super(NN_succ_halving_HMC_L_ECE, self).__init__(self.key, self.logdt, self.x_0)
        
    def _run_sampler(self, key, time_budget, x_0):
        dt = 10**self.logdt
        batch_size = int(self.batch_size_ratio*X_train.shape[0])
        samples = sghmc_NN_time_budget(key, time_budget, dt, batch_size, x_0, self.L, save_rate=self.thin_step)
        if samples == []:
            return None, None
        else:
            pass
        try:
            self.last_sample = samples[-1]
        except:
            pass    
        flat_samples = flatten_NN_params(samples)
        return flat_samples, flat_samples
    
    def get_ksd(self):
        try:
            if any(np.isnan(self.samples[-1])):
                return 1e10
            else: pass
            # in case the ECE is nan
            new_ECE = get_ECE_MCE(self.samples, X_test, y_test, M=10, pbar=False)[0]
            self.KSD = np.where(not np.isnan(new_ECE), new_ECE, 1e10)
            return self.KSD
        except:
            return 1e10
    
    def get_last_sample(self):
        return self.last_sample
    

    

### test SH class

In [33]:

key = random.PRNGKey(0)
nn_my_hmc = NN_succ_halving_HMC_L_ECE(key, -7., 0.1, 10, 3)

nn_my_hmc.run(5)
print(nn_my_hmc.get_ksd())
print(nn_my_hmc.len_samples)


0.020681869506835936
13


In [34]:

key = random.PRNGKey(0)
nn_my_hmc = NN_succ_halving_HMC_L_ECE(key, -7., 0.1, 10, 3)

nn_my_hmc.run(2)
print(nn_my_hmc.get_ksd())
print(nn_my_hmc.len_samples)


0.02057274627685547
6


In [35]:
nn_my_hmc.run(3)
print(nn_my_hmc.get_ksd())
print(nn_my_hmc.len_samples)


0.019782666015625
14


In [36]:

key = random.PRNGKey(0)
nn_my_hmc = NN_succ_halving_HMC_L_ECE(key, -2.9268637, 0.30538556, 16, thin_step=3)

nn_my_hmc.run(2)
print(nn_my_hmc.get_ksd())
print(nn_my_hmc.len_samples)



10000000000.0
2


### run SH-ECE

In [37]:
R = 30
thin_step = 5
r = R/40 # number of samples per run
n = 120 # number of step sizes to try at first
eta = 3 # cut trials by 3 every time

key = random.PRNGKey(0)

sampler_dict_hmc = run_SH_time_budget_keep_3(key, r, n, eta, thin_step, 
                               NN_succ_halving_HMC_L_ECE, create_T_list_sghmc, X_train.shape[0])



Number of iterations: 4
Total budget: around 360 sec
=====

Iteration 0: 120 configurations with 0.750 seconds each
Combinations: [(-5.6016817, 0.0035111913, 14), (-2.9268637, 0.30538556, 16), (-4.8808064, 0.0013219408, 16), (-6.961109, 0.5722368, 9), (-2.39332, 0.0061359066, 10), (-6.3222876, 0.035111915, 4), (-5.015533, 0.93260336, 4), (-6.9114394, 0.6135907, 26), (-6.307716, 0.0040370165, 15), (-5.0523543, 0.0011497568, 15), (-6.4231853, 0.001, 7), (-3.9714003, 0.0010722672, 7), (-4.149378, 0.004328762, 9), (-6.2886944, 0.40370172, 9), (-2.4729414, 0.1, 19), (-5.1166677, 0.06579333, 19), (-6.677946, 0.061359067, 5), (-6.695702, 0.28480357, 8), (-5.365959, 0.037649363, 15), (-4.5360165, 0.053366996, 3), (-4.330592, 0.15199111, 29), (-4.8197184, 0.03274549, 18), (-4.1471167, 0.0070548006, 21), (-5.4690056, 0.869749, 12), (-4.171063, 0.49770236, 15), (-5.787445, 0.008697491, 29), (-4.201998, 0.0040370165, 10), (-5.719651, 0.004328762, 7), (-4.001929, 0.0020092328, 27), (-3.3212676, 0.1

In [38]:
# key = 0
# running time: 28 min
for k,v in sampler_dict_hmc.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-5.9781446, 0.028480362, 7) 605 0.011240739440917968
(-6.046593, 0.0017475284, 6) 2855 0.012382036590576172
(-6.4231853, 0.001, 7) 2785 0.014234141540527344


In [40]:
# key = 1
# running time: 26 min
for k,v in sampler_dict_hmc.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-5.459502, 0.0013219408, 3) 4515 0.0138968017578125
(-6.4526224, 0.035111915, 27) 140 0.014145718383789063
(-6.333365, 0.002310129, 9) 1490 0.014456602478027344


In [42]:
# key = 2
# running time: 22 min
for k,v in sampler_dict_hmc.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-6.5566435, 0.003274549, 14) 1470 0.013270433044433594
(-6.657807, 0.0053366995, 29) 580 0.014085118103027344
(-6.990305, 0.004328762, 17) 1075 0.01546607666015625


## old run

In [165]:
# key = 0, thin=4
# running time: 24 min
for k,v in sampler_dict_hmc.items():
    print(k, v.len_samples*thin_step, v.KSD)

(-5.9781446, 0.028480362, 7) 676 0.011878512573242187
(-6.046593, 0.0017475284, 6) 3544 0.013024720764160157
(-6.0359597, 0.014174741, 16) 516 0.013271453857421876
