Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-12025: make Transform pickleable #278

Merged
merged 4 commits into from
Oct 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 27 additions & 0 deletions python/lsst/afw/geom/python/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,32 @@ def then(self, next):
% (self, next))


def unpickleTransform(cls, state):
"""Unpickle a Transform object

Parameters
----------
cls : `type`
A `Transform` class.
state : `str`
Pickled state.

Returns
-------
transform : `cls`
The unpickled Transform.
"""
return cls.readString(state)


def reduceTransform(transform):
"""Pickle a Transform object

This provides the `__reduce__` implementation for a Transform.
"""
return unpickleTransform, (type(transform), transform.writeString())


def addTransformMethods(cls):
"""Add pure python methods to the specified Transform class, and register
the class in `transformRegistry`
Expand All @@ -86,3 +112,4 @@ def addTransformMethods(cls):
transformRegistry[cls.__name__] = cls
cls.getJacobian = getJacobian
cls.then = then
cls.__reduce__ = reduceTransform
67 changes: 38 additions & 29 deletions python/lsst/afw/geom/testUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import itertools
import math
import os
import pickle

import astshim
import numpy as np
Expand Down Expand Up @@ -970,6 +971,39 @@ def checkThen(self, fromName, midName, toName):
with self.assertRaises(InvalidParameterError):
transform = transform1.then(transform2)

def assertTransformsEqual(self, transform1, transform2):
"""Assert that two transforms are equal"""
self.assertEqual(type(transform1), type(transform2))
self.assertEqual(transform1.fromEndpoint, transform2.fromEndpoint)
self.assertEqual(transform1.toEndpoint, transform2.toEndpoint)
self.assertEqual(transform1.getFrameSet(), transform2.getFrameSet())

fromEndpoint = transform1.fromEndpoint
toEndpoint = transform1.toEndpoint
frameSet = transform1.getFrameSet()
nIn = frameSet.nIn
nOut = frameSet.nOut

if frameSet.hasForward:
nPoints = 7 # arbitrary
rawInArray = self.makeRawArrayData(nPoints, nIn)
inArray = fromEndpoint.arrayFromData(rawInArray)
outArray = transform1.applyForward(inArray)
outData = toEndpoint.dataFromArray(outArray)
outArrayRoundTrip = transform2.applyForward(inArray)
outDataRoundTrip = toEndpoint.dataFromArray(outArrayRoundTrip)
assert_allclose(outData, outDataRoundTrip)

if frameSet.hasInverse:
nPoints = 7 # arbitrary
rawOutArray = self.makeRawArrayData(nPoints, nOut)
outArray = toEndpoint.arrayFromData(rawOutArray)
inArray = transform1.applyInverse(outArray)
inData = fromEndpoint.dataFromArray(inArray)
inArrayRoundTrip = transform2.applyInverse(outArray)
inDataRoundTrip = fromEndpoint.dataFromArray(inArrayRoundTrip)
assert_allclose(inData, inDataRoundTrip)

def checkPersistence(self, transform):
"""Check persistence of a transform
"""
Expand All @@ -988,36 +1022,11 @@ def checkPersistence(self, transform):
with self.assertRaises(lsst.pex.exceptions.InvalidParameterError):
transform.readString(badStr2)
transformFromStr1 = transform.readString(transformStr)
self.assertEqual(type(transform), type(transformFromStr1))
self.assertEqual(transform.getFrameSet(), transformFromStr1.getFrameSet())
self.assertTransformsEqual(transform, transformFromStr1)

# check transformFromString
transformFromStr2 = afwGeom.transformFromString(transformStr)
self.assertEqual(type(transform), type(transformFromStr2))
self.assertEqual(transform.getFrameSet(), transformFromStr2.getFrameSet())

fromEndpoint = transform.fromEndpoint
toEndpoint = transform.toEndpoint
frameSet = transform.getFrameSet()
nIn = frameSet.nIn
nOut = frameSet.nOut
self.assertTransformsEqual(transform, transformFromStr2)

if frameSet.hasForward:
nPoints = 7 # arbitrary
rawInArray = self.makeRawArrayData(nPoints, nIn)
inArray = fromEndpoint.arrayFromData(rawInArray)
outArray = transform.applyForward(inArray)
outData = toEndpoint.dataFromArray(outArray)
outArrayRoundTrip = transformFromStr1.applyForward(inArray)
outDataRoundTrip = toEndpoint.dataFromArray(outArrayRoundTrip)
assert_allclose(outData, outDataRoundTrip)

if frameSet.hasInverse:
nPoints = 7 # arbitrary
rawOutArray = self.makeRawArrayData(nPoints, nOut)
outArray = toEndpoint.arrayFromData(rawOutArray)
inArray = transform.applyInverse(outArray)
inData = fromEndpoint.dataFromArray(inArray)
inArrayRoundTrip = transformFromStr1.applyInverse(outArray)
inDataRoundTrip = fromEndpoint.dataFromArray(inArrayRoundTrip)
assert_allclose(inData, inDataRoundTrip)
# Check pickling
self.assertTransformsEqual(transform, pickle.loads(pickle.dumps(transform)))
2 changes: 2 additions & 0 deletions python/lsst/afw/math/interpolate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ PYBIND11_PLUGIN(_interpolate) {
"x"_a, "y"_a, "style"_a = Interpolate::AKIMA_SPLINE);

mod.def("stringToInterpStyle", stringToInterpStyle, "style"_a);
mod.def("lookupMaxInterpStyle", lookupMaxInterpStyle, "n"_a);
mod.def("lookupMinInterpPoints", lookupMinInterpPoints, "style"_a);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be much appreciated if you could add a few simple unit tests for these (including stringToInterpStyle).


return mod.ptr();
}
48 changes: 48 additions & 0 deletions tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,54 @@ def testInvalidInputs(self):
afwMath.makeInterpolate(np.array([0], dtype=float), np.array([1], dtype=float),
afwMath.Interpolate.LINEAR)

def testLookupMaxInterpStyle(self):
for numPoints in range(1, 6):
maxInterpStyle = afwMath.lookupMaxInterpStyle(numPoints)
desiredMax = {
1: afwMath.Interpolate.Style.CONSTANT,
2: afwMath.Interpolate.Style.LINEAR,
3: afwMath.Interpolate.Style.CUBIC_SPLINE,
4: afwMath.Interpolate.Style.CUBIC_SPLINE,
}.get(numPoints, afwMath.Interpolate.Style.AKIMA_SPLINE)
self.assertEqual(maxInterpStyle, desiredMax)

for badNumPoints in (-5, -1, 0):
with self.assertRaises(pexExcept.InvalidParameterError):
afwMath.lookupMaxInterpStyle(badNumPoints)

def testLookupMinInterpPoints(self):
for style in afwMath.Interpolate.Style.__members__.values():
if style in (afwMath.Interpolate.Style.UNKNOWN, afwMath.Interpolate.Style.NUM_STYLES):
with self.assertRaises(pexExcept.OutOfRangeError):
afwMath.lookupMinInterpPoints(style)
else:
minPoints = afwMath.lookupMinInterpPoints(style)
desiredMin = {
afwMath.Interpolate.Style.CONSTANT: 1,
afwMath.Interpolate.Style.LINEAR: 2,
afwMath.Interpolate.Style.NATURAL_SPLINE: 3,
afwMath.Interpolate.Style.CUBIC_SPLINE: 3,
afwMath.Interpolate.Style.CUBIC_SPLINE_PERIODIC: 3,
afwMath.Interpolate.Style.AKIMA_SPLINE: 5,
afwMath.Interpolate.Style.AKIMA_SPLINE_PERIODIC: 5,
}.get(style, None)
if desiredMin is None:
self.fail("Unrecognized style: %s" % (style,))
self.assertEqual(minPoints, desiredMin)

def testStringToInterpStyle(self):
for name, desiredStyle in afwMath.Interpolate.Style.__members__.items():
if name in ("UNKNOWN", "NUM_STYLES"):
with self.assertRaises(pexExcept.InvalidParameterError):
afwMath.stringToInterpStyle(name)
else:
style = afwMath.stringToInterpStyle(name)
self.assertEqual(style, desiredStyle)

for badName in ("BOGUS", ""):
with self.assertRaises(pexExcept.InvalidParameterError):
afwMath.stringToInterpStyle(badName)


class TestMemory(lsst.utils.tests.MemoryTestCase):
pass
Expand Down