Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-25377: Rewrite serialization to allow access to byte form #43

Merged
merged 1 commit into from
Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
73 changes: 58 additions & 15 deletions python/lsst/base/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,10 @@ class Packages:
This is essentially a wrapper around a dict with some conveniences.
"""

formats = {".pkl": "pickle",
".pickle": "pickle",
".yaml": "yaml"}

def __init__(self, packages):
assert isinstance(packages, Mapping)
self._packages = packages
Expand All @@ -279,6 +283,27 @@ def fromSystem(cls):
packages.update(getEnvironmentPackages()) # Should be last, to override products with LOCAL versions
return cls(packages)

@classmethod
def fromBytes(cls, data, format):
"""Construct the object from a byte representation.

Parameters
----------
data : `bytes`
The serialized form of this object in bytes.
format : `str`
The format of those bytes. Can be ``yaml`` or ``pickle``.
"""
if format == "pickle":
new = pickle.loads(data)
elif format == "yaml":
new = yaml.load(data, Loader=yaml.SafeLoader)
else:
raise ValueError(f"Unexpected serialization format given: {format}")
if not isinstance(new, cls):
raise TypeError(f"Extracted object of class '{type(new)}' but expected '{cls}'")
return new

@classmethod
def read(cls, filename):
"""Read packages from filename.
Expand All @@ -295,14 +320,34 @@ def read(cls, filename):
packages : `Packages`
"""
_, ext = os.path.splitext(filename)
if ext in (".pickle", ".pkl"):
with open(filename, "rb") as ff:
return pickle.load(ff)
elif ext == ".yaml":
with open(filename, "r") as ff:
return yaml.load(ff, Loader=yaml.SafeLoader)
if ext not in cls.formats:
raise ValueError(f"Format from {ext} extension in file {filename} not recognized")
with open(filename, "rb") as ff:
# We assume that these classes are tiny so there is no
# substantive memory impact by reading the entire file up front
data = ff.read()
return cls.fromBytes(data, cls.formats[ext])

def toBytes(self, format):
"""Convert the object to a serialized bytes form using the
specified format.

Parameters
----------
format : `str`
Format to use when serializing. Can be ``yaml`` or ``pickle``.

Returns
-------
data : `bytes`
Byte string representing the serialized object.
"""
if format == "pickle":
return pickle.dumps(self)
elif format == "yaml":
return yaml.dump(self).encode("utf-8")
else:
raise ValueError(f"Unable to determine how to read file {filename} from extension {ext}")
raise ValueError(f"Unexpected serialization format requested: {format}")

def write(self, filename):
"""Write to file.
Expand All @@ -315,14 +360,12 @@ def write(self, filename):
``.pickle`` and ``.yaml``
"""
_, ext = os.path.splitext(filename)
if ext in (".pickle", ".pkl"):
with open(filename, "wb") as ff:
pickle.dump(self, ff)
elif ext == ".yaml":
with open(filename, "w") as ff:
yaml.dump(self, ff)
else:
raise ValueError(f"Unexpected file format requested: {ext}")
if ext not in self.formats:
raise ValueError(f"Format from {ext} extension in file {filename} not recognized")
with open(filename, "wb") as ff:
# Assumes that the bytes serialization of this object is
# relatively small.
ff.write(self.toBytes(self.formats[ext]))

def __len__(self):
return len(self._packages)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,22 @@ def testPackages(self):
self.assertDictEqual(packages.extra(new), {})
self.assertEqual(len(packages), len(new))

# Serialize via bytes
for format in ("pickle", "yaml"):
asbytes = new.toBytes(format)
from_bytes = lsst.base.Packages.fromBytes(asbytes, format)
self.assertEqual(from_bytes, new)

with self.assertRaises(ValueError):
new.toBytes("unknown_format")

with self.assertRaises(ValueError):
lsst.base.Packages.fromBytes(from_bytes, "unknown_format")

with self.assertRaises(TypeError):
some_yaml = b"list: [1, 2]"
lsst.base.Packages.fromBytes(some_yaml, "yaml")


if __name__ == "__main__":
unittest.main()