-
Notifications
You must be signed in to change notification settings - Fork 429
/
sampler.py
163 lines (124 loc) · 4.25 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
The base sampler class implementing various helpful functions.
"""
from __future__ import (division, print_function, absolute_import,
unicode_literals)
__all__ = ["Sampler"]
import numpy as np
try:
import acor
acor = acor
except ImportError:
acor = None
class Sampler(object):
"""
An abstract sampler object that implements various helper functions
:param dim:
The number of dimensions in the parameter space.
:param lnpostfn:
A function that takes a vector in the parameter space as input and
returns the natural logarithm of the posterior probability for that
position.
:param args: (optional)
A list of extra arguments for ``lnpostfn``. ``lnpostfn`` will be
called with the sequence ``lnpostfn(p, *args)``.
"""
def __init__(self, dim, lnprobfn, args=[]):
self.dim = dim
self.lnprobfn = lnprobfn
self.args = args
# This is a random number generator that we can easily set the state
# of without affecting the numpy-wide generator
self._random = np.random.mtrand.RandomState()
self.reset()
@property
def random_state(self):
"""
The state of the internal random number generator. In practice, it's
the result of calling ``get_state()`` on a
``numpy.random.mtrand.RandomState`` object. You can try to set this
property but be warned that if you do this and it fails, it will do
so silently.
"""
return self._random.get_state()
@random_state.setter # NOQA
def random_state(self, state):
"""
Try to set the state of the random number generator but fail silently
if it doesn't work. Don't say I didn't warn you...
"""
try:
self._random.set_state(state)
except:
pass
@property
def acceptance_fraction(self):
"""
The fraction of proposed steps that were accepted.
"""
return self.naccepted / self.iterations
@property
def chain(self):
"""
A pointer to the Markov chain.
"""
return self._chain
@property
def flatchain(self):
"""
Alias of ``chain`` provided for compatibility.
"""
return self._chain
@property
def lnprobability(self):
"""
A list of the log-probability values associated with each step in
the chain.
"""
return self._lnprob
@property
def acor(self):
"""
The autocorrelation time of each parameter in the chain (length:
``dim``) as estimated by the ``acor`` module.
"""
if acor is None:
raise ImportError("acor")
return acor.acor(self._chain.T)[0]
def get_lnprob(self, p):
"""Return the log-probability at the given position."""
return self.lnprobfn(p, *self.args)
def reset(self):
"""
Clear ``chain``, ``lnprobability`` and the bookkeeping parameters.
"""
self.iterations = 0
self.naccepted = 0
def clear_chain(self):
"""An alias for :func:`reset` kept for backwards compatibility."""
return self.reset()
def sample(self, *args, **kwargs):
raise NotImplementedError("The sampling routine must be implemented "\
"by subclasses")
def run_mcmc(self, pos0, N, rstate0=None, lnprob0=None, **kwargs):
"""
Iterate :func:`sample` for ``N`` iterations and return the result.
:param p0:
The initial position vector.
:param N:
The number of steps to run.
:param lnprob0: (optional)
The log posterior probability at position ``p0``. If ``lnprob``
is not provided, the initial value is calculated.
:param rstate0: (optional)
The state of the random number generator. See the
:func:`random_state` property for details.
:param kwargs: (optional)
Other parameters that are directly passed to :func:`sample`.
"""
for results in self.sample(pos0, lnprob0, rstate0, iterations=N,
**kwargs):
pass
return results