Skip to content

Commit

Permalink
Fix regular grid for d=1; add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt Hoffman committed Jan 5, 2015
1 parent c107ca3 commit 9c0e373
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
10 changes: 7 additions & 3 deletions mwhutils/random/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def sobol(bounds, n, rng=None):
return X


def grid(bounds, n, rng=None):
def grid(bounds, n):
"""
Generate a regular grid within the specified region, given by `bounds`,
a list of [(lo,hi), ..] bounds in each dimension. `n` represents the number
Expand All @@ -97,7 +97,11 @@ def grid(bounds, n, rng=None):
bounds = np.array(bounds, ndmin=2, copy=False)
d = len(bounds)

X = np.meshgrid(*(np.linspace(a, b, n) for a, b in bounds))
X = np.reshape(X, (d, -1)).T
if d == 1:
X = np.linspace(bounds[0, 0], bounds[0, 1], n)
X = np.reshape(X, (-1, 1))
else:
X = np.meshgrid(*(np.linspace(a, b, n) for a, b in bounds))
X = np.reshape(X, (d, -1)).T

return X
22 changes: 19 additions & 3 deletions tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy.testing as nt

from mwhutils.random import rstate
from mwhutils.random import uniform, latin, sobol
from mwhutils.random import uniform, latin, sobol, grid


def test_rstate():
Expand All @@ -28,11 +28,27 @@ def check_random(method):
bounds = [(0, 1), (3, 4)]
sample = method(bounds, 10)
assert sample.shape == (10, 2)
assert all(sample[:, 0] > 0) and all(sample[:, 0] < 1)
assert all(sample[:, 1] > 3) and all(sample[:, 1] < 4)
assert all(sample[:, 0] >= 0) and all(sample[:, 0] <= 1)
assert all(sample[:, 1] >= 3) and all(sample[:, 1] <= 4)

sample = grid((0, 1), 10)
assert sample.shape == (10, 1)
assert all(sample[:, 0] >= 0) and all(sample[:, 0] <= 1)


def test_random():
"""Test all the random generators."""
for method in [uniform, latin, sobol]:
yield check_random, method


def test_grid():
"""Test the non-random grid "sampler"."""
sample = grid([(0, 1), (3, 4)], 10)
assert sample.shape == (100, 2)
assert all(sample[:, 0] >= 0) and all(sample[:, 0] <= 1)
assert all(sample[:, 1] >= 3) and all(sample[:, 1] <= 4)

sample = grid((0, 1), 10)
assert sample.shape == (10, 1)
assert all(sample[:, 0] >= 0) and all(sample[:, 0] <= 1)

0 comments on commit 9c0e373

Please sign in to comment.