Skip to content

Commit

Permalink
SMC uses ModelPrior, ensure validity of proposals (#224)
Browse files Browse the repository at this point in the history
* SMC uses ModelPrior, ensure validity of proposals

* Address comments

* Change behaviour when size=None

* Address comment
  • Loading branch information
vuolleko committed Sep 6, 2017
1 parent 6d38633 commit 6346b7c
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 36 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dev
- Renamed elfi.set_current_model to elfi.set_default_model
- Renamed elfi.get_current_model to elfi.get_default_model
- Improved performance when rerunning inference using stored data
- Change SMC to use ModelPrior, use to immediately reject invalid proposals

0.6.1 (2017-07-21)
------------------
Expand Down
2 changes: 1 addition & 1 deletion elfi/examples/bignk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import scipy.stats as ss

import elfi
from elfi.examples.gnk import ss_order, ss_robust, ss_octile, euclidean_multidim
from elfi.examples.gnk import euclidean_multidim, ss_octile, ss_order, ss_robust

EPS = np.finfo(float).eps

Expand Down
24 changes: 9 additions & 15 deletions elfi/methods/parameter_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import elfi.client
import elfi.methods.mcmc as mcmc
import elfi.model.augmenter as augmenter
import elfi.visualization.interactive as visin
import elfi.visualization.visualization as vis
from elfi.loader import get_sub_seed
Expand Down Expand Up @@ -281,12 +280,12 @@ def iterate(self):
# Submit new batches if allowed
while self._allow_submit(self.batches.next_index):
next_batch = self.prepare_new_batch(self.batches.next_index)
logger.info("Submitting batch %d" % self.batches.next_index)
logger.debug("Submitting batch %d" % self.batches.next_index)
self.batches.submit(next_batch)

# Handle the next ready batch in succession
batch, batch_index = self.batches.wait_next()
logger.info('Received batch %d' % batch_index)
logger.debug('Received batch %d' % batch_index)
self.update(batch, batch_index)

@property
Expand Down Expand Up @@ -616,17 +615,10 @@ def __init__(self, model, discrepancy_name=None, output_names=None, **kwargs):
"""
model, discrepancy_name = self._resolve_model(model, discrepancy_name)

# Add the prior pdf nodes to the model
model = model.copy()
logpdf_name = augmenter.add_pdf_nodes(model, log=True)[0]

output_names = [discrepancy_name] + model.parameter_names + [logpdf_name] + \
(output_names or [])

super(SMC, self).__init__(model, output_names, **kwargs)

self._prior = ModelPrior(self.model)
self.discrepancy_name = discrepancy_name
self.prior_logpdf = logpdf_name
self.state['round'] = 0
self._populations = []
self._rejection = None
Expand Down Expand Up @@ -701,9 +693,10 @@ def prepare_new_batch(self, batch_index):
# Use the actual prior
return

# Sample from the proposal
params = GMDistribution.rvs(
*self._gm_params, size=self.batch_size, random_state=self._round_random_state)
# Sample from the proposal, condition on actual prior
params = GMDistribution.rvs(*self._gm_params, size=self.batch_size,
prior_logpdf=self._prior.logpdf,
random_state=self._round_random_state)

batch = arr2d_to_batch(params, self.parameter_names)
return batch
Expand Down Expand Up @@ -743,7 +736,8 @@ def _compute_weights_and_cov(self, pop):

if self._populations:
q_logpdf = GMDistribution.logpdf(params, *self._gm_params)
w = np.exp(pop.outputs[self.prior_logpdf] - q_logpdf)
p_logpdf = self._prior.logpdf(params)
w = np.exp(p_logpdf - q_logpdf)
else:
w = np.ones(pop.n_samples)

Expand Down
78 changes: 59 additions & 19 deletions elfi/methods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,14 @@ def pdf(cls, x, means, cov=1, weights=None):
Parameters
----------
x : array_like
scalar, 1d or 2d array of points where to evaluate, observations in rows
Scalar, 1d or 2d array of points where to evaluate, observations in rows
means : array_like
means of the Gaussian mixture components. It is assumed that means[0] contains
Means of the Gaussian mixture components. It is assumed that means[0] contains
the mean of the first gaussian component.
weights : array_like
1d array of weights of the gaussian mixture components
cov : array_like, float
a shared covariance matrix for the mixture components
A shared covariance matrix for the mixture components
"""
means, weights = cls._normalize_params(means, weights)
Expand Down Expand Up @@ -170,42 +170,82 @@ def logpdf(cls, x, means, cov=1, weights=None):
Parameters
----------
x : array_like
scalar, 1d or 2d array of points where to evaluate, observations in rows
Scalar, 1d or 2d array of points where to evaluate, observations in rows
means : array_like
means of the Gaussian mixture components. It is assumed that means[0] contains
Means of the Gaussian mixture components. It is assumed that means[0] contains
the mean of the first gaussian component.
weights : array_like
1d array of weights of the gaussian mixture components
cov : array_like, float
a shared covariance matrix for the mixture components
A shared covariance matrix for the mixture components
"""
return np.log(cls.pdf(x, means=means, cov=cov, weights=weights))

@classmethod
def rvs(cls, means, cov=1, weights=None, size=1, random_state=None):
def rvs(cls, means, cov=1, weights=None, size=1, prior_logpdf=None, random_state=None):
"""Draw random variates from the distribution.
Parameters
----------
means : array_like
means of the Gaussian mixture components
weights : array_like
Means of the Gaussian mixture components
cov : array_like, optional
A shared covariance matrix for the mixture components
weights : array_like, optional
1d array of weights of the gaussian mixture components
cov : array_like
a shared covariance matrix for the mixture components
size : int or tuple
random_state : np.random.RandomState or None
size : int or tuple or None, optional
Number or shape of samples to draw (a single sample has the shape of `means`).
If None, return one sample without an enclosing array.
prior_logpdf : callable, optional
Can be used to check validity of random variable.
random_state : np.random.RandomState, optional
"""
means, weights = cls._normalize_params(means, weights)
random_state = random_state or np.random
means, weights = cls._normalize_params(means, weights)

inds = random_state.choice(len(means), size=size, p=weights)
rvs = means[inds]
perturb = ss.multivariate_normal.rvs(
mean=means[0] * 0, cov=cov, random_state=random_state, size=size)
return rvs + perturb
if size is None:
size = 1
no_wrap = True
else:
no_wrap = False

output = np.empty((size,) + means.shape[1:])

n_accepted = 0
n_left = size
trials = 0

while n_accepted < size:
inds = random_state.choice(len(means), size=n_left, p=weights)
rvs = means[inds]
perturb = ss.multivariate_normal.rvs(mean=means[0] * 0,
cov=cov,
random_state=random_state,
size=n_left)
x = rvs + perturb

# check validity of x
if prior_logpdf is not None:
x = x[np.isfinite(prior_logpdf(x))]

n_accepted1 = len(x)
output[n_accepted: n_accepted+n_accepted1] = x
n_accepted += n_accepted1
n_left -= n_accepted1

trials += 1
if trials == 100:
logger.warning("SMC: It appears to be difficult to find enough valid proposals "
"with prior pdf > 0. ELFI will keep trying, but you may wish "
"to kill the process and adjust the model priors.")

logger.debug('Needed %i trials to find %i valid samples.', trials, size)
if no_wrap:
return output[0]
else:
return output

@staticmethod
def _normalize_params(means, weights):
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

import elfi
import elfi.examples.ma2 as exma2
from elfi.methods.parameter_inference import ParameterInference


Expand All @@ -18,7 +19,7 @@ def test_smc(ma2):
N = 1000
smc = elfi.SMC(ma2['d'], batch_size=20000)
res = smc.sample(N, thresholds=thresholds)
dens = res.populations[0].outputs[smc.prior_logpdf]
dens = smc._prior.logpdf(res.samples_array)
# Test that the density is uniform
assert np.allclose(dens, dens[0])

Expand All @@ -38,6 +39,10 @@ def test_smc(ma2):
res.sample_means_summary()
res.sample_means_summary(all=True)

# Ensure prior pdf > 0 for samples
assert np.all(exma2.CustomPrior1.pdf(samples[:, 0], 2) > 0)
assert np.all(exma2.CustomPrior2.pdf(samples[:, 1], samples[:, 0], 1) > 0)


# A superficial test to compensate for test_inference.test_BOLFI not being run on Travis
@pytest.mark.usefixtures('with_all_clients')
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ def test_rvs(self):
# Test that the mean of the second mode is correct
assert np.abs(np.mean(rvs[:, 1]) + 3) < .1

def test_rvs_prior_ok(self):
means = [0.8, 0.5]
weights = [.3, .7]
N = 10000
prior_logpdf = ss.uniform(0, 1).logpdf
rvs = GMDistribution.rvs(means, weights=weights, size=N, prior_logpdf=prior_logpdf)

# Ensure prior pdf > 0 for all samples
assert np.all(np.isfinite(prior_logpdf(rvs)))


def test_numgrad():
assert np.allclose(numgrad(lambda x: np.log(x), 3), [1 / 3])
Expand Down

0 comments on commit 6346b7c

Please sign in to comment.