From 39b90f86a08c1a6c72eef6a8b59cb5dd2738f49c Mon Sep 17 00:00:00 2001 From: Tom McClintock Date: Fri, 23 Apr 2021 18:40:33 -0700 Subject: [PATCH 1/9] Added named parameter functionality to the ensemble sampler and added tests --- src/emcee/ensemble.py | 53 ++++++++++- src/emcee/tests/unit/test_ensemble.py | 129 ++++++++++++++++++++++++++ 2 files changed, 179 insertions(+), 3 deletions(-) create mode 100644 src/emcee/tests/unit/test_ensemble.py diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index 45cc4ba9..0bb48461 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -2,6 +2,8 @@ import warnings +from typing import Dict, List, Optional + import numpy as np from itertools import count @@ -76,6 +78,7 @@ def __init__( backend=None, vectorize=False, blobs_dtype=None, + parameter_names: Optional[List[str]] = None, # Deprecated... a=None, postargs=None, @@ -157,6 +160,28 @@ def __init__( # ``args`` and ``kwargs`` pickleable. self.log_prob_fn = _FunctionWrapper(log_prob_fn, args, kwargs) + # Save the parameter names + self.params_are_named: bool = parameter_names is not None + self.parameter_names: Optional[List[str]] = parameter_names + if self.params_are_named: + # Don't support vectorizing yet + msg = "named parameters with vectorization unsupported for now" + assert not self.vectorize, msg + + # Check for all named + msg = "name all parameters or set `parameter_names` to `None`" + assert len(parameter_names) == ndim, msg + + # Check for duplicate names + dupes = set() + uniq = [] + for name in parameter_names: + if name not in dupes: + uniq.append(name) + dupes.add(name) + msg = f"duplicate paramters: {dupes}" + assert len(uniq) == len(parameter_names), msg + @property def random_state(self): """ @@ -416,6 +441,10 @@ def compute_log_prob(self, coords): if np.any(np.isnan(p)): raise ValueError("At least one parameter value was NaN") + # If the parmaeters are named, then switch to dictionaries + if self.params_are_named: + p = ndarray_to_list_of_dicts(p, self.parameter_names) + # Run the log-probability calculations (optionally in parallel). if self.vectorize: results = self.log_prob_fn(p) @@ -428,7 +457,7 @@ def compute_log_prob(self, coords): else: map_func = map results = list( - map_func(self.log_prob_fn, (p[i] for i in range(len(p)))) + map_func(self.log_prob_fn, p) ) try: @@ -557,8 +586,8 @@ class _FunctionWrapper(object): def __init__(self, f, args, kwargs): self.f = f - self.args = [] if args is None else args - self.kwargs = {} if kwargs is None else kwargs + self.args = args or [] + self.kwargs = kwargs or {} def __call__(self, x): try: @@ -605,3 +634,21 @@ def _scaled_cond(a): return np.inf c = b / bsum return np.linalg.cond(c.astype(float)) + +def ndarray_to_list_of_dicts( + x: np.ndarray, + keys: List[str] +) -> List[Dict[str, np.number]]: + """ + A helper function to convert a ``np.ndarray`` into a list + of dictionaries of parameters. Used when parameters are named. + + Args: + x (np.ndarray): parameter array of shape ``(N, n_dim)``, where + ``N`` is an integer + keys (List[str]): names of the parameters to use as dictionary keys + + Returns: + list of dictionaries of parameters + """ + return [{key: xi[i] for i, key in enumerate(keys)} for xi in x] diff --git a/src/emcee/tests/unit/test_ensemble.py b/src/emcee/tests/unit/test_ensemble.py new file mode 100644 index 00000000..5bd145d4 --- /dev/null +++ b/src/emcee/tests/unit/test_ensemble.py @@ -0,0 +1,129 @@ +""" +Unit tests of some functionality in ensemble.py when the parameters are named +""" +import string +from unittest import TestCase + +import numpy as np +import pytest +from scipy.stats import norm + +from emcee.ensemble import EnsembleSampler, ndarray_to_list_of_dicts + + +class TestNP2ListOfDicts(TestCase): + def test_ndarray_to_list_of_dicts(self): + # Try different numbers of keys + for n_keys in [1, 2, 10, 26]: + keys = list(string.ascii_lowercase[:n_keys]) + key_set = set(keys) + # Try different number of walker/procs + for N in [1, 2, 3, 10, 100]: + x = np.random.rand(N, n_keys) + + LOD = ndarray_to_list_of_dicts(x, keys) + assert len(LOD) == N, "need 1 dict per row" + for i, dct in enumerate(LOD): + assert dct.keys() == key_set, "keys are missing" + for j, key in enumerate(keys): + assert dct[key] == x[i, j], f"wrong value at {(i, j)}" + + +class TestNamedParameters(TestCase): + """ + Test that a keyword-based log-probability function instead of + a positional. + """ + + # Keyword based lnpdf + def lnpdf(self, pars) -> np.float64: + mean = pars["mean"] + var = pars["var"] + if var <= 0: + return -np.inf + return norm.logpdf(self.x, loc=mean, scale=np.sqrt(var)).sum() + + def lnpdf_mixture(self, pars) -> np.float64: + mean1 = pars["mean1"] + var1 = pars["var1"] + mean2 = pars["mean2"] + var2 = pars["var2"] + if var1 <= 0 or var2 <= 0: + return -np.inf + return ( + norm.logpdf(self.x, loc=mean1, scale=np.sqrt(var1)).sum() + + norm.logpdf(self.x + 3, loc=mean2, scale=np.sqrt(var2)).sum() + ) + + def setUp(self): + # Draw some data from a unit Gaussian + self.x = norm.rvs(size=100) + self.names = ["mean", "var"] + + def test_named_parameters(self): + sampler = EnsembleSampler( + nwalkers=10, + ndim=len(self.names), + log_prob_fn=self.lnpdf, + parameter_names=self.names, + ) + assert sampler.params_are_named + assert sampler.parameter_names == self.names + + def test_asserts(self): + # ndim name mismatch + with pytest.raises(AssertionError): + _ = EnsembleSampler( + nwalkers=10, + ndim=len(self.names) - 1, + log_prob_fn=self.lnpdf, + parameter_names=self.names, + ) + + # duplicate names + with pytest.raises(AssertionError): + _ = EnsembleSampler( + nwalkers=10, + ndim=3, + log_prob_fn=self.lnpdf, + parameter_names=["a", "b", "a"], + ) + + # vectorize turned on + with pytest.raises(AssertionError): + _ = EnsembleSampler( + nwalkers=10, + ndim=len(self.names), + log_prob_fn=self.lnpdf, + parameter_names=self.names, + vectorize=True, + ) + + def test_compute_log_prob(self): + # Try different numbers of walkers + for N in [4, 8, 10]: + sampler = EnsembleSampler( + nwalkers=N, + ndim=len(self.names), + log_prob_fn=self.lnpdf, + parameter_names=self.names, + ) + coords = np.random.rand(N, len(self.names)) + lnps, _ = sampler.compute_log_prob(coords) + assert len(lnps) == N + assert lnps.dtype == np.float64 + + def test_compute_log_prob_mixture(self): + names = ["mean1", "var1", "mean2", "var2"] + # Try different numbers of walkers + for N in [8, 10, 20]: + sampler = EnsembleSampler( + nwalkers=N, + ndim=len(names), + log_prob_fn=self.lnpdf_mixture, + parameter_names=names, + ) + coords = np.random.rand(N, len(names)) + lnps, _ = sampler.compute_log_prob(coords) + assert len(lnps) == N + assert lnps.dtype == np.float64 From eb39ccc2647b8b23126bf9a7c2b07aca73e58992 Mon Sep 17 00:00:00 2001 From: Tom McClintock Date: Sat, 24 Apr 2021 15:02:21 -0700 Subject: [PATCH 2/9] Made an error message more informative --- src/emcee/ensemble.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index 0bb48461..43140f76 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -276,8 +276,9 @@ def sample( raise ValueError("'store' must be False when 'iterations' is None") # Interpret the input as a walker state and check the dimensions. state = State(initial_state, copy=True) - if np.shape(state.coords) != (self.nwalkers, self.ndim): - raise ValueError("incompatible input dimensions") + state_shape = np.shape(state.coords) + if state_shape != (self.nwalkers, self.ndim): + raise ValueError(f"incompatible input dimensions {state_shape}") if (not skip_initial_state_check) and ( not walkers_independent(state.coords) ): From f9aa4adf96d656aeffb091c784209330602bc195 Mon Sep 17 00:00:00 2001 From: Tom McClintock Date: Sat, 24 Apr 2021 15:02:31 -0700 Subject: [PATCH 3/9] Made a test for run_mcmc --- src/emcee/tests/unit/test_ensemble.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/emcee/tests/unit/test_ensemble.py b/src/emcee/tests/unit/test_ensemble.py index 5bd145d4..493f99af 100644 --- a/src/emcee/tests/unit/test_ensemble.py +++ b/src/emcee/tests/unit/test_ensemble.py @@ -127,3 +127,19 @@ def test_compute_log_prob_mixture(self): lnps, _ = sampler.compute_log_prob(coords) assert len(lnps) == N assert lnps.dtype == np.float64 + + def test_run_mcmc(self): + # Sort of an integration test + n_walkers = 4 + sampler = EnsembleSampler( + nwalkers=n_walkers, + ndim=len(self.names), + log_prob_fn=self.lnpdf, + parameter_names=self.names, + ) + guess = np.random.rand(n_walkers, len(self.names)) + n_steps = 50 + results = sampler.run_mcmc(guess, n_steps) + assert results.coords.shape == (n_walkers, len(self.names)) + chain = sampler.chain + assert chain.shape == (n_walkers, n_steps, len(self.names)) From 20e67f0679f8b120cb4b53c5012338e17ba751e1 Mon Sep 17 00:00:00 2001 From: Tom McClintock Date: Mon, 26 Apr 2021 10:17:57 -0700 Subject: [PATCH 4/9] Removed scipy from the test_ensemble.py unittests --- src/emcee/tests/unit/test_ensemble.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/emcee/tests/unit/test_ensemble.py b/src/emcee/tests/unit/test_ensemble.py index 493f99af..9b9fe732 100644 --- a/src/emcee/tests/unit/test_ensemble.py +++ b/src/emcee/tests/unit/test_ensemble.py @@ -6,7 +6,6 @@ import numpy as np import pytest -from scipy.stats import norm from emcee.ensemble import EnsembleSampler, ndarray_to_list_of_dicts @@ -41,7 +40,7 @@ def lnpdf(self, pars) -> np.float64: var = pars["var"] if var <= 0: return -np.inf - return norm.logpdf(self.x, loc=mean, scale=np.sqrt(var)).sum() + return -0.5 * ((mean - self.x)**2 / var + np.log(2 * np.pi * var)).sum() def lnpdf_mixture(self, pars) -> np.float64: mean1 = pars["mean1"] @@ -50,14 +49,14 @@ def lnpdf_mixture(self, pars) -> np.float64: var2 = pars["var2"] if var1 <= 0 or var2 <= 0: return -np.inf - return ( - norm.logpdf(self.x, loc=mean1, scale=np.sqrt(var1)).sum() - + norm.logpdf(self.x + 3, loc=mean2, scale=np.sqrt(var2)).sum() - ) + return -0.5 * ( + (mean1 - self.x)**2 / var1 + np.log(2 * np.pi * var1) + + (mean2 - self.x - 3)**2 / var2 + np.log(2 * np.pi * var2) + ).sum() def setUp(self): # Draw some data from a unit Gaussian - self.x = norm.rvs(size=100) + self.x = np.random.randn(100) self.names = ["mean", "var"] def test_named_parameters(self): From 4739fbf4d91edfbd5ec84089d31c8b863beef477 Mon Sep 17 00:00:00 2001 From: Tom McClintock Date: Thu, 29 Apr 2021 08:12:52 -0700 Subject: [PATCH 5/9] Docstring and updated typing for parameter_names --- src/emcee/ensemble.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index 43140f76..f498ea1c 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -2,7 +2,7 @@ import warnings -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import numpy as np from itertools import count @@ -63,6 +63,10 @@ class EnsembleSampler(object): to accept a list of position vectors instead of just one. Note that ``pool`` will be ignored if this is ``True``. (default: ``False``) + parameter_names (Optional[Union[List[str], Dict[str, List[int]]]]): + names of individual parameters or groups of parameters. If + specified, the ``log_prob_fn`` will recieve a dictionary of + parameters, rather than a ``np.ndarray``. """ @@ -78,7 +82,7 @@ def __init__( backend=None, vectorize=False, blobs_dtype=None, - parameter_names: Optional[List[str]] = None, + parameter_names: Optional[Union[Dict[str, int], List[str]]] = None, # Deprecated... a=None, postargs=None, From 8ee878a40a4ebe19aa452548eceef63729278417 Mon Sep 17 00:00:00 2001 From: Tom McClintock Date: Thu, 29 Apr 2021 08:44:49 -0700 Subject: [PATCH 6/9] Working on separate logic for dict vs list parameter names --- src/emcee/ensemble.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index f498ea1c..c3ed0ade 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -168,23 +168,28 @@ def __init__( self.params_are_named: bool = parameter_names is not None self.parameter_names: Optional[List[str]] = parameter_names if self.params_are_named: + assert isinstance(parameter_names, (list, dict)) + # Don't support vectorizing yet msg = "named parameters with vectorization unsupported for now" assert not self.vectorize, msg - - # Check for all named - msg = "name all parameters or set `parameter_names` to `None`" - assert len(parameter_names) == ndim, msg - - # Check for duplicate names - dupes = set() - uniq = [] - for name in parameter_names: - if name not in dupes: - uniq.append(name) - dupes.add(name) - msg = f"duplicate paramters: {dupes}" - assert len(uniq) == len(parameter_names), msg + + if isinstance(parameter_names, list): + # Check for all named + msg = "name all parameters or set `parameter_names` to `None`" + assert len(parameter_names) == ndim, msg + + # Check for duplicate names + dupes = set() + uniq = [] + for name in parameter_names: + if name not in dupes: + uniq.append(name) + dupes.add(name) + msg = f"duplicate paramters: {dupes}" + assert len(uniq) == len(parameter_names), msg + else: + pass # TODO @property def random_state(self): From 559a877d1e0dc31c808b961f26cc4f44b42c0054 Mon Sep 17 00:00:00 2001 From: Tom McClintock Date: Fri, 30 Apr 2021 17:38:29 -0700 Subject: [PATCH 7/9] Added functionality for the parameter_names to be a dictionary of either integers or lists of integers --- src/emcee/ensemble.py | 65 +++++++++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index c3ed0ade..a5149121 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -166,7 +166,6 @@ def __init__( # Save the parameter names self.params_are_named: bool = parameter_names is not None - self.parameter_names: Optional[List[str]] = parameter_names if self.params_are_named: assert isinstance(parameter_names, (list, dict)) @@ -174,22 +173,39 @@ def __init__( msg = "named parameters with vectorization unsupported for now" assert not self.vectorize, msg + # Check for duplicate names + dupes = set() + uniq = [] + for name in parameter_names: + if name not in dupes: + uniq.append(name) + dupes.add(name) + msg = f"duplicate paramters: {dupes}" + assert len(uniq) == len(parameter_names), msg + if isinstance(parameter_names, list): # Check for all named msg = "name all parameters or set `parameter_names` to `None`" assert len(parameter_names) == ndim, msg - - # Check for duplicate names - dupes = set() - uniq = [] - for name in parameter_names: - if name not in dupes: - uniq.append(name) - dupes.add(name) - msg = f"duplicate paramters: {dupes}" - assert len(uniq) == len(parameter_names), msg - else: - pass # TODO + # Convert a list to a dict + parameter_names: Dict[str, int] = { + name: i for i, name in enumerate(parameter_names) + } + + # Check not too many names + msg = "too many names" + assert len(parameter_names) <= ndim, msg + + # Check all indices appear + values = [ + v if isinstance(v, list) else [v] + for v in parameter_names.values() + ] + values = [item for sublist in values for item in sublist] + values = set(values) + msg = f"not all values appear -- set should be 0 to {ndim-1}" + assert values == set(np.arange(ndim)), msg + self.parameter_names = parameter_names @property def random_state(self): @@ -466,9 +482,7 @@ def compute_log_prob(self, coords): map_func = self.pool.map else: map_func = map - results = list( - map_func(self.log_prob_fn, p) - ) + results = list(map_func(self.log_prob_fn, p)) try: log_prob = np.array([float(l[0]) for l in results]) @@ -483,8 +497,9 @@ def compute_log_prob(self, coords): else: try: with warnings.catch_warnings(record=True): - warnings.simplefilter("error", - np.VisibleDeprecationWarning) + warnings.simplefilter( + "error", np.VisibleDeprecationWarning + ) try: dt = np.atleast_1d(blob[0]).dtype except Warning: @@ -494,7 +509,8 @@ def compute_log_prob(self, coords): "placed in an object array. Numpy has " "deprecated this automatic detection, so " "please specify " - "blobs_dtype=np.dtype('object')") + "blobs_dtype=np.dtype('object')" + ) dt = np.dtype("object") except ValueError: dt = np.dtype("object") @@ -645,10 +661,11 @@ def _scaled_cond(a): c = b / bsum return np.linalg.cond(c.astype(float)) + def ndarray_to_list_of_dicts( - x: np.ndarray, - keys: List[str] -) -> List[Dict[str, np.number]]: + x: np.ndarray, + key_map: Dict[str, Union[int, List[int]]], +) -> List[Dict[str, Union[np.number, np.ndarray]]]: """ A helper function to convert a ``np.ndarray`` into a list of dictionaries of parameters. Used when parameters are named. @@ -656,9 +673,9 @@ def ndarray_to_list_of_dicts( Args: x (np.ndarray): parameter array of shape ``(N, n_dim)``, where ``N`` is an integer - keys (List[str]): names of the parameters to use as dictionary keys + key_map (Dict[str, Union[int, List[int]]): Returns: list of dictionaries of parameters """ - return [{key: xi[i] for i, key in enumerate(keys)} for xi in x] + return [{key: xi[val] for key, val in key_map.items()} for xi in x] From a5b3f10841400d79514176fb7e1220b4951773f1 Mon Sep 17 00:00:00 2001 From: Tom McClintock Date: Fri, 30 Apr 2021 17:38:53 -0700 Subject: [PATCH 8/9] Added functionality for the parameter_names to be a dictionary of either integers or lists of integers. Also ran isort and black --- src/emcee/ensemble.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index a5149121..24faba3c 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -1,11 +1,10 @@ # -*- coding: utf-8 -*- import warnings - +from itertools import count from typing import Dict, List, Optional, Union import numpy as np -from itertools import count from .backends import Backend from .model import Model From 36f72c7134b1e65df909d4357c348d68761a1dde Mon Sep 17 00:00:00 2001 From: Tom McClintock Date: Fri, 30 Apr 2021 17:39:26 -0700 Subject: [PATCH 9/9] Test of the dictionary parameter name functionality --- src/emcee/tests/unit/test_ensemble.py | 54 +++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 7 deletions(-) diff --git a/src/emcee/tests/unit/test_ensemble.py b/src/emcee/tests/unit/test_ensemble.py index 9b9fe732..f6f5ad27 100644 --- a/src/emcee/tests/unit/test_ensemble.py +++ b/src/emcee/tests/unit/test_ensemble.py @@ -16,11 +16,12 @@ def test_ndarray_to_list_of_dicts(self): for n_keys in [1, 2, 10, 26]: keys = list(string.ascii_lowercase[:n_keys]) key_set = set(keys) + key_dict = {key: i for i, key in enumerate(keys)} # Try different number of walker/procs for N in [1, 2, 3, 10, 100]: x = np.random.rand(N, n_keys) - LOD = ndarray_to_list_of_dicts(x, keys) + LOD = ndarray_to_list_of_dicts(x, key_dict) assert len(LOD) == N, "need 1 dict per row" for i, dct in enumerate(LOD): assert dct.keys() == key_set, "keys are missing" @@ -40,7 +41,9 @@ def lnpdf(self, pars) -> np.float64: var = pars["var"] if var <= 0: return -np.inf - return -0.5 * ((mean - self.x)**2 / var + np.log(2 * np.pi * var)).sum() + return ( + -0.5 * ((mean - self.x) ** 2 / var + np.log(2 * np.pi * var)).sum() + ) def lnpdf_mixture(self, pars) -> np.float64: mean1 = pars["mean1"] @@ -49,10 +52,32 @@ def lnpdf_mixture(self, pars) -> np.float64: var2 = pars["var2"] if var1 <= 0 or var2 <= 0: return -np.inf - return -0.5 * ( - (mean1 - self.x)**2 / var1 + np.log(2 * np.pi * var1) - + (mean2 - self.x - 3)**2 / var2 + np.log(2 * np.pi * var2) - ).sum() + return ( + -0.5 + * ( + (mean1 - self.x) ** 2 / var1 + + np.log(2 * np.pi * var1) + + (mean2 - self.x - 3) ** 2 / var2 + + np.log(2 * np.pi * var2) + ).sum() + ) + + def lnpdf_mixture_grouped(self, pars) -> np.float64: + mean1, mean2 = pars["means"] + var1, var2 = pars["vars"] + const = pars["constant"] + if var1 <= 0 or var2 <= 0: + return -np.inf + return ( + -0.5 + * ( + (mean1 - self.x) ** 2 / var1 + + np.log(2 * np.pi * var1) + + (mean2 - self.x - 3) ** 2 / var2 + + np.log(2 * np.pi * var2) + ).sum() + + const + ) def setUp(self): # Draw some data from a unit Gaussian @@ -67,7 +92,7 @@ def test_named_parameters(self): parameter_names=self.names, ) assert sampler.params_are_named - assert sampler.parameter_names == self.names + assert list(sampler.parameter_names.keys()) == self.names def test_asserts(self): # ndim name mismatch @@ -127,6 +152,21 @@ def test_compute_log_prob_mixture(self): assert len(lnps) == N assert lnps.dtype == np.float64 + def test_compute_log_prob_mixture_grouped(self): + names = {"means": [0, 1], "vars": [2, 3], "constant": 4} + # Try different numbers of walkers + for N in [8, 10, 20]: + sampler = EnsembleSampler( + nwalkers=N, + ndim=5, + log_prob_fn=self.lnpdf_mixture_grouped, + parameter_names=names, + ) + coords = np.random.rand(N, 5) + lnps, _ = sampler.compute_log_prob(coords) + assert len(lnps) == N + assert lnps.dtype == np.float64 + def test_run_mcmc(self): # Sort of an integration test n_walkers = 4