Skip to content

Commit

Permalink
Merge pull request #468 from lsst/tickets/DM-20019
Browse files Browse the repository at this point in the history
DM-20019: Fix pickling of String Fields
  • Loading branch information
parejkoj committed Jun 6, 2019
2 parents a1fda76 + 6ee9f7a commit 3287933
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 30 deletions.
9 changes: 9 additions & 0 deletions include/lsst/afw/table/detail/Access.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ class Access final {
/// @internal Access to the private Key constructor.
static Key<Flag> makeKey(int offset, int bit) { return Key<Flag>(offset, bit); }

/// @internal Access to the private Key constructor.
static Key<std::string> makeKeyString(int offset, int size) { return Key<std::string>(offset, size); }

/// @internal Access to the private Key constructor.
template <typename T>
static Key<Array<T>> makeKeyArray(int offset, int size) {
return Key<Array<T>>(offset, size);
}

/// @internal Add some padding to a schema without adding a field.
static void padSchema(Schema &schema, int bytes) {
schema._edit();
Expand Down
128 changes: 112 additions & 16 deletions python/lsst/afw/table/schema/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,55 @@ void declareFieldBaseSpecializations(PyFieldBase<std::string> &cls) {
cls.def("getSize", &FieldBase<std::string>::getSize);
}

// Specializations for Field

template <typename T>
void declareFieldSpecializations(PyField<T> &cls) {
cls.def(py::pickle(
[](Field<T> const &self) {
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(self.getName(), self.getDoc(), self.getUnits());
},
[](py::tuple t) {
int const NPARAMS = 3;
if (t.size() != NPARAMS) {
std::ostringstream os;
os << "Invalid number of parameters (" << t.size() << ") when unpickling; expected "
<< NPARAMS;
throw std::runtime_error(os.str());
}
return Field<T>(t[0].cast<std::string>(), t[1].cast<std::string>(), t[2].cast<std::string>());
}));
}

// Field<Array<T>> and Field<std::string> have the same pickle implementation
template <typename T>
void _sequenceFieldSpecializations(PyField<T> &cls) {
cls.def(py::pickle(
[](Field<T> const &self) {
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(self.getName(), self.getDoc(), self.getUnits(), self.getSize());
},
[](py::tuple t) {
int const NPARAMS = 4;
if (t.size() != NPARAMS) {
std::ostringstream os;
os << "Invalid number of parameters (" << t.size() << ") when unpickling; expected "
<< NPARAMS;
throw std::runtime_error(os.str());
}
return Field<T>(t[0].cast<std::string>(), t[1].cast<std::string>(), t[2].cast<std::string>(),
t[3].cast<int>());
}));
}

template <typename T>
void declareFieldSpecializations(PyField<Array<T>> &cls) {
_sequenceFieldSpecializations(cls);
}

void declareFieldSpecializations(PyField<std::string> &cls) { _sequenceFieldSpecializations(cls); }

// Specializations for KeyBase

template <typename T>
Expand Down Expand Up @@ -157,7 +206,13 @@ void declareKeySpecializations(PyKey<T> &cls) {
return py::make_tuple(self.getOffset());
},
[](py::tuple t) {
if (t.size() != 1) throw std::runtime_error("Invalid number of parameters when unpickling!");
int const NPARAMS = 1;
if (t.size() != NPARAMS) {
std::ostringstream os;
os << "Invalid number of parameters (" << t.size() << ") when unpickling; expected "
<< NPARAMS;
throw std::runtime_error(os.str());
}
return detail::Access::makeKey<T>(t[0].cast<int>());
}));
}
Expand All @@ -173,28 +228,70 @@ void declareKeySpecializations(PyKey<Flag> &cls) {
return py::make_tuple(self.getOffset(), self.getBit());
},
[](py::tuple t) {
if (t.size() != 2) throw std::runtime_error("Invalid number of parameters when unpickling!");
int const NPARAMS = 2;
if (t.size() != NPARAMS) {
std::ostringstream os;
os << "Invalid number of parameters (" << t.size() << ") when unpickling; expected "
<< NPARAMS;
throw std::runtime_error(os.str());
}
return detail::Access::makeKey(t[0].cast<int>(), t[1].cast<int>());
}));
}

template <typename U>
void declareKeySpecializations(PyKey<Array<U>> &cls) {
template <typename T>
void declareKeySpecializations(PyKey<Array<T>> &cls) {
declareKeyAccessors(cls);
cls.def_property_readonly("subfields", [](Key<Array<U>> const &self) -> py::object {
cls.def_property_readonly("subfields", [](Key<Array<T>> const &self) -> py::object {
py::list result;
for (int i = 0; i < self.getSize(); ++i) {
result.append(py::cast(i));
}
return py::tuple(result);
});
cls.def_property_readonly("subkeys", [](Key<Array<U>> const &self) -> py::object {
cls.def_property_readonly("subkeys", [](Key<Array<T>> const &self) -> py::object {
py::list result;
for (int i = 0; i < self.getSize(); ++i) {
result.append(py::cast(self[i]));
}
return py::tuple(result);
});
cls.def(py::pickle(
[](Key<Array<T>> const &self) {
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(self.getOffset(), self.getElementCount());
},
[](py::tuple t) {
int const NPARAMS = 2;
if (t.size() != NPARAMS) {
std::ostringstream os;
os << "Invalid number of parameters (" << t.size() << ") when unpickling; expected "
<< NPARAMS;
throw std::runtime_error(os.str());
}
return detail::Access::makeKeyArray<T>(t[0].cast<int>(), t[1].cast<int>());
}));
}

void declareKeySpecializations(PyKey<std::string> &cls) {
declareKeyAccessors(cls);
cls.def_property_readonly("subfields", [](py::object const &) { return py::none(); });
cls.def_property_readonly("subkeys", [](py::object const &) { return py::none(); });
cls.def(py::pickle(
[](Key<std::string> const &self) {
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(self.getOffset(), self.getElementCount());
},
[](py::tuple t) {
int const NPARAMS = 2;
if (t.size() != NPARAMS) {
std::ostringstream os;
os << "Invalid number of parameters (" << t.size() << ") when unpickling; expected "
<< NPARAMS;
throw std::runtime_error(os.str());
}
return detail::Access::makeKeyString(t[0].cast<int>(), t[1].cast<int>());
}));
}

// Wrap all helper classes (FieldBase, KeyBase, Key, Field, SchemaItem) declarefor a Schema field type.
Expand All @@ -217,6 +314,8 @@ void declareSchemaType(py::module &mod) {

// Field
PyField<T> clsField(mod, ("Field" + suffix).c_str());
declareFieldSpecializations(clsField);

mod.attr("_Field")[pySuffix] = clsField;
clsField.def(py::init([astropyUnit]( // capture by value to refcount in Python instead of dangle in C++
std::string const &name, std::string const &doc, py::str const &units,
Expand All @@ -240,15 +339,6 @@ void declareSchemaType(py::module &mod) {
clsField.def("copyRenamed", &Field<T>::copyRenamed);
utils::python::addOutputOp(clsField, "__str__");
utils::python::addOutputOp(clsField, "__repr__");
clsField.def(py::pickle(
[](Field<T> const &self) {
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(self.getName(), self.getDoc(), self.getUnits());
},
[](py::tuple t) {
if (t.size() != 3) throw std::runtime_error("Invalid number of parameters when unpickling!");
return Field<T>(t[0].cast<std::string>(), t[1].cast<std::string>(), t[2].cast<std::string>());
}));

// Key
PyKey<T> clsKey(mod, ("Key" + suffix).c_str());
Expand Down Expand Up @@ -306,7 +396,13 @@ void declareSchemaType(py::module &mod) {
return py::make_tuple(self.key, self.field);
},
[](py::tuple t) {
if (t.size() != 2) throw std::runtime_error("Invalid number of parameters when unpickling!");
int const NPARAMS = 2;
if (t.size() != NPARAMS) {
std::ostringstream os;
os << "Invalid number of parameters (" << t.size() << ") when unpickling; expected "
<< NPARAMS;
throw std::runtime_error(os.str());
}
return SchemaItem<T>(t[0].cast<Key<T>>(), t[1].cast<Field<T>>());
}));
} // namespace
Expand Down
33 changes: 19 additions & 14 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@ def _checkSchemaIdentical(schema1, schema2):
return schema1.compare(schema2, lsst.afw.table.Schema.IDENTICAL) == lsst.afw.table.Schema.IDENTICAL


def addTestFields(schema):
"""Add Fields to the schema to test operations on each Field type.
"""
schema.addField("ra", type="Angle", doc="coord_ra")
schema.addField("dec", type="Angle", doc="coord_dec")
schema.addField("x", type="D", doc="position_x", units="pixel")
schema.addField("y", type="D", doc="position_y")
schema.addField("i", type="I", doc="int")
schema.addField("f", type="F", doc="float", units="m2")
schema.addField("flag", type="Flag", doc="a flag")
schema.addField("string", type="String", doc="A string field", size=42)
schema.addField("variable_string", type="String", doc="A variable-length string field", size=0)
schema.addField("array", type="ArrayF", doc="An array field", size=10)
schema.addField("variable_array", type="ArrayF", doc="A variable-length array field", size=0)


class SchemaTestCase(unittest.TestCase):

def testSchema(self):
Expand Down Expand Up @@ -177,13 +193,8 @@ def testComparison(self):

def testPickle(self):
schema = lsst.afw.table.Schema()
schema.addField("ra", type="Angle", doc="coord_ra")
schema.addField("dec", type="Angle", doc="coord_dec")
schema.addField("x", type="D", doc="position_x", units="pixel")
schema.addField("y", type="D", doc="position_y")
schema.addField("i", type="I", doc="int")
schema.addField("f", type="F", doc="float", units="m2")
schema.addField("flag", type="Flag", doc="a flag")
addTestFields(schema)

pickled = pickle.dumps(schema, protocol=pickle.HIGHEST_PROTOCOL)
unpickled = pickle.loads(pickled)
self.assertEqual(schema, unpickled)
Expand Down Expand Up @@ -403,13 +414,7 @@ def testNotEqualMappingsSomeFieldsUnmapped(self):

def testPickle(self):
schema = lsst.afw.table.Schema()
schema.addField("ra", type="Angle", doc="coord_ra")
schema.addField("dec", type="Angle", doc="coord_dec")
schema.addField("x", type="D", doc="position_x", units="pixel")
schema.addField("y", type="D", doc="position_y")
schema.addField("i", type="I", doc="int")
schema.addField("f", type="F", doc="float", units="m2")
schema.addField("flag", type="Flag", doc="a flag")
addTestFields(schema)
mapper = lsst.afw.table.SchemaMapper(schema)
mapper.addMinimalSchema(schema)
inKey = schema.addField("bb", type=float)
Expand Down

0 comments on commit 3287933

Please sign in to comment.