Skip to content

Commit

Permalink
Test UnitNormMap transform of center
Browse files Browse the repository at this point in the history
Ehnance unit test of UnitNormMap to test transforming position = center.
  • Loading branch information
r-owen committed Aug 1, 2017
1 parent 402fd3b commit 49c7edf
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions tests/test_unitNormMap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest

import numpy as np
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_equal

import astshim
from astshim.test import MappingTestCase
Expand All @@ -13,8 +13,17 @@ class TestUnitNormMap(MappingTestCase):
def test_UnitNormMapBasics(self):
"""Test basics of UnitNormMap including applyForward
"""
# `full_` variables contain data for 3 axes; the variables without the `full_` prefix
# are a subset containing the number of axes being tested
full_center = np.array([-1, 1, 2], dtype=float)
full_indata = np.array([
[full_center[0], 1.0, 2.0, -6.0, 30.0, 1.0],
[full_center[1], 3.0, 99.0, -5.0, 21.0, 0.0],
[full_center[2], -5.0, 3.0, -7.0, 37.0, 0.0],
], dtype=float)
for nin in (1, 2, 3):
center = np.array([-1, 1, 2][0:nin], dtype=float)
center = full_center[0:nin]
indata = full_indata[0:nin]
unitnormmap = astshim.UnitNormMap(center)
self.assertEqual(unitnormmap.className, "UnitNormMap")
self.assertEqual(unitnormmap.nIn, nin)
Expand All @@ -25,23 +34,26 @@ def test_UnitNormMapBasics(self):
self.checkCopy(unitnormmap)
self.checkPersistence(unitnormmap)

indata = np.array([
[1.0, 2.0, -6.0, 30.0, 1.0],
[3.0, 99.0, -5.0, 21.0, 0.0],
[-5.0, 3.0, -7.0, 37.0, 0.0],
[7.0, -23.0, -3.0, 45.0, 0.0],
], dtype=float)[0:nin]
self.checkRoundTrip(unitnormmap, indata)

outdata = unitnormmap.applyForward(indata)
norm = outdata[-1]

relindata = (indata.T - center).T
pred_norm = np.linalg.norm(relindata, axis=0)
# the first input point is at the center, so the expected output is
# [Nan, Nan, ..., Nan, 0]
pred_out_at_center = [np.nan]*nin + [0]
assert_equal(outdata[:, 0], pred_out_at_center)

relative_indata = (indata.T - center).T
pred_norm = np.linalg.norm(relative_indata, axis=0)
assert_allclose(norm, pred_norm)

pred_relindata = outdata[0:nin] * norm
assert_allclose(relindata, pred_relindata)
pred_relative_indata = outdata[0:nin] * norm
# the first input point is at the center, so the output is
# [NaN, NaN, ..., NaN, 0], (as checked above),
# but the expected value after scaling by the norm is 0s, so...
pred_relative_indata[:, 0] = [0]*nin
assert_allclose(relative_indata, pred_relative_indata)

# UnitNormMap must have at least one input
with self.assertRaises(Exception):
Expand Down

0 comments on commit 49c7edf

Please sign in to comment.