Skip to content

Commit

Permalink
Merge branch 'master' into tests
Browse files Browse the repository at this point in the history
  • Loading branch information
d-ming committed Nov 2, 2017
2 parents f42034a + 9571770 commit e056706
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
38 changes: 38 additions & 0 deletions artools/artools.py
Expand Up @@ -954,6 +954,40 @@ def isOdd(N):
return True


def gridPts(pts_per_axis, axis_lims):
'''
Generate a list of points spaced on a user-specified grid range.
Arguments
pts_per_axis: Number of points to generate.
axis_lims: An array of axis min-max pairs.
e.g. [xmin, xmax, ymin, ymax, zmin, zmax, etc.] where
d = len(axis_lims)/2
Returns
Ys: (pts_per_axis x d) numpy array of grid points.
'''

num_elements = len(axis_lims)
if isOdd(num_elements):
raise ValueError("axis_lims must have an even number of elements")

dim = int(num_elements/2)

AX = sp.reshape(axis_lims, (-1, 2))
D = sp.diag(AX[:, 1] - AX[:, 0])

# compute the Cartesian product for an n-D unit cube
spacing_list = [sp.linspace(0, 1, pts_per_axis) for i in range(AX.shape[0])]
Xs = sp.array(list(itertools.product(*spacing_list)))

# scale to axis limits
Ys = sp.dot(Xs, D) + AX[:, 0]

return Ys


def randPts(Npts, axis_lims):
'''
Generate a list of random points within a user-specified range.
Expand All @@ -973,6 +1007,9 @@ def randPts(Npts, axis_lims):
if isOdd(num_elements):
raise ValueError("axis_lims must have an even number of elements")

if type(Npts) != int:
raise TypeError("Npts must be an integer")

dim = int(num_elements/2)

Xs = sp.rand(Npts, dim)
Expand All @@ -982,6 +1019,7 @@ def randPts(Npts, axis_lims):
AX = sp.reshape(axis_lims, (-1, 2))
D = sp.diag(AX[:, 1] - AX[:, 0])

# scale to axis limits
Ys = sp.dot(Xs, D) + AX[:, 0]

return Ys
Expand Down
32 changes: 32 additions & 0 deletions artools/test/gridPts_test.py
@@ -0,0 +1,32 @@
import sys
sys.path.append('../')
from artools import gridPts, sameRows
import scipy as sp


class TestNormal:

def test_1(self):
xs = gridPts(2, [0., 1., 2., 3.5])

xs_ref = sp.array([[ 0. , 2. ],
[ 0. , 3.5],
[ 1. , 2. ],
[ 1. , 3.5]])

assert sameRows(xs, xs_ref) == True


def test_2(self):
xs = gridPts(2, [0., 1., 1., 2., 3., 4.])

xs_ref = sp.array([[ 0., 1., 3.],
[ 0., 1., 4.],
[ 0., 2., 3.],
[ 0., 2., 4.],
[ 1., 1., 3.],
[ 1., 1., 4.],
[ 1., 2., 3.],
[ 1., 2., 4.]])

assert sameRows(xs, xs_ref) == True

0 comments on commit e056706

Please sign in to comment.