-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils_stan.py
141 lines (127 loc) · 5.46 KB
/
utils_stan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import pystan
import pickle
import numpy
def check_div(fit):
"""Check transitions that ended with a divergence"""
sampler_params = fit.get_sampler_params(inc_warmup=False)
divergent = [x for y in sampler_params for x in y['divergent__']]
n = sum(divergent)
N = len(divergent)
print('{} of {} iterations ended with a divergence ({}%)'.format(n, N,
100 * n / N))
if n > 0:
print(' Try running with larger adapt_delta to remove the divergences')
def check_treedepth(fit, max_depth = 10):
"""Check transitions that ended prematurely due to maximum tree depth limit"""
sampler_params = fit.get_sampler_params(inc_warmup=False)
depths = [x for y in sampler_params for x in y['treedepth__']]
n = sum(1 for x in depths if x == max_depth)
N = len(depths)
print(('{} of {} iterations saturated the maximum tree depth of {}'
+ ' ({}%)').format(n, N, max_depth, 100 * n / N))
if n > 0:
print(' Run again with max_depth set to a larger value to avoid saturation')
def check_energy(fit):
"""Checks the energy Bayesian fraction of missing information (E-BFMI)"""
sampler_params = fit.get_sampler_params(inc_warmup=False)
no_warning = True
for chain_num, s in enumerate(sampler_params):
energies = s['energy__']
numer = sum((energies[i] - energies[i - 1])**2 for i in range(1, len(energies))) / len(energies)
denom = numpy.var(energies)
if numer / denom < 0.2:
print('Chain {}: E-BFMI = {}'.format(chain_num, numer / denom))
no_warning = False
if no_warning:
print('E-BFMI indicated no pathological behavior')
else:
print(' E-BFMI below 0.2 indicates you may need to reparameterize your model')
def check_n_eff(fit):
"""Checks the effective sample size per iteration"""
fit_summary = fit.summary(probs=[0.5])
n_effs = [x[4] for x in fit_summary['summary']]
names = fit_summary['summary_rownames']
n_iter = len(fit.extract()['lp__'])
no_warning = True
for n_eff, name in zip(n_effs, names):
ratio = n_eff / n_iter
if (ratio < 0.001):
print('n_eff / iter for parameter {} is {}!'.format(name, ratio))
print('E-BFMI below 0.2 indicates you may need to reparameterize your model')
no_warning = False
if no_warning:
print('n_eff / iter looks reasonable for all parameters')
else:
print(' n_eff / iter below 0.001 indicates that the effective sample size has likely been overestimated')
def check_rhat(fit):
"""Checks the potential scale reduction factors"""
from math import isnan
from math import isinf
fit_summary = fit.summary(probs=[0.5])
rhats = [x[5] for x in fit_summary['summary']]
names = fit_summary['summary_rownames']
no_warning = True
for rhat, name in zip(rhats, names):
if (rhat > 1.1 or isnan(rhat) or isinf(rhat)):
print('Rhat for parameter {} is {}!'.format(name, rhat))
no_warning = False
if no_warning:
print('Rhat looks reasonable for all parameters')
else:
print(' Rhat above 1.1 indicates that the chains very likely have not mixed')
def check_all_diagnostics(fit):
"""Checks all MCMC diagnostics"""
check_n_eff(fit)
check_rhat(fit)
check_div(fit)
check_treedepth(fit)
check_energy(fit)
def _by_chain(unpermuted_extraction):
num_chains = len(unpermuted_extraction[0])
result = [[] for _ in range(num_chains)]
for c in range(num_chains):
for i in range(len(unpermuted_extraction)):
result[c].append(unpermuted_extraction[i][c])
return numpy.array(result)
def _shaped_ordered_params(fit):
ef = fit.extract(permuted=False, inc_warmup=False) # flattened, unpermuted, by (iteration, chain)
ef = _by_chain(ef)
ef = ef.reshape(-1, len(ef[0][0]))
ef = ef[:, 0:len(fit.flatnames)] # drop lp__
shaped = {}
idx = 0
for dim, param_name in zip(fit.par_dims, fit.extract().keys()):
length = int(numpy.prod(dim))
shaped[param_name] = ef[:,idx:idx + length]
shaped[param_name].reshape(*([-1] + dim))
idx += length
return shaped
def partition_div(fit):
""" Returns parameter arrays separated into divergent and non-divergent transitions"""
sampler_params = fit.get_sampler_params(inc_warmup=False)
div = numpy.concatenate([x['divergent__'] for x in sampler_params]).astype('int')
params = _shaped_ordered_params(fit)
nondiv_params = dict((key, params[key][div == 0]) for key in params)
div_params = dict((key, params[key][div == 1]) for key in params)
return nondiv_params, div_params
def compile_model(filename, model_name=None, **kwargs):
"""This will automatically cache models - great if you're just running a
script on the command line.
See http://pystan.readthedocs.io/en/latest/avoiding_recompilation.html"""
from hashlib import md5
with open(filename) as f:
model_code = f.read()
code_hash = md5(model_code.encode('ascii')).hexdigest()
if model_name is None:
cache_fn = 'cached-model-{}.pkl'.format(code_hash)
else:
cache_fn = 'cached-{}-{}.pkl'.format(model_name, code_hash)
try:
sm = pickle.load(open(cache_fn, 'rb'))
except:
sm = pystan.StanModel(model_code=model_code)
with open(cache_fn, 'wb') as f:
pickle.dump(sm, f)
else:
print("Using cached StanModel")
return sm