-
Notifications
You must be signed in to change notification settings - Fork 429
/
utils.py
34 lines (24 loc) · 923 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
__all__ = ["sample_ball", "MH_proposal_axisaligned"]
import numpy as np
def sample_ball(p0, std, size=1):
"""
Produce a ball of walkers around an initial parameter value.
:param p0: The initial parameter value.
:param std: The axis-aligned standard deviation.
:param size: The number of samples to produce.
"""
assert(len(p0) == len(std))
return np.vstack([p0 + std * np.random.normal(size=len(p0))
for i in range(size)])
class MH_proposal_axisaligned(object):
"""
A Metropolis-Hastings proposal, with axis-aligned Gaussian steps,
for convenient use as the ``mh_proposal`` option to
:func:`EnsembleSampler.sample` .
"""
def __init__(self, stdev):
self.stdev = stdev
def __call__(self, X):
(nw, npar) = X.shape
assert(len(self.stdev) == npar)
return X + self.stdev * np.random.normal(size=X.shape)