Skip to content

Commit

Permalink
Use dask functions by duck typing. also add a unit test.
Browse files Browse the repository at this point in the history
  • Loading branch information
rainwoodman committed Oct 19, 2020
1 parent 25cf52e commit c1f1739
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
27 changes: 27 additions & 0 deletions nbodykit/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,30 @@ def test_constarray(comm):

a = ConstantArray([1.0, 1.0], 3, chunks=1000)
assert a.shape == (3, 2)

@MPITest([1, 4])
def test_vector_projection(comm):
cosmo = cosmology.Planck15

# make source
s = UniformCatalog(nbar=1e-5, BoxSize=1380., seed=42, comm=comm)

x = transform.VectorProjection(s['Position'], [1, 0, 0])
y = transform.VectorProjection(s['Position'], [0, 1, 0])
z = transform.VectorProjection(s['Position'], [0, 0, 1])
d = transform.VectorProjection(s['Position'], [1, 1, 1])

nx = transform.VectorProjection(s['Position'], [-2, 0, 0])
ny = transform.VectorProjection(s['Position'], [0, -2, 0])
nz = transform.VectorProjection(s['Position'], [0, 0, -2])
nd = transform.VectorProjection(s['Position'], [-2, -2, -2])

numpy.testing.assert_allclose(x, s['Position'] * [1, 0, 0], rtol=1e-3)
numpy.testing.assert_allclose(y, s['Position'] * [0, 1, 0], rtol=1e-3)
numpy.testing.assert_allclose(z, s['Position'] * [0, 0, 1], rtol=1e-3)
numpy.testing.assert_allclose(d[:, 0], s['Position'].sum(axis=-1) / 3., rtol=1e-3)

numpy.testing.assert_allclose(nx, s['Position'] * [1, 0, 0], rtol=1e-3)
numpy.testing.assert_allclose(ny, s['Position'] * [0, 1, 0], rtol=1e-3)
numpy.testing.assert_allclose(nz, s['Position'] * [0, 0, 1], rtol=1e-3)
numpy.testing.assert_allclose(nd[:, 0], s['Position'].sum(axis=-1) / 3., rtol=1e-3)
8 changes: 4 additions & 4 deletions nbodykit/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,10 @@ def VectorProjection(vector, direction):
projection : array_like, (..., D)
vector component of the given vector in the given direction
"""
direction = numpy.asarray(direction)
direction = direction / numpy.sqrt(direction ** 2)
projection = numpy.dot(vector, direction)[..., None]
projection = projection * direction
direction = numpy.asarray(direction, dtype='f8')
direction = direction / (direction ** 2).sum() ** 0.5
projection = (vector * direction).sum(axis=-1)
projection = projection[:, None] * direction[None, :]

return projection

Expand Down

0 comments on commit c1f1739

Please sign in to comment.