Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for named parameters #386

Merged
merged 9 commits into from
Jun 23, 2021
95 changes: 84 additions & 11 deletions src/emcee/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*-

import warnings
from itertools import count
from typing import Dict, List, Optional, Union

import numpy as np
from itertools import count

from .backends import Backend
from .model import Model
Expand Down Expand Up @@ -61,6 +62,10 @@ class EnsembleSampler(object):
to accept a list of position vectors instead of just one. Note
that ``pool`` will be ignored if this is ``True``.
(default: ``False``)
parameter_names (Optional[Union[List[str], Dict[str, List[int]]]]):
names of individual parameters or groups of parameters. If
specified, the ``log_prob_fn`` will recieve a dictionary of
parameters, rather than a ``np.ndarray``.

"""

Expand All @@ -76,6 +81,7 @@ def __init__(
backend=None,
vectorize=False,
blobs_dtype=None,
parameter_names: Optional[Union[Dict[str, int], List[str]]] = None,
# Deprecated...
a=None,
postargs=None,
Expand Down Expand Up @@ -157,6 +163,49 @@ def __init__(
# ``args`` and ``kwargs`` pickleable.
self.log_prob_fn = _FunctionWrapper(log_prob_fn, args, kwargs)

# Save the parameter names
self.params_are_named: bool = parameter_names is not None
if self.params_are_named:
assert isinstance(parameter_names, (list, dict))

# Don't support vectorizing yet
msg = "named parameters with vectorization unsupported for now"
assert not self.vectorize, msg

# Check for duplicate names
dupes = set()
uniq = []
for name in parameter_names:
if name not in dupes:
uniq.append(name)
dupes.add(name)
msg = f"duplicate paramters: {dupes}"
assert len(uniq) == len(parameter_names), msg

if isinstance(parameter_names, list):
# Check for all named
msg = "name all parameters or set `parameter_names` to `None`"
assert len(parameter_names) == ndim, msg
# Convert a list to a dict
parameter_names: Dict[str, int] = {
name: i for i, name in enumerate(parameter_names)
}

# Check not too many names
msg = "too many names"
assert len(parameter_names) <= ndim, msg

# Check all indices appear
values = [
v if isinstance(v, list) else [v]
for v in parameter_names.values()
]
values = [item for sublist in values for item in sublist]
values = set(values)
msg = f"not all values appear -- set should be 0 to {ndim-1}"
assert values == set(np.arange(ndim)), msg
self.parameter_names = parameter_names

@property
def random_state(self):
"""
Expand Down Expand Up @@ -251,8 +300,9 @@ def sample(
raise ValueError("'store' must be False when 'iterations' is None")
# Interpret the input as a walker state and check the dimensions.
state = State(initial_state, copy=True)
if np.shape(state.coords) != (self.nwalkers, self.ndim):
raise ValueError("incompatible input dimensions")
state_shape = np.shape(state.coords)
if state_shape != (self.nwalkers, self.ndim):
raise ValueError(f"incompatible input dimensions {state_shape}")
if (not skip_initial_state_check) and (
not walkers_independent(state.coords)
):
Expand Down Expand Up @@ -416,6 +466,10 @@ def compute_log_prob(self, coords):
if np.any(np.isnan(p)):
raise ValueError("At least one parameter value was NaN")

# If the parmaeters are named, then switch to dictionaries
if self.params_are_named:
p = ndarray_to_list_of_dicts(p, self.parameter_names)

# Run the log-probability calculations (optionally in parallel).
if self.vectorize:
results = self.log_prob_fn(p)
Expand All @@ -427,9 +481,7 @@ def compute_log_prob(self, coords):
map_func = self.pool.map
else:
map_func = map
results = list(
map_func(self.log_prob_fn, (p[i] for i in range(len(p))))
)
results = list(map_func(self.log_prob_fn, p))

try:
log_prob = np.array([float(l[0]) for l in results])
Expand All @@ -444,8 +496,9 @@ def compute_log_prob(self, coords):
else:
try:
with warnings.catch_warnings(record=True):
warnings.simplefilter("error",
np.VisibleDeprecationWarning)
warnings.simplefilter(
"error", np.VisibleDeprecationWarning
)
try:
dt = np.atleast_1d(blob[0]).dtype
except Warning:
Expand All @@ -455,7 +508,8 @@ def compute_log_prob(self, coords):
"placed in an object array. Numpy has "
"deprecated this automatic detection, so "
"please specify "
"blobs_dtype=np.dtype('object')")
"blobs_dtype=np.dtype('object')"
)
dt = np.dtype("object")
except ValueError:
dt = np.dtype("object")
Expand Down Expand Up @@ -557,8 +611,8 @@ class _FunctionWrapper(object):

def __init__(self, f, args, kwargs):
self.f = f
self.args = [] if args is None else args
self.kwargs = {} if kwargs is None else kwargs
self.args = args or []
self.kwargs = kwargs or {}

def __call__(self, x):
try:
Expand Down Expand Up @@ -605,3 +659,22 @@ def _scaled_cond(a):
return np.inf
c = b / bsum
return np.linalg.cond(c.astype(float))


def ndarray_to_list_of_dicts(
x: np.ndarray,
key_map: Dict[str, Union[int, List[int]]],
) -> List[Dict[str, Union[np.number, np.ndarray]]]:
"""
A helper function to convert a ``np.ndarray`` into a list
of dictionaries of parameters. Used when parameters are named.

Args:
x (np.ndarray): parameter array of shape ``(N, n_dim)``, where
``N`` is an integer
key_map (Dict[str, Union[int, List[int]]):

Returns:
list of dictionaries of parameters
"""
return [{key: xi[val] for key, val in key_map.items()} for xi in x]
184 changes: 184 additions & 0 deletions src/emcee/tests/unit/test_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""
Unit tests of some functionality in ensemble.py when the parameters are named
"""
import string
from unittest import TestCase

import numpy as np
import pytest

from emcee.ensemble import EnsembleSampler, ndarray_to_list_of_dicts


class TestNP2ListOfDicts(TestCase):
def test_ndarray_to_list_of_dicts(self):
# Try different numbers of keys
for n_keys in [1, 2, 10, 26]:
keys = list(string.ascii_lowercase[:n_keys])
key_set = set(keys)
key_dict = {key: i for i, key in enumerate(keys)}
# Try different number of walker/procs
for N in [1, 2, 3, 10, 100]:
x = np.random.rand(N, n_keys)

LOD = ndarray_to_list_of_dicts(x, key_dict)
assert len(LOD) == N, "need 1 dict per row"
for i, dct in enumerate(LOD):
assert dct.keys() == key_set, "keys are missing"
for j, key in enumerate(keys):
assert dct[key] == x[i, j], f"wrong value at {(i, j)}"


class TestNamedParameters(TestCase):
"""
Test that a keyword-based log-probability function instead of
a positional.
"""

# Keyword based lnpdf
def lnpdf(self, pars) -> np.float64:
mean = pars["mean"]
var = pars["var"]
if var <= 0:
return -np.inf
return (
-0.5 * ((mean - self.x) ** 2 / var + np.log(2 * np.pi * var)).sum()
)

def lnpdf_mixture(self, pars) -> np.float64:
mean1 = pars["mean1"]
var1 = pars["var1"]
mean2 = pars["mean2"]
var2 = pars["var2"]
if var1 <= 0 or var2 <= 0:
return -np.inf
return (
-0.5
* (
(mean1 - self.x) ** 2 / var1
+ np.log(2 * np.pi * var1)
+ (mean2 - self.x - 3) ** 2 / var2
+ np.log(2 * np.pi * var2)
).sum()
)

def lnpdf_mixture_grouped(self, pars) -> np.float64:
mean1, mean2 = pars["means"]
var1, var2 = pars["vars"]
const = pars["constant"]
if var1 <= 0 or var2 <= 0:
return -np.inf
return (
-0.5
* (
(mean1 - self.x) ** 2 / var1
+ np.log(2 * np.pi * var1)
+ (mean2 - self.x - 3) ** 2 / var2
+ np.log(2 * np.pi * var2)
).sum()
+ const
)

def setUp(self):
# Draw some data from a unit Gaussian
self.x = np.random.randn(100)
self.names = ["mean", "var"]

def test_named_parameters(self):
sampler = EnsembleSampler(
nwalkers=10,
ndim=len(self.names),
log_prob_fn=self.lnpdf,
parameter_names=self.names,
)
assert sampler.params_are_named
assert list(sampler.parameter_names.keys()) == self.names

def test_asserts(self):
# ndim name mismatch
with pytest.raises(AssertionError):
_ = EnsembleSampler(
nwalkers=10,
ndim=len(self.names) - 1,
log_prob_fn=self.lnpdf,
parameter_names=self.names,
)

# duplicate names
with pytest.raises(AssertionError):
_ = EnsembleSampler(
nwalkers=10,
ndim=3,
log_prob_fn=self.lnpdf,
parameter_names=["a", "b", "a"],
)

# vectorize turned on
with pytest.raises(AssertionError):
_ = EnsembleSampler(
nwalkers=10,
ndim=len(self.names),
log_prob_fn=self.lnpdf,
parameter_names=self.names,
vectorize=True,
)

def test_compute_log_prob(self):
# Try different numbers of walkers
for N in [4, 8, 10]:
sampler = EnsembleSampler(
nwalkers=N,
ndim=len(self.names),
log_prob_fn=self.lnpdf,
parameter_names=self.names,
)
coords = np.random.rand(N, len(self.names))
lnps, _ = sampler.compute_log_prob(coords)
assert len(lnps) == N
assert lnps.dtype == np.float64

def test_compute_log_prob_mixture(self):
names = ["mean1", "var1", "mean2", "var2"]
# Try different numbers of walkers
for N in [8, 10, 20]:
sampler = EnsembleSampler(
nwalkers=N,
ndim=len(names),
log_prob_fn=self.lnpdf_mixture,
parameter_names=names,
)
coords = np.random.rand(N, len(names))
lnps, _ = sampler.compute_log_prob(coords)
assert len(lnps) == N
assert lnps.dtype == np.float64

def test_compute_log_prob_mixture_grouped(self):
names = {"means": [0, 1], "vars": [2, 3], "constant": 4}
# Try different numbers of walkers
for N in [8, 10, 20]:
sampler = EnsembleSampler(
nwalkers=N,
ndim=5,
log_prob_fn=self.lnpdf_mixture_grouped,
parameter_names=names,
)
coords = np.random.rand(N, 5)
lnps, _ = sampler.compute_log_prob(coords)
assert len(lnps) == N
assert lnps.dtype == np.float64

def test_run_mcmc(self):
# Sort of an integration test
n_walkers = 4
sampler = EnsembleSampler(
nwalkers=n_walkers,
ndim=len(self.names),
log_prob_fn=self.lnpdf,
parameter_names=self.names,
)
guess = np.random.rand(n_walkers, len(self.names))
n_steps = 50
results = sampler.run_mcmc(guess, n_steps)
assert results.coords.shape == (n_walkers, len(self.names))
chain = sampler.chain
assert chain.shape == (n_walkers, n_steps, len(self.names))