/
pymc.py
194 lines (158 loc) · 6.79 KB
/
pymc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import numpy as np
import pymc3 as pm
import theano
from arviz import from_pymc3
from bambi.priors import Prior
from .base import BackEnd
class PyMC3BackEnd(BackEnd):
"""PyMC3 model-fitting back-end."""
# Available link functions
links = {
"identity": lambda x: x,
"logit": theano.tensor.nnet.sigmoid,
"inverse": theano.tensor.inv,
"inverse_squared": lambda x: theano.tensor.inv(theano.tensor.sqrt(x)),
"log": theano.tensor.exp,
}
dists = {"HalfFlat": pm.Bound(pm.Flat, lower=0)}
def __init__(self):
self.reset()
# Attributes defined elsewhere
self.mu = None # build()
self.spec = None # build()
self.trace = None # build()
self.advi_params = None # build()
def reset(self):
"""Reset PyMC3 model and all tracked distributions and parameters."""
self.model = pm.Model()
self.mu = None
self.par_groups = {}
def _build_dist(self, spec, label, dist, **kwargs):
"""Build and return a PyMC3 Distribution."""
if isinstance(dist, str):
if hasattr(pm, dist):
dist = getattr(pm, dist)
elif dist in self.dists:
dist = self.dists[dist]
else:
raise ValueError(
f"The Distribution {dist} was not found in PyMC3 or the PyMC3BackEnd."
)
# Inspect all args in case we have hyperparameters
def _expand_args(key, value, label):
if isinstance(value, Prior):
label = f"{label}_{key}"
return self._build_dist(spec, label, value.name, **value.args)
return value
kwargs = {k: _expand_args(k, v, label) for (k, v) in kwargs.items()}
# Non-centered parameterization for hyperpriors
if (
spec.noncentered
and "sigma" in kwargs
and "observed" not in kwargs
and isinstance(kwargs["sigma"], pm.model.TransformedRV)
):
old_sigma = kwargs["sigma"]
_offset = pm.Normal(label + "_offset", mu=0, sigma=1, shape=kwargs["shape"])
return pm.Deterministic(label, _offset * old_sigma)
return dist(label, **kwargs)
def build(self, spec, reset=True): # pylint: disable=arguments-differ
"""Compile the PyMC3 model from an abstract model specification.
Parameters
----------
spec : Bambi model
A bambi Model instance containing the abstract specification of the model to compile.
reset : Bool
If True (default), resets the PyMC3BackEnd instance before compiling.
"""
if reset:
self.reset()
with self.model:
self.mu = 0.0
for t in spec.terms.values():
data = t.data
label = t.name
dist_name = t.prior.name
dist_args = t.prior.args
n_cols = t.data.shape[1]
coef = self._build_dist(spec, label, dist_name, shape=n_cols, **dist_args)
if t.random:
self.mu += coef[t.group_index][:, None] * t.predictor
else:
self.mu += pm.math.dot(data, coef)[:, None]
y = spec.y.data
y_prior = spec.family.prior
link_f = spec.family.link
if isinstance(link_f, str):
link_f = self.links[link_f]
y_prior.args[spec.family.parent] = link_f(self.mu)
y_prior.args["observed"] = y
self._build_dist(spec, spec.y.name, y_prior.name, **y_prior.args)
self.spec = spec
# pylint: disable=arguments-differ, inconsistent-return-statements
def run(self, start=None, method="mcmc", init="auto", n_init=50000, **kwargs):
"""Run the PyMC3 MCMC sampler.
Parameters
----------
start: dict, or array of dict
Starting parameter values to pass to sampler; see ``'pm.sample()'`` for details.
method: str
The method to use for fitting the model. By default, 'mcmc', in which case the
PyMC3 sampler will be used. Alternatively, 'advi', in which case the model will be
fitted using automatic differentiation variational inference as implemented in PyMC3.
Finally, 'laplace', in wich case a laplace approximation is used, 'laplace' is not
recommended other than for pedagogical use.
init: str
Initialization method (see PyMC3 sampler documentation). Currently, this is
``'jitter+adapt_diag'``, but this can change in the future.
n_init: int
Number of initialization iterations if init = 'advi' or 'nuts'. Default is kind of in
PyMC3 for the kinds of models we expect to see run with bambi, so we lower it
considerably.
Returns
-------
An ArviZ InferenceData instance.
"""
model = self.model
if method.lower() == "mcmc":
samples = kwargs.pop("samples", 1000)
with model:
self.trace = pm.sample(samples, start=start, init=init, n_init=n_init, **kwargs)
return from_pymc3(self.trace, model=model)
elif method.lower() == "advi":
with model:
self.advi_params = pm.variational.ADVI(start, **kwargs)
return (
self.advi_params
) # this should return an InferenceData object (once arviz adds support for VI)
elif method.lower() == "laplace":
return _laplace(model)
def _laplace(model):
"""Fit a model using a laplace approximation.
Mainly for pedagogical use. ``mcmc`` and ``advi`` are better approximations.
Parameters
----------
model: PyMC3 model
Returns
-------
Dictionary, the keys are the names of the variables and the values tuples of modes and standard
deviations.
"""
with model:
varis = [v for v in model.unobserved_RVs if not pm.util.is_transformed_name(v.name)]
maps = pm.find_MAP(start=model.test_point, vars=varis)
hessian = pm.find_hessian(maps, vars=varis)
if np.linalg.det(hessian) == 0:
raise np.linalg.LinAlgError("Singular matrix. Use mcmc or advi method")
stds = np.diag(np.linalg.inv(hessian) ** 0.5)
maps = [v for (k, v) in maps.items() if not pm.util.is_transformed_name(k)]
modes = [v.item() if v.size == 1 else v for v in maps]
names = [v.name for v in varis]
shapes = [np.atleast_1d(mode).shape for mode in modes]
stds_reshaped = []
idx0 = 0
for shape in shapes:
idx1 = idx0 + sum(shape)
stds_reshaped.append(np.reshape(stds[idx0:idx1], shape))
idx0 = idx1
return dict(zip(names, zip(modes, stds_reshaped)))