-
Notifications
You must be signed in to change notification settings - Fork 3
/
bnn_regress.py
90 lines (73 loc) · 2.83 KB
/
bnn_regress.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
np.set_printoptions(suppress=True, precision=3)
# load BNN package
import np_bnn as bn
import scipy.stats
# set random seed
rseed = 1234
np.random.seed(rseed)
f="./example_files/data_features_reg.txt"
l="./example_files/data_lab_reg.txt"
dat = bn.get_data(f,
l,
seed=1234,
testsize=0.1, # 10% test set
all_class_in_testset=0,
cv=0, # cross validation (1st batch; set to 1,2,... to run on subsequent batches)
header=True, # input data has a header
from_file=True,
instance_id=0,
randomize_order=True,
label_mode="regression")
# set up the BNN model
bnn_model = bn.npBNN(dat,
n_nodes = [6,2],
estimation_mode="regression",
actFun = bn.ActFun(fun="tanh"),
p_scale=1,
use_bias_node=0,
empirical_error=True)
# set up the MCMC environment
mcmc = bn.MCMC(bnn_model,
update_ws=[0.025,0.025, 0.05],
update_f=[0.005,0.005,0.05],
n_iteration=20000,
sampling_f=100,
print_f=1000,
n_post_samples=100,
likelihood_tempering=1,
adapt_f=0.3,
estimate_error=False)
# mcmc._update_n = [1,1,1]
print(mcmc._update_n)
mcmc._accuracy_lab_f(mcmc._y, bnn_model._labels)
# initialize output files
logger = bn.postLogger(bnn_model, filename="testM", log_all_weights=0)
# run MCMC
bn.run_mcmc(bnn_model, mcmc, logger)
import seaborn as sns
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(5, 5))
# sns.regplot(x=dat['labels'][:,0].flatten(),y=mcmc._y[:,0])
# ax = sns.regplot(x=(dat['labels'][:,0].flatten()),y=(mcmc._y[:,0]))
# sns.regplot(x=(dat['test_labels'][:,0].flatten()),y=(mcmc._y_test[:,0]))
# sns.regplot(x=(dat['labels'][:,1].flatten()),y=(mcmc._y[:,1]))
sns.regplot(x=(dat['test_labels'][:,1].flatten()),y=(mcmc._y_test[:,1]))
sns.regplot(x=dat['labels'][:,1].flatten(),y=mcmc._y[:,1])
ax.set(xlabel='True values', ylabel='Estimated values')
plt.axline((0, 0), (1, 1), linewidth=1, color='k')
fig.show()
#### run predict
bnn_obj, mcmc_obj, logger_obj = bn.load_obj(logger._pklfile)
post_samples = logger_obj._post_weight_samples
# load posterior weights
post_weights = [post_samples[i]['weights'] for i in range(len(post_samples))]
post_alphas = [post_samples[i]['alphas'] for i in range(len(post_samples))]
actFun = bnn_obj._act_fun
output_act_fun = bnn_obj._output_act_fun
post_cat_probs = []
for i in range(len(post_weights)):
actFun_i = actFun
actFun_i.reset_prm(post_alphas[i])
pred = bn.RunPredict(bnn_obj._data, post_weights[i], actFun=actFun_i, output_act_fun=output_act_fun)
post_cat_probs.append(pred)