Skip to content

Commit

Permalink
adding a few more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Oct 5, 2017
1 parent a36bfac commit ae4c05c
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 255 deletions.
5 changes: 1 addition & 4 deletions emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,7 @@ def compute_log_prob(self, coords=None):
this position or ``None`` if nothing was returned.
"""
if coords is None:
p = self.pos
else:
p = coords
p = coords

# Check that the parameters are in physical ranges.
if np.any(np.isinf(p)):
Expand Down
1 change: 0 additions & 1 deletion emcee/moves/move.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def update(self,
accepted, subset=None):
if subset is None:
subset = np.ones(len(coords), dtype=bool)
inds = np.arange(len(coords))
m1 = subset & accepted
m2 = accepted[subset]
coords[m1] = new_coords[m2]
Expand Down
247 changes: 0 additions & 247 deletions emcee/tests.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/integration/test_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _test_normal(proposal, ndim=1, nwalkers=32, nsteps=2000, seed=1234,
# standard deviation.
samps = sampler.get_chain(flat=True)
mu, sig = np.mean(samps, axis=0), np.std(samps, axis=0)
assert np.all(np.abs(mu) < 0.07), "Incorrect mean"
assert np.all(np.abs(mu) < 0.08), "Incorrect mean"
assert np.all(np.abs(sig - 1) < 0.05), "Incorrect standard deviation"

if ndim == 1:
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_autocorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,15 @@ def test_too_short(seed=1234, ndim=3, N=100):
with pytest.raises(AutocorrError):
integrated_time(x, low=100)
tau = integrated_time(x, quiet=True) # NOQA


def test_autocorr_multi_works():
np.random.seed(42)
xs = np.random.randn(16384, 2)

# This throws exception unconditionally in buggy impl's
acls_multi = integrated_time(xs)
acls_single = np.array([integrated_time(xs[:, i])
for i in range(xs.shape[1])])

assert np.all(np.abs(acls_multi - acls_single) < 2)
8 changes: 6 additions & 2 deletions tests/unit/test_stretch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import division, print_function

import warnings

import pytest
import numpy as np

Expand All @@ -11,6 +13,8 @@


def test_live_dangerously(nwalkers=32, nsteps=3000, seed=1234):
warnings.filterwarnings("error")

# Set up the random number generator.
np.random.seed(seed)
coords = np.random.randn(nwalkers, 2 * nwalkers)
Expand All @@ -20,9 +24,9 @@ def test_live_dangerously(nwalkers=32, nsteps=3000, seed=1234):
# walkers.
with pytest.raises(RuntimeError):
proposal.propose(coords, np.random.randn(nwalkers), None,
lambda x: (np.zeros(nwalkers), None), np.random)
lambda x: (np.zeros(len(x)), None), np.random)

# Living dangerously...
proposal.live_dangerously = True
proposal.propose(coords, np.random.randn(nwalkers), None,
lambda x: (np.zeros(nwalkers), None), np.random)
lambda x: (np.zeros(len(x)), None), np.random)

0 comments on commit ae4c05c

Please sign in to comment.