Skip to content

Commit

Permalink
Add pickle and __eq__ support to SchemaMapper
Browse files Browse the repository at this point in the history
Add pickle support for Key and SchemaItem, including Flags

It was much easier to implement == at the python level for SchemaMapper,
and it probably isn't needed in C++.
  • Loading branch information
parejkoj committed Mar 8, 2019
1 parent 0f85081 commit 0f415aa
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 7 deletions.
30 changes: 29 additions & 1 deletion python/lsst/afw/table/schema/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "lsst/utils/python.h"

#include "lsst/afw/fits.h"
#include "lsst/afw/table/detail/Access.h"
#include "lsst/afw/table/Schema.h"
#include "lsst/afw/table/BaseRecord.h"
#include "lsst/afw/table/SchemaMapper.h"
Expand Down Expand Up @@ -150,13 +151,31 @@ void declareKeySpecializations(PyKey<T> &cls) {
declareKeyAccessors(cls);
cls.def_property_readonly("subfields", [](py::object const &) { return py::none(); });
cls.def_property_readonly("subkeys", [](py::object const &) { return py::none(); });
cls.def(py::pickle(
[](Key<T> const &self) {
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(self.getOffset());
},
[](py::tuple t) {
if (t.size() != 1) throw std::runtime_error("Invalid number of parameters when unpickling!");
return detail::Access::makeKey<T>(t[0].cast<int>());
}));
}

void declareKeySpecializations(PyKey<Flag> &cls) {
declareKeyAccessors(cls);
cls.def_property_readonly("subfields", [](py::object const &) { return py::none(); });
cls.def_property_readonly("subkeys", [](py::object const &) { return py::none(); });
cls.def("getBit", &Key<Flag>::getBit);
cls.def(py::pickle(
[](Key<Flag> const &self) {
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(self.getOffset(), self.getBit());
},
[](py::tuple t) {
if (t.size() != 2) throw std::runtime_error("Invalid number of parameters when unpickling!");
return detail::Access::makeKey(t[0].cast<int>(), t[1].cast<int>());
}));
}

template <typename U>
Expand Down Expand Up @@ -281,7 +300,16 @@ void declareSchemaType(py::module &mod) {
clsSchemaItem.def("__repr__", [](py::object const &self) -> py::str {
return py::str("SchemaItem(key={0.key}, field={0.field})").format(self);
});
}
clsSchemaItem.def(py::pickle(
[](SchemaItem<T> const &self) {
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(self.key, self.field);
},
[](py::tuple t) {
if (t.size() != 2) throw std::runtime_error("Invalid number of parameters when unpickling!");
return SchemaItem<T>(t[0].cast<Key<T>>(), t[1].cast<Field<T>>());
}));
} // namespace

// Helper class for Schema::find(name, func) that converts the result to Python.
// In C++14, this should be converted to a universal lambda.
Expand Down
12 changes: 7 additions & 5 deletions python/lsst/afw/table/schemaMapper/schemaMapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ using PySchemaMapper = py::class_<SchemaMapper, std::shared_ptr<SchemaMapper>>;

template <typename T>
void declareSchemaMapperOverloads(PySchemaMapper &cls, std::string const &suffix) {
cls.def("getMapping", (Key<T> (SchemaMapper::*)(Key<T> const &) const) & SchemaMapper::getMapping);
cls.def("getMapping", (Key<T>(SchemaMapper::*)(Key<T> const &) const) & SchemaMapper::getMapping);
cls.def("isMapped", (bool (SchemaMapper::*)(Key<T> const &) const) & SchemaMapper::isMapped);
};

PYBIND11_MODULE(schemaMapper, mod) {
Expand All @@ -59,6 +60,7 @@ PYBIND11_MODULE(schemaMapper, mod) {
cls.def("addMinimalSchema", &SchemaMapper::addMinimalSchema, "minimal"_a, "doMap"_a = true);
cls.def_static("removeMinimalSchema", &SchemaMapper::removeMinimalSchema);
cls.def_static("join", &SchemaMapper::join, "inputs"_a, "prefixes"_a = std::vector<std::string>());

declareSchemaMapperOverloads<std::uint8_t>(cls, "B");
declareSchemaMapperOverloads<std::uint16_t>(cls, "U");
declareSchemaMapperOverloads<std::int32_t>(cls, "I");
Expand All @@ -74,7 +76,7 @@ PYBIND11_MODULE(schemaMapper, mod) {
declareSchemaMapperOverloads<lsst::afw::table::Array<float>>(cls, "ArrayF");
declareSchemaMapperOverloads<lsst::afw::table::Array<double>>(cls, "ArrayD");
}
}
}
}
} // namespace lsst::afw::table::<anonymous>
} // namespace
} // namespace table
} // namespace afw
} // namespace lsst
64 changes: 63 additions & 1 deletion python/lsst/afw/table/schemaMapper/schemaMapperContinued.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

__all__ = [] # import only for the side effects

import lsst.pex.exceptions
from lsst.utils import continueClass

from ..schema import Field
from ..schema import Field, Schema
from .schemaMapper import SchemaMapper


Expand Down Expand Up @@ -85,3 +86,64 @@ def addMapping(self, input, output=None, doReplace=True):
doReplace = output
output = None
return input._addMappingTo(self, output, doReplace)

def __eq__(self, other):
"""SchemaMappers are equal if their respective input and output
schemas are identical, and they have the same mappings defined.
Note: It was simpler to implement equality in python than in C++.
"""
iSchema = self.getInputSchema()
oSchema = self.getOutputSchema()
if (not (iSchema.compare(other.getInputSchema(), Schema.IDENTICAL) == Schema.IDENTICAL and
oSchema.compare(other.getOutputSchema(), Schema.IDENTICAL) == Schema.IDENTICAL)):
return False

for item in iSchema:
if self.isMapped(item.key) and other.isMapped(item.key):
if (self.getMapping(item.key) == other.getMapping(item.key)):
continue
else:
return False
elif (not self.isMapped(item.key)) and (not other.isMapped(item.key)):
continue
else:
return False

return True

def __reduce__(self):
"""To support pickle."""
mappings = {}
for item in self.getInputSchema():
try:
key = self.getMapping(item.key)
except lsst.pex.exceptions.NotFoundError:
# Not all fields may be mapped, so just continue if a mapping is not found.
continue
mappings[item.key] = self.getOutputSchema().find(key).field
return (makeSchemaMapper, (self.getInputSchema(), self.getOutputSchema(), mappings))


def makeSchemaMapper(input, output, mappings):
"""Build a mapper from two Schemas and the mapping between them.
For pickle support.
Parameters
----------
input : `lsst.afw.table.Schema`
The input schema for the mapper.
output : `lsst.afw.table.Schema`
The output schema for the mapper.
mappings : `dict` [`lsst.afw.table.Key`, `lsst.afw.table.Key`]
The mappings to define between the input and output schema.
Returns
-------
mapper : `lsst.afw.table.SchemaMapper`
The constructed SchemaMapper.
"""
mapper = SchemaMapper(input, output)
for key, value in mappings.items():
mapper.addMapping(key, value)
return mapper
129 changes: 129 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
import lsst.afw.table


def _checkSchemaIdentical(schema1, schema2):
return schema1.compare(schema2, lsst.afw.table.Schema.IDENTICAL) == lsst.afw.table.Schema.IDENTICAL


class SchemaTestCase(unittest.TestCase):

def testSchema(self):
Expand Down Expand Up @@ -170,6 +174,7 @@ def testPickle(self):
schema.addField("y", type="D", doc="position_y")
schema.addField("i", type="I", doc="int")
schema.addField("f", type="F", doc="float", units="m2")
schema.addField("flag", type="Flag", doc="a flag")
pickled = pickle.dumps(schema, protocol=pickle.HIGHEST_PROTOCOL)
unpickled = pickle.loads(pickled)
self.assertEqual(schema, unpickled)
Expand Down Expand Up @@ -291,6 +296,130 @@ def testJoin2(self):
self.assertEqual(s1.join("a", "b", "c"), "a_b_c")
self.assertEqual(s1.join("a", "b", "c", "d"), "a_b_c_d")

def _makeMapper(self, name1="a", name2="bb", name3="ccc", name4="dddd"):
"""Make a SchemaMapper for testing.
Parameters
----------
name1 : `str`, optional
Name of a field to "default map" from input to output.
name2 : `str`, optional
Name of a field to map from input to ``name3`` in output.
name3 : `str`, optional
Name of a field that is unmapped in input, and mapped from
``name2`` in output.
name4 : `str`, optional
Name of a field that is unmapped in output.
Returns
-------
mapper : `lsst.afw.table.SchemaMapper`
The created mapper.
"""
schema = lsst.afw.table.Schema()
schema.addField(name1, type=float)
schema.addField(name2, type=float)
schema.addField(name3, type=float)
schema.addField("asdf", type="Flag")
mapper = lsst.afw.table.SchemaMapper(schema)

# add a default mapping for the first field
mapper.addMapping(schema.find(name1).key)

# add a mapping to a new field for the second field
field = lsst.afw.table.Field[float](name3, "doc for thingy")
mapper.addMapping(schema.find(name2).key, field)

# add a totally separate field to the output
mapper.addOutputField(lsst.afw.table.Field[float](name4, 'docstring'))
return mapper

def testOperatorEquals(self):
mapper1 = self._makeMapper()
mapper2 = self._makeMapper()
self.assertEqual(mapper1, mapper2)

def testNotEqualInput(self):
"""Check that differing input schema compare not equal."""
mapper1 = self._makeMapper(name2="somethingelse")
mapper2 = self._makeMapper()
# output schema should still be equal
self.assertTrue(_checkSchemaIdentical(mapper1.getOutputSchema(), mapper2.getOutputSchema()))
self.assertNotEqual(mapper1, mapper2)

def testNotEqualOutput(self):
"""Check that differing output schema compare not equal."""
mapper1 = self._makeMapper(name4="another")
mapper2 = self._makeMapper()
# input schema should still be equal
self.assertTrue(_checkSchemaIdentical(mapper1.getInputSchema(), mapper2.getInputSchema()))
self.assertNotEqual(mapper1, mapper2)

def testNotEqualMappings(self):
"""Check that differing mappings but same schema compare not equal."""
schema = lsst.afw.table.Schema()
schema.addField('a', type=np.int32, doc="int")
schema.addField('b', type=np.int32, doc="int")
mapper1 = lsst.afw.table.SchemaMapper(schema)
mapper2 = lsst.afw.table.SchemaMapper(schema)
mapper1.addMapping(schema['a'].asKey(), 'c')
mapper1.addMapping(schema['b'].asKey(), 'd')
mapper2.addMapping(schema['b'].asKey(), 'c')
mapper2.addMapping(schema['a'].asKey(), 'd')

# input and output schemas should still be equal
self.assertTrue(_checkSchemaIdentical(mapper1.getInputSchema(), mapper2.getInputSchema()))
self.assertTrue(_checkSchemaIdentical(mapper1.getOutputSchema(), mapper2.getOutputSchema()))
self.assertNotEqual(mapper1, mapper2)

def testNotEqualMappingsSomeFieldsUnmapped(self):
"""Check that differing mappings, with some unmapped fields, but the
same input and output schema compare not equal.
"""
schema = lsst.afw.table.Schema()
schema.addField('a', type=np.int32, doc="int")
schema.addField('b', type=np.int32, doc="int")
mapper1 = lsst.afw.table.SchemaMapper(schema)
mapper2 = lsst.afw.table.SchemaMapper(schema)
mapper1.addMapping(schema['a'].asKey(), 'c')
mapper1.addMapping(schema['b'].asKey(), 'd')
mapper2.addMapping(schema['b'].asKey(), 'c')
# add an unmapped field to output of 2 to match 1
mapper2.addOutputField(lsst.afw.table.Field[np.int32]('d', doc="int"))

# input and output schemas should still be equal
self.assertTrue(_checkSchemaIdentical(mapper1.getInputSchema(), mapper2.getInputSchema()))
self.assertTrue(_checkSchemaIdentical(mapper1.getOutputSchema(), mapper2.getOutputSchema()))
self.assertNotEqual(mapper1, mapper2)

def testPickle(self):
schema = lsst.afw.table.Schema()
schema.addField("ra", type="Angle", doc="coord_ra")
schema.addField("dec", type="Angle", doc="coord_dec")
schema.addField("x", type="D", doc="position_x", units="pixel")
schema.addField("y", type="D", doc="position_y")
schema.addField("i", type="I", doc="int")
schema.addField("f", type="F", doc="float", units="m2")
schema.addField("flag", type="Flag", doc="a flag")
mapper = lsst.afw.table.SchemaMapper(schema)
mapper.addMinimalSchema(schema)
inKey = schema.addField("bb", type=float)
outField = lsst.afw.table.Field[float]("cc", "doc for bb->cc")
mapper.addMapping(inKey, outField, True)

pickled = pickle.dumps(mapper, protocol=pickle.HIGHEST_PROTOCOL)
unpickled = pickle.loads(pickled)
self.assertEqual(mapper, unpickled)

def testPickleMissingInput(self):
"""Test pickling with some fields not being mapped."""
mapper = self._makeMapper()

pickled = pickle.dumps(mapper, protocol=pickle.HIGHEST_PROTOCOL)
unpickled = pickle.loads(pickled)

self.assertEqual(mapper, unpickled)


class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass
Expand Down

0 comments on commit 0f415aa

Please sign in to comment.