Skip to content

Commit

Permalink
Merge pull request #369 from lsst/tickets/DM-40392
Browse files Browse the repository at this point in the history
DM-40392: Re-implement QuantumGraph.updateRun method
  • Loading branch information
andy-slac committed Aug 19, 2023
2 parents 33ad370 + bfe6f61 commit 870df52
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 19 deletions.
1 change: 1 addition & 0 deletions doc/changes/DM-40392.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`QuantumGraph.updateRun()` method is fixed to update dataset ID in references which have their run collection changed.
90 changes: 72 additions & 18 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,20 @@
import time
import uuid
from collections import defaultdict, deque
from collections.abc import Generator, Iterable, Mapping, MutableMapping
from collections.abc import Generator, Iterable, Iterator, Mapping, MutableMapping
from itertools import chain
from types import MappingProxyType
from typing import Any, BinaryIO, TypeVar

import networkx as nx
from lsst.daf.butler import DatasetRef, DatasetType, DimensionRecordsAccumulator, DimensionUniverse, Quantum
from lsst.daf.butler import (
DatasetId,
DatasetRef,
DatasetType,
DimensionRecordsAccumulator,
DimensionUniverse,
Quantum,
)
from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils.introspection import get_full_type_name
from networkx.drawing.nx_agraph import write_dot
Expand Down Expand Up @@ -1229,32 +1236,79 @@ def updateRun(self, run: str, *, metadata_key: str | None = None, update_graph_i
update_graph_id : `bool`, optional
If `True` then also update graph ID with a new unique value.
"""
dataset_id_map: dict[DatasetId, DatasetId] = {}

def _update_refs_in_place(refs: list[DatasetRef], run: str) -> None:
"""Update list of `~lsst.daf.butler.DatasetRef` with new run and
dataset IDs.
def _update_output_refs(
refs: Iterable[DatasetRef], run: str, dataset_id_map: MutableMapping[DatasetId, DatasetId]
) -> Iterator[DatasetRef]:
"""Update a collection of `~lsst.daf.butler.DatasetRef` with new
run and dataset IDs.
"""
for ref in refs:
# hack the run to be replaced explicitly
object.__setattr__(ref, "run", run)
new_ref = ref.replace(run=run)
dataset_id_map[ref.id] = new_ref.id
yield new_ref

def _update_intermediate_refs(
refs: Iterable[DatasetRef], run: str, dataset_id_map: Mapping[DatasetId, DatasetId]
) -> Iterator[DatasetRef]:
"""Update intermediate references with new run and IDs. Only the
references that appear in ``dataset_id_map`` are updated, others
are returned unchanged.
"""
for ref in refs:
if dataset_id := dataset_id_map.get(ref.id):
ref = ref.replace(run=run, id=dataset_id)
yield ref

# Loop through all outputs and update their datasets.
# Replace quantum output refs first.
for node in self._connectedQuanta:
for refs in node.quantum.outputs.values():
_update_refs_in_place(refs, run)

for refs in self._initOutputRefs.values():
_update_refs_in_place(refs, run)
quantum = node.quantum
outputs = {
dataset_type: tuple(_update_output_refs(refs, run, dataset_id_map))
for dataset_type, refs in quantum.outputs.items()
}
updated_quantum = Quantum(
taskName=quantum.taskName,
dataId=quantum.dataId,
initInputs=quantum.initInputs,
inputs=quantum.inputs,
outputs=outputs,
datastore_records=quantum.datastore_records,
)
node._replace_quantum(updated_quantum)

_update_refs_in_place(self._globalInitOutputRefs, run)
self._initOutputRefs = {
task_def: list(_update_output_refs(refs, run, dataset_id_map))
for task_def, refs in self._initOutputRefs.items()
}
self._globalInitOutputRefs = list(
_update_output_refs(self._globalInitOutputRefs, run, dataset_id_map)
)

# Update all intermediates from their matching outputs.
for node in self._connectedQuanta:
for refs in node.quantum.inputs.values():
_update_refs_in_place(refs, run)
quantum = node.quantum
inputs = {
dataset_type: tuple(_update_intermediate_refs(refs, run, dataset_id_map))
for dataset_type, refs in quantum.inputs.items()
}
initInputs = list(_update_intermediate_refs(quantum.initInputs.values(), run, dataset_id_map))

updated_quantum = Quantum(
taskName=quantum.taskName,
dataId=quantum.dataId,
initInputs=initInputs,
inputs=inputs,
outputs=quantum.outputs,
datastore_records=quantum.datastore_records,
)
node._replace_quantum(updated_quantum)

for refs in self._initInputRefs.values():
_update_refs_in_place(refs, run)
self._initInputRefs = {
task_def: list(_update_intermediate_refs(refs, run, dataset_id_map))
for task_def, refs in self._initInputRefs.items()
}

if update_graph_id:
self._buildId = BuildId(f"{time.time()}-{os.getpid()}")
Expand Down
27 changes: 27 additions & 0 deletions python/lsst/pipe/base/graph/quantumNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,33 @@ def from_simple(
nodeId=simple.nodeId,
)

def _replace_quantum(self, quantum: Quantum) -> None:
"""Replace Quantum instance in this node.
Parameters
----------
quantum : `Quantum`
New Quantum instance for this node.
Raises
------
ValueError
Raised if the hash of the new quantum is different from the hash of
the existing quantum.
Notes
-----
This class is immutable and hashable, so this method checks that new
quantum does not invalidate its current hash. This method is supposed
to used only by `QuantumGraph` class as its implementation detail,
so it is made "underscore-protected".
"""
if hash(quantum) != hash(self.quantum):
raise ValueError(
f"Hash of the new quantum {quantum} does not match hash of existing quantum {self.quantum}"
)
object.__setattr__(self, "quantum", quantum)


_fields_set = {"quantum", "taskLabel", "nodeId"}

Expand Down
37 changes: 37 additions & 0 deletions python/lsst/pipe/base/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,40 @@ def check_output_run(graph: QuantumGraph, run: str) -> list[DatasetRef]:
newRefs += [ref for ref in intermediates if ref.run != run]

return newRefs


def get_output_refs(graph: QuantumGraph) -> list[DatasetRef]:
"""Return all output and intermediate references in a graph.
Parameters
----------
graph : `QuantumGraph`
Quantum graph.
Returns
-------
refs : `list` [ `~lsst.daf.butler.DatasetRef` ]
List of all output/intermediate dataset references, intermediates
will appear more than once in this list.
"""
output_refs: set[DatasetRef] = set()
for node in graph:
for refs in node.quantum.outputs.values():
output_refs.update(refs)
for task_def in graph.iterTaskGraph():
init_refs = graph.initOutputRefs(task_def)
if init_refs:
output_refs.update(init_refs)
output_refs.update(graph.globalInitOutputRefs())

result = list(output_refs)

for node in graph:
for refs in node.quantum.inputs.values():
result += [ref for ref in refs if ref in output_refs]
for task_def in graph.iterTaskGraph():
init_refs = graph.initInputRefs(task_def)
if init_refs:
result += [ref for ref in init_refs if ref in output_refs]

return result
8 changes: 7 additions & 1 deletion tests/test_quantumGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
TaskDef,
)
from lsst.pipe.base.graph.quantumNode import BuildId, QuantumNode
from lsst.pipe.base.tests.util import check_output_run
from lsst.pipe.base.tests.util import check_output_run, get_output_refs
from lsst.utils.introspection import get_full_type_name

METADATA = {"a": [1, 2, 3]}
Expand Down Expand Up @@ -561,11 +561,17 @@ def testDimensionUniverseInSave(self) -> None:
def testUpdateRun(self) -> None:
"""Test for QuantumGraph.updateRun method."""
self.assertEqual(check_output_run(self.qGraph, self.output_run), [])
output_refs = get_output_refs(self.qGraph)
self.assertGreater(len(output_refs), 0)
graph_id = self.qGraph.graphID

self.qGraph.updateRun("updated-run")
self.assertEqual(check_output_run(self.qGraph, "updated-run"), [])
self.assertEqual(self.qGraph.graphID, graph_id)
output_refs2 = get_output_refs(self.qGraph)
self.assertEqual(len(output_refs2), len(output_refs))
# All output dataset IDs must be updated.
self.assertTrue(set(ref.id for ref in output_refs).isdisjoint(set(ref.id for ref in output_refs2)))

# Also update metadata.
self.qGraph.updateRun("updated-run2", metadata_key="ouput_run")
Expand Down

0 comments on commit 870df52

Please sign in to comment.