Skip to content

Commit

Permalink
Merge pull request #425 from lsst/tickets/DM-27418
Browse files Browse the repository at this point in the history
DM-27418: Use safe_load/dump for YAML
  • Loading branch information
timj committed Nov 5, 2020
2 parents 39b7668 + 477dcae commit f286447
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 40 deletions.
15 changes: 12 additions & 3 deletions python/lsst/daf/butler/configs/datastores/formatters.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,20 @@ Packages:
formatter: lsst.obs.base.formatters.packages.PackagesFormatter
parameters:
format: yaml
PropertyList: lsst.daf.butler.formatters.yaml.YamlFormatter
PropertySet: lsst.daf.butler.formatters.yaml.YamlFormatter
PropertyList:
formatter: lsst.daf.butler.formatters.yaml.YamlFormatter
parameters:
unsafe_dump: true
PropertySet:
formatter: lsst.daf.butler.formatters.yaml.YamlFormatter
parameters:
unsafe_dump: true
NumpyArray: lsst.daf.butler.formatters.pickle.PickleFormatter
Plot: lsst.daf.butler.formatters.matplotlib.MatplotlibFormatter
MetricValue: lsst.daf.butler.formatters.yaml.YamlFormatter
MetricValue:
formatter: lsst.daf.butler.formatters.yaml.YamlFormatter
parameters:
unsafe_dump: true
BrighterFatterKernel: lsst.daf.butler.formatters.pickle.PickleFormatter
StructuredDataDict: lsst.daf.butler.formatters.yaml.YamlFormatter
Filter: lsst.obs.base.formatters.filter.FilterFormatter
16 changes: 6 additions & 10 deletions python/lsst/daf/butler/formatters/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,10 @@ def _readFile(self, path: str, pytype: Optional[Type[Any]] = None) -> Any:
Returns
-------
data : `object`
Either data as Python object read from JSON file, or None
if the file could not be opened.
Data as Python object read from JSON file.
"""
try:
with open(path, "rb") as fd:
data = self._fromBytes(fd.read(), pytype)
except FileNotFoundError:
data = None
with open(path, "rb") as fd:
data = self._fromBytes(fd.read(), pytype)

return data

Expand All @@ -88,8 +84,6 @@ def _writeFile(self, inMemoryDataset: Any) -> None:
The file could not be written.
"""
with open(self.fileDescriptor.location.path, "wb") as fd:
if hasattr(inMemoryDataset, "_asdict"):
inMemoryDataset = inMemoryDataset._asdict()
fd.write(self._toBytes(inMemoryDataset))

def _fromBytes(self, serializedDataset: bytes, pytype: Optional[Type[Any]] = None) -> Any:
Expand Down Expand Up @@ -133,6 +127,8 @@ def _toBytes(self, inMemoryDataset: Any) -> bytes:
Exception
The object could not be serialized.
"""
if hasattr(inMemoryDataset, "_asdict"):
inMemoryDataset = inMemoryDataset._asdict()
return json.dumps(inMemoryDataset, ensure_ascii=False).encode()

def _coerceType(self, inMemoryDataset: Any, storageClass: StorageClass,
Expand All @@ -153,7 +149,7 @@ def _coerceType(self, inMemoryDataset: Any, storageClass: StorageClass,
inMemoryDataset : `object`
Object of expected type `pytype`.
"""
if pytype is not None and not hasattr(builtins, pytype.__name__):
if inMemoryDataset is not None and pytype is not None and not hasattr(builtins, pytype.__name__):
if storageClass.isComposite():
inMemoryDataset = storageClass.delegate().assemble(inMemoryDataset, pytype=pytype)
elif not isinstance(inMemoryDataset, pytype):
Expand Down
48 changes: 38 additions & 10 deletions python/lsst/daf/butler/formatters/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class YamlFormatter(FileFormatter):
unsupportedParameters = None
"""This formatter does not support any parameters"""

supportedWriteParameters = frozenset({"unsafe_dump"})
"""Allow the normal yaml.dump to be used to write the YAML. Use this
if you know that your class has registered representers."""

def _readFile(self, path: str, pytype: Type[Any] = None) -> Any:
"""Read a file from the path in YAML format.
Expand All @@ -65,7 +69,7 @@ def _readFile(self, path: str, pytype: Type[Any] = None) -> Any:
Notes
-----
The `~yaml.UnsafeLoader` is used when parsing the YAML file.
The `~yaml.SafeLoader` is used when parsing the YAML file.
"""
try:
with open(path, "rb") as fd:
Expand All @@ -90,11 +94,13 @@ def _fromBytes(self, serializedDataset: bytes, pytype: Optional[Type[Any]] = Non
inMemoryDataset : `object`
The requested data as an object, or None if the string could
not be read.
Notes
-----
The `~yaml.SafeLoader` is used when parsing the YAML.
"""
try:
data = yaml.load(serializedDataset, Loader=yaml.FullLoader)
except yaml.YAMLError:
data = None
data = yaml.safe_load(serializedDataset)

try:
data = data.exportAsDict()
except AttributeError:
Expand All @@ -105,7 +111,8 @@ def _writeFile(self, inMemoryDataset: Any) -> None:
"""Write the in memory dataset to file on disk.
Will look for `_asdict()` method to aid YAML serialization, following
the approach of the simplejson module.
the approach of the simplejson module. The `dict` will be passed
to the relevant constructor on read.
Parameters
----------
Expand All @@ -116,15 +123,23 @@ def _writeFile(self, inMemoryDataset: Any) -> None:
------
Exception
The file could not be written.
Notes
-----
The `~yaml.SafeDumper` is used when generating the YAML serialization.
This will fail for data structures that have complex python classes
without a registered YAML representer.
"""
with open(self.fileDescriptor.location.path, "wb") as fd:
if hasattr(inMemoryDataset, "_asdict"):
inMemoryDataset = inMemoryDataset._asdict()
fd.write(self._toBytes(inMemoryDataset))

def _toBytes(self, inMemoryDataset: Any) -> bytes:
"""Write the in memory dataset to a bytestring.
Will look for `_asdict()` method to aid YAML serialization, following
the approach of the simplejson module. The `dict` will be passed
to the relevant constructor on read.
Parameters
----------
inMemoryDataset : `object`
Expand All @@ -139,8 +154,21 @@ def _toBytes(self, inMemoryDataset: Any) -> bytes:
------
Exception
The object could not be serialized.
Notes
-----
The `~yaml.SafeDumper` is used when generating the YAML serialization.
This will fail for data structures that have complex python classes
without a registered YAML representer.
"""
return yaml.dump(inMemoryDataset).encode()
if hasattr(inMemoryDataset, "_asdict"):
inMemoryDataset = inMemoryDataset._asdict()
unsafe_dump = self.writeParameters.get("unsafe_dump", False)
if unsafe_dump:
serialized = yaml.dump(inMemoryDataset)
else:
serialized = yaml.safe_dump(inMemoryDataset)
return serialized.encode()

def _coerceType(self, inMemoryDataset: Any, storageClass: StorageClass,
pytype: Optional[Type[Any]] = None) -> Any:
Expand All @@ -160,7 +188,7 @@ def _coerceType(self, inMemoryDataset: Any, storageClass: StorageClass,
inMemoryDataset : `object`
Object of expected type `pytype`.
"""
if pytype is not None and not hasattr(builtins, pytype.__name__):
if inMemoryDataset is not None and pytype is not None and not hasattr(builtins, pytype.__name__):
if storageClass.isComposite():
inMemoryDataset = storageClass.delegate().assemble(inMemoryDataset, pytype=pytype)
elif not isinstance(inMemoryDataset, pytype):
Expand Down
59 changes: 42 additions & 17 deletions tests/test_butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ def setUpClass(cls):
cls.storageClassFactory = StorageClassFactory()
cls.storageClassFactory.addFromConfig(cls.configFile)

def assertGetComponents(self, butler, datasetRef, components, reference):
def assertGetComponents(self, butler, datasetRef, components, reference, collections=None):
datasetType = datasetRef.datasetType
dataId = datasetRef.dataId
deferred = butler.getDirectDeferred(datasetRef)

for component in components:
compTypeName = datasetType.componentTypeName(component)
result = butler.get(compTypeName, dataId)
result = butler.get(compTypeName, dataId, collections=collections)
self.assertEqual(result, getattr(reference, component))
result_deferred = deferred.get(component=component)
self.assertEqual(result_deferred, result)
Expand Down Expand Up @@ -194,25 +194,41 @@ def runPutGetTest(self, storageClass, datasetTypeName):

# Put and remove the dataset once as a DatasetRef, once as a dataId,
# and once with a DatasetType

# Keep track of any collections we add and do not clean up
expected_collections = {run, tag}

counter = 0
for args in ((refIn,), (datasetTypeName, dataId), (datasetType, dataId)):
# Since we are using subTest we can get cascading failures
# here with the first attempt failing and the others failing
# immediately because the dataset already exists. Work around
# this by using a distinct run collection each time
counter += 1
this_run = f"put_run_{counter}"
this_tag = f"put_tag_{counter}"
butler.registry.registerCollection(this_run, type=CollectionType.RUN)
butler.registry.registerCollection(this_tag, type=CollectionType.TAGGED)
expected_collections.update({this_run, this_tag})

with self.subTest(args=args):
ref = butler.put(metric, *args)
ref = butler.put(metric, *args, run=this_run, tags=[this_tag])
self.assertIsInstance(ref, DatasetRef)

# Test getDirect
metricOut = butler.getDirect(ref)
self.assertEqual(metric, metricOut)
# Test get
metricOut = butler.get(ref.datasetType.name, dataId)
metricOut = butler.get(ref.datasetType.name, dataId, collections=this_run)
self.assertEqual(metric, metricOut)
# Test get with a datasetRef
metricOut = butler.get(ref)
metricOut = butler.get(ref, collections=this_run)
self.assertEqual(metric, metricOut)
# Test getDeferred with dataId
metricOut = butler.getDeferred(ref.datasetType.name, dataId).get()
metricOut = butler.getDeferred(ref.datasetType.name, dataId, collections=this_run).get()
self.assertEqual(metric, metricOut)
# Test getDeferred with a datasetRef
metricOut = butler.getDeferred(ref).get()
metricOut = butler.getDeferred(ref, collections=this_run).get()
self.assertEqual(metric, metricOut)
# and deferred direct with ref
metricOut = butler.getDirectDeferred(ref).get()
Expand All @@ -221,13 +237,14 @@ def runPutGetTest(self, storageClass, datasetTypeName):
# Check we can get components
if storageClass.isComposite():
self.assertGetComponents(butler, ref,
("summary", "data", "output"), metric)
("summary", "data", "output"), metric,
collections=this_run)

# Remove from the tagged collection only; after that we
# shouldn't be able to find it unless we use the dataset_id.
butler.pruneDatasets([ref])
butler.pruneDatasets([ref], tags=[this_tag])
with self.assertRaises(LookupError):
butler.datasetExists(*args)
butler.datasetExists(*args, collections=this_tag)
# Registry still knows about it, if we use the dataset_id.
self.assertEqual(butler.registry.getDataset(ref.id), ref)
# If we use the output ref with the dataset_id, we should
Expand All @@ -236,29 +253,37 @@ def runPutGetTest(self, storageClass, datasetTypeName):

# Reinsert into collection, then delete from Datastore *and*
# remove from collection.
butler.registry.associate(tag, [ref])
butler.pruneDatasets([ref], unstore=True)
butler.registry.associate(this_tag, [ref])
butler.pruneDatasets([ref], unstore=True, tags=[this_tag])
# Lookup with original args should still fail.
with self.assertRaises(LookupError):
butler.datasetExists(*args)
butler.datasetExists(*args, collections=this_tag)
# Now getDirect() should fail, too.
with self.assertRaises(FileNotFoundError, msg=f"Checking ref {ref} not found"):
butler.getDirect(ref)
# Registry still knows about it, if we use the dataset_id.
self.assertEqual(butler.registry.getDataset(ref.id), ref)

# Now remove the dataset completely.
butler.pruneDatasets([ref], purge=True, unstore=True)
butler.pruneDatasets([ref], purge=True, unstore=True, tags=[this_tag], run=this_run)
# Lookup with original args should still fail.
with self.assertRaises(LookupError):
butler.datasetExists(*args)
butler.datasetExists(*args, collections=this_run)
# getDirect() should still fail.
with self.assertRaises(FileNotFoundError):
butler.getDirect(ref)
# Registry shouldn't be able to find it by dataset_id anymore.
self.assertIsNone(butler.registry.getDataset(ref.id))

# Put the dataset again, since the last thing we did was remove it.
# Cleanup
for coll in (this_run, this_tag):
# Do explicit registry removal since we know they are
# empty
butler.registry.removeCollection(coll)
expected_collections.remove(coll)

# Put the dataset again, since the last thing we did was remove it
# and we want to use the default collection.
ref = butler.put(metric, refIn)

# Get with parameters
Expand Down Expand Up @@ -324,7 +349,7 @@ def runPutGetTest(self, storageClass, datasetTypeName):

# Check we have a collection
collections = set(butler.registry.queryCollections())
self.assertEqual(collections, {run, tag})
self.assertEqual(collections, expected_collections)

# Clean up to check that we can remove something that may have
# already had a component removed
Expand Down

0 comments on commit f286447

Please sign in to comment.