Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-9595: Allow Transform to return its inverse #192

Merged
merged 4 commits into from
Mar 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 40 additions & 8 deletions include/lsst/afw/geom/Transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ and an astshim::FrameSet or astshim::Mapping to specify the transformation.
In the case of a FrameSet the transformation is from the `BASE` frame to the `CURRENT` frame.
The endpoints convert the data between the LSST Form (e.g. Point2D) and the form used by astshim.

Depending on the astshim::FrameSet or astshim::Mapping used to define it, a Transform may
provide either a forward transform, an inverse transform, or both. In particular, the
@ref getInverse "inverse" of a forward-only transform is an inverse-only transform. The
@ref hasForward and @ref hasInverse methods can be used to check which transforms are available.

Unless otherwise stated, all constructors and methods may throw `std::runtime_error` to indicate
internal errors within AST.

@note You gain some safety by constructing a Transform from an astshim::FrameSet,
since the base and current frames in the FrameSet can be checked against by the appropriate endpoint.

Expand All @@ -61,8 +69,8 @@ class Transform {

Transform(Transform const &) = delete;
Transform(Transform &&) = default;
Transform & operator=(Transform const &) = delete;
Transform & operator=(Transform &&) = default;
Transform &operator=(Transform const &) = delete;
Transform &operator=(Transform &&) = default;

/**
Construct a Transform from a deep copy of an ast::Mapping
Expand All @@ -74,7 +82,7 @@ class Transform {
@param[in] simplify Simplify the mapping? This combines component mappings
and removes redundant components where possible.
*/
explicit Transform(ast::Mapping const &mapping, bool simplify=true);
explicit Transform(ast::Mapping const &mapping, bool simplify = true);

/**
Constructor a Transform from a deep copy of a FrameSet.
Expand All @@ -96,10 +104,24 @@ class Transform {
redundant components where possible. However it
does not remove any frames.
*/
explicit Transform(ast::FrameSet const & frameSet, bool simplify=true);
explicit Transform(ast::FrameSet const &frameSet, bool simplify = true);

~Transform(){};

/**
* Test if this method has a forward transform.
*
* @exceptsafe Provides basic exception safety.
*/
bool hasForward() const { return _frameSet->hasForward(); }

/**
* Test if this method has an inverse transform.
*
* @exceptsafe Provides basic exception safety.
*/
bool hasInverse() const { return _frameSet->hasInverse(); }

/**
Get the "from" endpoint
*/
Expand All @@ -118,7 +140,7 @@ class Transform {
/**
Transform one point in the forward direction ("from" to "to")
*/
ToPoint tranForward(FromPoint const & point) const;
ToPoint tranForward(FromPoint const &point) const;

/**
Transform an array of points in the forward direction ("from" to "to")
Expand All @@ -128,12 +150,22 @@ class Transform {
/**
Transform one point in the inverse direction ("to" to "from")
*/
FromPoint tranInverse(ToPoint const & point) const;
FromPoint tranInverse(ToPoint const &point) const;

/**
Transform an array of points in the inverse direction ("to" to "from")
*/
FromArray tranInverse(ToArray const & array) const;
FromArray tranInverse(ToArray const &array) const;

/**
* The inverse of this Transform.
*
* @returns a Transform whose `tranForward` is equivalent to this Transform's
* `tranInverse`, and vice versa.
*
* @exceptsafe Provides basic exception safety.
*/
Transform<ToEndpoint, FromEndpoint> getInverse() const;

private:
FromEndpoint const _fromEndpoint;
Expand All @@ -149,7 +181,7 @@ where _fromEndpoint_ and _toEndpoint_ are the appropriate endpoint printed to th
for example "Transform<GenericEndpoint(4), Point3Endpoint()>"
*/
template <typename FromEndpoint, typename ToEndpoint>
std::ostream & operator<<(std::ostream & os, Transform<FromEndpoint, ToEndpoint> const & transform);
std::ostream &operator<<(std::ostream &os, Transform<FromEndpoint, ToEndpoint> const &transform);

} // geom
} // afw
Expand Down
27 changes: 26 additions & 1 deletion python/lsst/afw/geom/endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <ostream>
#include <memory>
#include <string>
#include <typeinfo>

#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
Expand Down Expand Up @@ -82,12 +83,35 @@ template <typename PyClass>
void addMakeFrame(PyClass& cls) {
using Class = typename PyClass::type; // C++ class associated with pybind11 wrapper class
// return a deep copy so Python cannot modify the internal state
cls.def("makeFrame", [](Class const & self) {
cls.def("makeFrame", [](Class const& self) {
auto frame = self.makeFrame();
return frame->copy();
});
}

// Comparison of different Endpoints useful in Python but counterproductive
// in C++: point2Endpoint == spherePoint should not compile instead of
// returning `false`. Therefore, implemented only on the Python side.
// Two Endpoints are defined to be equal if and only if they have the same
// implementation type and the same number of dimensions
template <typename SelfClass, typename OtherClass, typename PyClass>
void addEquals(PyClass& cls) {
auto pyEquals = [](SelfClass const& self, OtherClass const& other) {
return self.getNAxes() == other.getNAxes() && typeid(self) == typeid(other);
};
cls.def("__eq__", pyEquals);
cls.def("__ne__",
[pyEquals](SelfClass const& self, OtherClass const& other) { return !pyEquals(self, other); });
}

template <typename SelfClass, typename PyClass>
void addAllEquals(PyClass& cls) {
addEquals<SelfClass, GenericEndpoint>(cls);
addEquals<SelfClass, PointEndpoint<2>>(cls);
addEquals<SelfClass, PointEndpoint<3>>(cls);
addEquals<SelfClass, SpherePointEndpoint>(cls);
}

/*
* Declare BaseVectorEndpoint<Point, Array>;
* this is meant to be called by other `declare...` functions;
Expand All @@ -103,6 +127,7 @@ void declareBaseEndpoint(py::module& mod, std::string const& suffix) {
addDataConverters(cls);
addMakeFrame(cls);
cls.def("normalizeFrame", &Class::normalizeFrame);
addAllEquals<Class>(cls);
}

// Declare BaseVectorEndpoint and all subclasses (the corresponding BaseEndpoint)
Expand Down
31 changes: 16 additions & 15 deletions python/lsst/afw/geom/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ namespace {

// Return a string consisting of "_pythonClassName_[_fromNAxes_->_toNAxes_]",
// for example "TransformGenericToPoint3[4->3]"
template<typename Class>
std::string formatStr(Class const & self, std::string const & pyClassName) {
template <typename Class>
std::string formatStr(Class const &self, std::string const &pyClassName) {
std::ostringstream os;
os << pyClassName;
auto const frameSet = self.getFrameSet();
Expand All @@ -55,7 +55,7 @@ std::string formatStr(Class const & self, std::string const & pyClassName) {
// where <X> and <Y> are the name of the from endpoint and to endpoint class, respectively,
// for example TransformFromGenericToPoint3
template <typename FromEndpoint, typename ToEndpoint>
void declareTransform(py::module& mod, std::string const & fromName, std::string const & toName) {
void declareTransform(py::module &mod, std::string const &fromName, std::string const &toName) {
using Class = Transform<FromEndpoint, ToEndpoint>;
using ToPoint = typename ToEndpoint::Point;
using ToArray = typename ToEndpoint::Array;
Expand All @@ -66,25 +66,26 @@ void declareTransform(py::module& mod, std::string const & fromName, std::string

py::class_<Class, std::shared_ptr<Class>> cls(mod, pyClassName.c_str());

cls.def(py::init<ast::FrameSet const &, bool>(), "frameSet"_a, "simplify"_a=true);
cls.def(py::init<ast::Mapping const &, bool>(), "mapping"_a, "simplify"_a=true);
cls.def(py::init<ast::FrameSet const &, bool>(), "frameSet"_a, "simplify"_a = true);
cls.def(py::init<ast::Mapping const &, bool>(), "mapping"_a, "simplify"_a = true);

cls.def("hasForward", &Class::hasForward);
cls.def("hasInverse", &Class::hasInverse);

cls.def("getFromEndpoint", &Class::getFromEndpoint);
cls.def("getFrameSet", &Class::getFrameSet);
cls.def("getToEndpoint", &Class::getToEndpoint);

cls.def("tranForward", (ToArray (Class::*)(FromArray const &) const) &Class::tranForward, "array"_a);
cls.def("tranForward", (ToPoint (Class::*)(FromPoint const &) const) &Class::tranForward, "point"_a);
cls.def("tranInverse", (FromArray (Class::*)(ToArray const &) const) &Class::tranInverse, "array"_a);
cls.def("tranInverse", (FromPoint (Class::*)(ToPoint const &) const) &Class::tranInverse, "point"_a);
cls.def("tranForward", (ToArray (Class::*)(FromArray const &) const) & Class::tranForward, "array"_a);
cls.def("tranForward", (ToPoint (Class::*)(FromPoint const &) const) & Class::tranForward, "point"_a);
cls.def("tranInverse", (FromArray (Class::*)(ToArray const &) const) & Class::tranInverse, "array"_a);
cls.def("tranInverse", (FromPoint (Class::*)(ToPoint const &) const) & Class::tranInverse, "point"_a);
cls.def("getInverse", &Class::getInverse);
// str(self) = "<Python class name>[<nIn>-><nOut>]"
cls.def("__str__", [pyClassName](Class const & self) {
return formatStr(self, pyClassName);
});
cls.def("__str__", [pyClassName](Class const &self) { return formatStr(self, pyClassName); });
// repr(self) = "lsst.afw.geom.<Python class name>[<nIn>-><nOut>]"
cls.def("__repr__", [pyClassName](Class const & self) {
return "lsst.afw.geom." + formatStr(self, pyClassName);
});
cls.def("__repr__",
[pyClassName](Class const &self) { return "lsst.afw.geom." + formatStr(self, pyClassName); });
}

PYBIND11_PLUGIN(transform) {
Expand Down
18 changes: 17 additions & 1 deletion python/lsst/afw/geom/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@

import lsst.utils.tests
from .angle import arcseconds
from .endpoint import GenericEndpoint, Point2Endpoint, Point3Endpoint, SpherePointEndpoint


__all__ = ["assertAnglesNearlyEqual", "assertPairsNearlyEqual", "assertBoxesNearlyEqual"]
__all__ = ["assertAnglesNearlyEqual", "assertPairsNearlyEqual",
"assertBoxesNearlyEqual", "makeEndpoints"]

@lsst.utils.tests.inTestCase
def assertAnglesNearlyEqual(testCase, ang0, ang1, maxDiff=0.001*arcseconds,
Expand Down Expand Up @@ -99,3 +101,17 @@ def assertBoxesNearlyEqual(testCase, box0, box1, maxDiff=1e-7, msg="Boxes differ
"""
assertPairsNearlyEqual(testCase, box0.getMin(), box1.getMin(), maxDiff=maxDiff, msg=msg + ": min")
assertPairsNearlyEqual(testCase, box0.getMax(), box1.getMax(), maxDiff=maxDiff, msg=msg + ": max")


@lsst.utils.tests.inTestCase
def makeEndpoints(testCase):
"""Generate a representative sample of Endpoints.

Returns
-------
x : `list`
List of endpoints with enough diversity to exercise Endpoint-related
code. Each invocation of this method shall return independent objects.
"""
return [GenericEndpoint(n) for n in range(1, 6)] + \
[Point2Endpoint(), Point3Endpoint(), SpherePointEndpoint()]
25 changes: 19 additions & 6 deletions src/geom/Transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include <memory>
#include <ostream>
#include <sstream>
#include <vector>

#include "astshim.h"
Expand Down Expand Up @@ -54,8 +55,8 @@ Transform<FromEndpoint, ToEndpoint>::Transform(ast::FrameSet const &frameSet, bo
// and normalize the frame set as a frame (i.e. normalize the frame "in situ").
// The obvious alternative of normalizing a shallow copy of the frame does not work;
// the frame is altered but not the associated mapping!
auto frameSetCopy = simplify ? std::dynamic_pointer_cast<ast::FrameSet>(frameSet.simplify())
: frameSet.copy();
auto frameSetCopy =
simplify ? std::dynamic_pointer_cast<ast::FrameSet>(frameSet.simplify()) : frameSet.copy();

// Normalize the current frame by normalizing the frameset as a frame
_toEndpoint.normalizeFrame(frameSetCopy);
Expand Down Expand Up @@ -104,17 +105,29 @@ typename FromEndpoint::Array Transform<FromEndpoint, ToEndpoint>::tranInverse(
return _fromEndpoint.arrayFromData(rawToData);
}

template <typename FromEndpoint, typename ToEndpoint>
Transform<ToEndpoint, FromEndpoint> Transform<FromEndpoint, ToEndpoint>::getInverse() const {
auto inverse = std::dynamic_pointer_cast<ast::FrameSet>(_frameSet->getInverse());
if (!inverse) {
// don't throw std::bad_cast because it doesn't let you provide debugging info
std::ostringstream buffer;
buffer << "FrameSet.getInverse() does not return a FrameSet. Called from: " << _frameSet;
throw std::logic_error(buffer.str());
}
return Transform<ToEndpoint, FromEndpoint>(*inverse);
}

template <typename FromEndpoint, typename ToEndpoint>
std::ostream &operator<<(std::ostream &os, Transform<FromEndpoint, ToEndpoint> const &transform) {
auto const frameSet = transform.getFrameSet();
os << "Transform<" << transform.getFromEndpoint() << ", " << transform.getToEndpoint() << ">";
return os;
};

#define INSTANTIATE_TRANSFORM(FromEndpoint, ToEndpoint) \
template class Transform<FromEndpoint, ToEndpoint>; \
template std::ostream &operator<< <FromEndpoint, ToEndpoint> \
(std::ostream &os, Transform<FromEndpoint, ToEndpoint> const &transform);
#define INSTANTIATE_TRANSFORM(FromEndpoint, ToEndpoint) \
template class Transform<FromEndpoint, ToEndpoint>; \
template std::ostream &operator<<<FromEndpoint, ToEndpoint>( \
std::ostream &os, Transform<FromEndpoint, ToEndpoint> const &transform);

// explicit instantiations
INSTANTIATE_TRANSFORM(GenericEndpoint, GenericEndpoint);
Expand Down
12 changes: 12 additions & 0 deletions tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ def checkEndpointBasics(self, endpoint, pointType, nAxes):
pointDataRoundTrip = endpoint.dataFromPoint(point)
assert_allclose(pointData, pointDataRoundTrip, err_msg=baseMsg)

def testEndpointEquals(self):
"""Test Endpoint == Endpoint
"""
for i1, point1 in enumerate(self.makeEndpoints()):
for i2, point2 in enumerate(self.makeEndpoints()):
if i1 == i2:
self.assertTrue(point1 == point2)
self.assertFalse(point1 != point2)
else:
self.assertFalse(point1 == point2)
self.assertTrue(point1 != point2)


class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass
Expand Down