Skip to content

Commit

Permalink
Merge pull request #43 from lsst/tickets/DM-25377
Browse files Browse the repository at this point in the history
DM-25377: Rewrite serialization to allow access to byte form
  • Loading branch information
timj committed Jun 12, 2020
2 parents 04f1aa2 + cd52125 commit 082faa5
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 15 deletions.
73 changes: 58 additions & 15 deletions python/lsst/base/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,10 @@ class Packages:
This is essentially a wrapper around a dict with some conveniences.
"""

formats = {".pkl": "pickle",
".pickle": "pickle",
".yaml": "yaml"}

def __init__(self, packages):
assert isinstance(packages, Mapping)
self._packages = packages
Expand All @@ -279,6 +283,27 @@ def fromSystem(cls):
packages.update(getEnvironmentPackages()) # Should be last, to override products with LOCAL versions
return cls(packages)

@classmethod
def fromBytes(cls, data, format):
"""Construct the object from a byte representation.
Parameters
----------
data : `bytes`
The serialized form of this object in bytes.
format : `str`
The format of those bytes. Can be ``yaml`` or ``pickle``.
"""
if format == "pickle":
new = pickle.loads(data)
elif format == "yaml":
new = yaml.load(data, Loader=yaml.SafeLoader)
else:
raise ValueError(f"Unexpected serialization format given: {format}")
if not isinstance(new, cls):
raise TypeError(f"Extracted object of class '{type(new)}' but expected '{cls}'")
return new

@classmethod
def read(cls, filename):
"""Read packages from filename.
Expand All @@ -295,14 +320,34 @@ def read(cls, filename):
packages : `Packages`
"""
_, ext = os.path.splitext(filename)
if ext in (".pickle", ".pkl"):
with open(filename, "rb") as ff:
return pickle.load(ff)
elif ext == ".yaml":
with open(filename, "r") as ff:
return yaml.load(ff, Loader=yaml.SafeLoader)
if ext not in cls.formats:
raise ValueError(f"Format from {ext} extension in file {filename} not recognized")
with open(filename, "rb") as ff:
# We assume that these classes are tiny so there is no
# substantive memory impact by reading the entire file up front
data = ff.read()
return cls.fromBytes(data, cls.formats[ext])

def toBytes(self, format):
"""Convert the object to a serialized bytes form using the
specified format.
Parameters
----------
format : `str`
Format to use when serializing. Can be ``yaml`` or ``pickle``.
Returns
-------
data : `bytes`
Byte string representing the serialized object.
"""
if format == "pickle":
return pickle.dumps(self)
elif format == "yaml":
return yaml.dump(self).encode("utf-8")
else:
raise ValueError(f"Unable to determine how to read file {filename} from extension {ext}")
raise ValueError(f"Unexpected serialization format requested: {format}")

def write(self, filename):
"""Write to file.
Expand All @@ -315,14 +360,12 @@ def write(self, filename):
``.pickle`` and ``.yaml``
"""
_, ext = os.path.splitext(filename)
if ext in (".pickle", ".pkl"):
with open(filename, "wb") as ff:
pickle.dump(self, ff)
elif ext == ".yaml":
with open(filename, "w") as ff:
yaml.dump(self, ff)
else:
raise ValueError(f"Unexpected file format requested: {ext}")
if ext not in self.formats:
raise ValueError(f"Format from {ext} extension in file {filename} not recognized")
with open(filename, "wb") as ff:
# Assumes that the bytes serialization of this object is
# relatively small.
ff.write(self.toBytes(self.formats[ext]))

def __len__(self):
return len(self._packages)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,22 @@ def testPackages(self):
self.assertDictEqual(packages.extra(new), {})
self.assertEqual(len(packages), len(new))

# Serialize via bytes
for format in ("pickle", "yaml"):
asbytes = new.toBytes(format)
from_bytes = lsst.base.Packages.fromBytes(asbytes, format)
self.assertEqual(from_bytes, new)

with self.assertRaises(ValueError):
new.toBytes("unknown_format")

with self.assertRaises(ValueError):
lsst.base.Packages.fromBytes(from_bytes, "unknown_format")

with self.assertRaises(TypeError):
some_yaml = b"list: [1, 2]"
lsst.base.Packages.fromBytes(some_yaml, "yaml")


if __name__ == "__main__":
unittest.main()

0 comments on commit 082faa5

Please sign in to comment.