Skip to content

Commit

Permalink
Save stack versions in QuantumGraph.
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleGower committed Mar 26, 2024
1 parent 7c47f2a commit 83de5ef
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
1 change: 1 addition & 0 deletions doc/changes/DM-43225.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
QuantumGraph generation now saves software stack versions in the graph's metadata.
16 changes: 8 additions & 8 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from lsst.daf.butler.persistence_context import PersistenceContextVars
from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils.introspection import get_full_type_name
from lsst.utils.packages import Packages
from networkx.drawing.nx_agraph import write_dot

from ..connections import iterConnections
Expand Down Expand Up @@ -174,7 +175,10 @@ def _buildGraphs(
"""Build the graph that is used to store the relation between tasks,
and the graph that holds the relations between quanta
"""
self._metadata = metadata
# Save packages to metadata
self._metadata = dict(metadata) if metadata is not None else {}
self._metadata["packages"] = Packages.fromSystem()

self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
# Data structure used to identify relations between
# DatasetTypeName -> TaskDef.
Expand Down Expand Up @@ -757,8 +761,6 @@ def metadata(self) -> MappingProxyType[str, Any] | None:
The mapping is a dynamic view of this object's metadata. Values should
be able to be serialized in JSON.
"""
if self._metadata is None:
return None
return MappingProxyType(self._metadata)

def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
Expand Down Expand Up @@ -1246,11 +1248,9 @@ def _update_intermediate_refs(
if update_graph_id:
self._buildId = BuildId(f"{time.time()}-{os.getpid()}")

# Update metadata if present.
if self._metadata is not None and metadata_key is not None:
metadata = dict(self._metadata)
metadata[metadata_key] = run
self._metadata = metadata
# Update run if given.
if metadata_key is not None:
self._metadata[metadata_key] = run

@property
def graphID(self) -> BuildId:
Expand Down
16 changes: 11 additions & 5 deletions tests/test_quantumGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from lsst.pipe.base.graph.quantumNode import BuildId, QuantumNode
from lsst.pipe.base.tests.util import check_output_run, get_output_refs
from lsst.utils.introspection import get_full_type_name
from lsst.utils.packages import Packages

METADATA = {"a": [1, 2, 3]}

Expand Down Expand Up @@ -498,7 +499,7 @@ def testSaveLoadUri(self) -> None:
uri = tmpFile.name
self.qGraph.saveUri(uri)
restore = QuantumGraph.loadUri(uri)
self.assertEqual(restore.metadata, METADATA)
self.assertEqual(restore.metadata, self.qGraph.metadata)
self.assertEqual(self.qGraph, restore)
nodeNumberId = random.randint(0, len(self.qGraph) - 1)
nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId]
Expand Down Expand Up @@ -580,18 +581,23 @@ def testUpdateRun(self) -> None:
self.assertTrue(set(ref.id for ref in output_refs).isdisjoint(set(ref.id for ref in output_refs2)))

# Also update metadata.
self.qGraph.updateRun("updated-run2", metadata_key="ouput_run")
self.qGraph.updateRun("updated-run2", metadata_key="output_run")
self.assertEqual(check_output_run(self.qGraph, "updated-run2"), [])
self.assertEqual(self.qGraph.graphID, graph_id)
assert self.qGraph.metadata is not None
self.assertIn("ouput_run", self.qGraph.metadata)
self.assertEqual(self.qGraph.metadata["ouput_run"], "updated-run2")
self.assertIn("output_run", self.qGraph.metadata)
self.assertEqual(self.qGraph.metadata["output_run"], "updated-run2")

# Update graph ID.
self.qGraph.updateRun("updated-run3", metadata_key="ouput_run", update_graph_id=True)
self.qGraph.updateRun("updated-run3", metadata_key="output_run", update_graph_id=True)
self.assertEqual(check_output_run(self.qGraph, "updated-run3"), [])
self.assertNotEqual(self.qGraph.graphID, graph_id)

def testMetadataPackage(self) -> None:
"""Test package versions added to QuantumGraph metadata."""
packages = Packages.fromSystem()
self.assertEqual(self.qGraph.metadata["packages"], packages)


class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
"""Run file leak tests."""
Expand Down

0 comments on commit 83de5ef

Please sign in to comment.