Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-39582: Decrease memory usage and load times when reading graphs #348

Merged
merged 9 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/changes/DM-39582.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The back-end to quantum graph loading has been optimized such that duplicate objects are not created in
memory, but create shared references. This results in a large decrease in memory usage, and decrease in load
times.
1 change: 1 addition & 0 deletions doc/changes/DM-39582.removal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Deprecated reconstituteDimensions argument from `QuantumNode.from_simple`
2 changes: 1 addition & 1 deletion python/lsst/pipe/base/_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ class _DummyConfig(Config):

config = _DummyConfig()

return config.packer.apply(data_id, is_exposure=is_exposure)
return config.packer.apply(data_id, is_exposure=is_exposure) # type: ignore

@staticmethod
@final
Expand Down
10 changes: 5 additions & 5 deletions python/lsst/pipe/base/_quantumContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def get(
n_connections = len(dataset)
n_retrieved = 0
for i, (name, ref) in enumerate(dataset):
if isinstance(ref, list):
if isinstance(ref, (list, tuple)):
val = []
n_refs = len(ref)
for j, r in enumerate(ref):
Expand Down Expand Up @@ -301,7 +301,7 @@ def get(
"Completed retrieval of %d datasets from %d connections", n_retrieved, n_connections
)
return retVal
elif isinstance(dataset, list):
elif isinstance(dataset, (list, tuple)):
n_datasets = len(dataset)
retrieved = []
for i, x in enumerate(dataset):
Expand Down Expand Up @@ -363,14 +363,14 @@ def put(
)
for name, refs in dataset:
valuesAttribute = getattr(values, name)
if isinstance(refs, list):
if isinstance(refs, (list, tuple)):
if len(refs) != len(valuesAttribute):
raise ValueError(f"There must be a object to put for every Dataset ref in {name}")
for i, ref in enumerate(refs):
self._put(valuesAttribute[i], ref)
else:
self._put(valuesAttribute, refs)
elif isinstance(dataset, list):
elif isinstance(dataset, (list, tuple)):
if not isinstance(values, Sequence):
raise ValueError("Values to put must be a sequence")
if len(dataset) != len(values):
Expand Down Expand Up @@ -401,7 +401,7 @@ def _checkMembership(self, ref: list[DatasetRef] | DatasetRef, inout: set) -> No
which may be important for Quanta with lots of
`~lsst.daf.butler.DatasetRef`.
"""
if not isinstance(ref, list):
if not isinstance(ref, (list, tuple)):
ref = [ref]
for r in ref:
if (r.datasetType, r.dataId) not in inout:
Expand Down
15 changes: 8 additions & 7 deletions python/lsst/pipe/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import itertools
import string
from collections import UserDict
from collections.abc import Collection, Generator, Iterable, Mapping, Set
from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set
from dataclasses import dataclass
from types import MappingProxyType, SimpleNamespace
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -934,12 +934,12 @@
connection-oriented mappings used inside `PipelineTaskConnections`.
"""

inputs: NamedKeyMapping[DatasetType, list[DatasetRef]]
inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
"""Mapping of regular input and prerequisite input datasets, grouped by
`~lsst.daf.butler.DatasetType`.
"""

outputs: NamedKeyMapping[DatasetType, list[DatasetRef]]
outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
"""Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`.
"""

Expand Down Expand Up @@ -997,7 +997,7 @@
# Translate adjustments to DatasetType-keyed, Quantum-oriented form,
# installing new mappings in self if necessary.
if adjusted_inputs_by_connection:
adjusted_inputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.inputs)
adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs)

Check warning on line 1000 in python/lsst/pipe/base/connections.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/connections.py#L1000

Added line #L1000 was not covered by tests
for name, (connection, updated_refs) in adjusted_inputs_by_connection.items():
dataset_type_name = connection.name
if not set(updated_refs).issubset(self.inputs[dataset_type_name]):
Expand All @@ -1006,21 +1006,22 @@
f"({dataset_type_name}) input datasets that are not a subset of those "
f"it was given for data ID {data_id}."
)
adjusted_inputs[dataset_type_name] = list(updated_refs)
adjusted_inputs[dataset_type_name] = tuple(updated_refs)

Check warning on line 1009 in python/lsst/pipe/base/connections.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/connections.py#L1009

Added line #L1009 was not covered by tests
self.inputs = adjusted_inputs.freeze()
self.inputs_adjusted = True
else:
self.inputs_adjusted = False
if adjusted_outputs_by_connection:
adjusted_outputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.outputs)
adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs)

Check warning on line 1015 in python/lsst/pipe/base/connections.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/connections.py#L1015

Added line #L1015 was not covered by tests
for name, (connection, updated_refs) in adjusted_outputs_by_connection.items():
dataset_type_name = connection.name

Check warning on line 1017 in python/lsst/pipe/base/connections.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/connections.py#L1017

Added line #L1017 was not covered by tests
if not set(updated_refs).issubset(self.outputs[dataset_type_name]):
raise RuntimeError(
f"adjustQuantum implementation for task with label {label} returned {name} "
f"({dataset_type_name}) output datasets that are not a subset of those "
f"it was given for data ID {data_id}."
)
adjusted_outputs[dataset_type_name] = list(updated_refs)
adjusted_outputs[dataset_type_name] = tuple(updated_refs)

Check warning on line 1024 in python/lsst/pipe/base/connections.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/connections.py#L1024

Added line #L1024 was not covered by tests
self.outputs = adjusted_outputs.freeze()
self.outputs_adjusted = True
else:
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/pipe/base/executionButlerBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _accumulate(
for type, refs in attr.items():
# This if block is because init inputs has a different
# signature for its items
if not isinstance(refs, list):
if not isinstance(refs, (list, tuple)):
refs = [refs]
for ref in refs:
if ref.isComponent():
Expand All @@ -177,7 +177,7 @@ def _accumulate(
attr = getattr(quantum, attrName)

for type, refs in attr.items():
if not isinstance(refs, list):
if not isinstance(refs, (list, tuple)):
refs = [refs]
if type.component() is not None:
type = type.makeCompositeDatasetType()
Expand Down
6 changes: 4 additions & 2 deletions python/lsst/pipe/base/graph/_implDetails.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,15 @@ def _pruner(
# from the graph.
try:
helper.adjust_in_place(node.taskDef.connections, node.taskDef.label, node.quantum.dataId)
# ignore the types because quantum really can take a sequence
# of inputs
newQuantum = Quantum(
taskName=node.quantum.taskName,
taskClass=node.quantum.taskClass,
dataId=node.quantum.dataId,
initInputs=node.quantum.initInputs,
inputs=helper.inputs,
outputs=helper.outputs,
inputs=helper.inputs, # type: ignore
outputs=helper.outputs, # type: ignore
)
# If the inputs or outputs were adjusted to something different
# than what was supplied by the graph builder, dissassociate
Expand Down
8 changes: 6 additions & 2 deletions python/lsst/pipe/base/graph/_loadHelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from typing import TYPE_CHECKING, BinaryIO
from uuid import UUID

from lsst.daf.butler import DimensionUniverse
from lsst.daf.butler import DimensionUniverse, PersistenceContextVars
from lsst.resources import ResourceHandleProtocol, ResourcePath

if TYPE_CHECKING:
Expand Down Expand Up @@ -219,7 +219,11 @@ def load(
_readBytes = self._readBytes
if universe is None:
universe = headerInfo.universe
return self.deserializer.constructGraph(nodeSet, _readBytes, universe)
# use the daf butler context vars to aid in ensuring deduplication in
# object instantiation.
runner = PersistenceContextVars()
graph = runner.run(self.deserializer.constructGraph, nodeSet, _readBytes, universe)
return graph

def _readBytes(self, start: int, stop: int) -> bytes:
"""Load the specified byte range from the ResourcePath object
Expand Down
2 changes: 2 additions & 0 deletions python/lsst/pipe/base/graph/_versionDeserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@ def constructGraph(

# Turn the json back into the pydandtic model
nodeDeserialized = SerializedQuantumNode.direct(**dump)
del dump

# attach the dictionary of dimension records to the pydantic model
# these are stored separately because the are stored over and over
# and this saves a lot of space and time.
Expand Down
33 changes: 8 additions & 25 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,49 +1276,32 @@ def updateRun(self, run: str, *, metadata_key: str | None = None, update_graph_i
update_graph_id : `bool`, optional
If `True` then also update graph ID with a new unique value.
"""
dataset_id_map = {}

def _update_output_refs_in_place(refs: list[DatasetRef], run: str) -> None:
def _update_refs_in_place(refs: list[DatasetRef], run: str) -> None:
"""Update list of `~lsst.daf.butler.DatasetRef` with new run and
dataset IDs.
"""
new_refs = []
for ref in refs:
new_ref = DatasetRef(ref.datasetType, ref.dataId, run=run, conform=False)
dataset_id_map[ref.id] = new_ref.id
new_refs.append(new_ref)
refs[:] = new_refs

def _update_input_refs_in_place(refs: list[DatasetRef], run: str) -> None:
"""Update list of `~lsst.daf.butler.DatasetRef` with IDs from
dataset_id_map.
"""
new_refs = []
for ref in refs:
if (new_id := dataset_id_map.get(ref.id)) is not None:
new_ref = DatasetRef(ref.datasetType, ref.dataId, id=new_id, run=run, conform=False)
new_refs.append(new_ref)
else:
new_refs.append(ref)
refs[:] = new_refs
# hack the run to be replaced explicitly
object.__setattr__(ref, "run", run)

# Loop through all outputs and update their datasets.
for node in self._connectedQuanta:
for refs in node.quantum.outputs.values():
_update_output_refs_in_place(refs, run)
_update_refs_in_place(refs, run)

for refs in self._initOutputRefs.values():
_update_output_refs_in_place(refs, run)
_update_refs_in_place(refs, run)

_update_output_refs_in_place(self._globalInitOutputRefs, run)
_update_refs_in_place(self._globalInitOutputRefs, run)

# Update all intermediates from their matching outputs.
for node in self._connectedQuanta:
for refs in node.quantum.inputs.values():
_update_input_refs_in_place(refs, run)
_update_refs_in_place(refs, run)

for refs in self._initInputRefs.values():
_update_input_refs_in_place(refs, run)
_update_refs_in_place(refs, run)

if update_graph_id:
self._buildId = BuildId(f"{time.time()}-{os.getpid()}")
Expand Down
19 changes: 15 additions & 4 deletions python/lsst/pipe/base/graph/quantumNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
__all__ = ("QuantumNode", "NodeId", "BuildId")

import uuid
import warnings
from dataclasses import dataclass
from typing import Any, NewType

Expand All @@ -34,6 +35,7 @@
Quantum,
SerializedQuantum,
)
from lsst.utils.introspection import find_outside_stacklevel
from pydantic import BaseModel

from ..pipeline import TaskDef
Expand Down Expand Up @@ -96,6 +98,8 @@ class QuantumNode:
creation.
"""

__slots__ = ("quantum", "taskDef", "nodeId", "_precomputedHash")

def __post_init__(self) -> None:
# use setattr here to preserve the frozenness of the QuantumNode
self._precomputedHash: int
Expand Down Expand Up @@ -135,15 +139,22 @@ def from_simple(
universe: DimensionUniverse,
recontitutedDimensions: dict[int, tuple[str, DimensionRecord]] | None = None,
) -> QuantumNode:
if recontitutedDimensions is not None:
warnings.warn(
"The recontitutedDimensions argument is now ignored and may be removed after v 27",
category=FutureWarning,
stacklevel=find_outside_stacklevel("lsst.pipe.base"),
)
return QuantumNode(
quantum=Quantum.from_simple(
simple.quantum, universe, reconstitutedDimensions=recontitutedDimensions
),
quantum=Quantum.from_simple(simple.quantum, universe),
taskDef=taskDefMap[simple.taskLabel],
nodeId=simple.nodeId,
)


_fields_set = {"quantum", "taskLabel", "nodeId"}


class SerializedQuantumNode(BaseModel):
quantum: SerializedQuantum
taskLabel: str
Expand All @@ -156,5 +167,5 @@ def direct(cls, *, quantum: dict[str, Any], taskLabel: str, nodeId: str) -> Seri
setter(node, "quantum", SerializedQuantum.direct(**quantum))
setter(node, "taskLabel", taskLabel)
setter(node, "nodeId", uuid.UUID(nodeId))
setter(node, "__fields_set__", {"quantum", "taskLabel", "nodeId"})
setter(node, "__fields_set__", _fields_set)
return node
5 changes: 3 additions & 2 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,13 +455,14 @@ def makeQuantum(self, datastore_records: Mapping[str, DatastoreRecordData] | Non
matching_records = records.subset(input_ids)
if matching_records is not None:
quantum_records[datastore_name] = matching_records
# ignore the types because quantum really can take a sequence of inputs
return Quantum(
taskName=self.task.taskDef.taskName,
taskClass=self.task.taskDef.taskClass,
dataId=self.dataId,
initInputs=initInputs,
inputs=helper.inputs,
outputs=helper.outputs,
inputs=helper.inputs, # type: ignore
outputs=helper.outputs, # type: ignore
datastore_records=quantum_records,
)

Expand Down
8 changes: 4 additions & 4 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@
"""Name of a dataset type for log output from this task, `None` if
logs are not to be saved (`str`)
"""
if cast(PipelineTaskConfig, self.config).saveLogOutput:
if self.config.saveLogOutput:
return acc.LOG_OUTPUT_TEMPLATE.format(label=self.label)
else:
return None
Expand Down Expand Up @@ -623,7 +623,7 @@
"""
instrument_class_name = self._pipelineIR.instrument
if instrument_class_name is not None:
instrument_class = doImportType(instrument_class_name)
instrument_class = cast(PipeBaseInstrument, doImportType(instrument_class_name))
if instrument_class is not None:
return DataCoordinate.standardize(instrument=instrument_class.getName(), universe=universe)
return DataCoordinate.makeEmpty(universe)
Expand Down Expand Up @@ -654,8 +654,8 @@
# be defined without label which is not acceptable, use task
# _DefaultName in that case
if isinstance(task, str):
task_class = doImportType(task)
label = task_class._DefaultName
task_class = cast(PipelineTask, doImportType(task))
label = task_class._DefaultName

Check warning on line 658 in python/lsst/pipe/base/pipeline.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/pipeline.py#L657-L658

Added lines #L657 - L658 were not covered by tests
self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)

def removeTask(self, label: str) -> None:
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/pipe/base/script/transfer_from_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@
if refs := qgraph.initOutputRefs(task_def):
original_output_refs.update(refs)
for qnode in qgraph:
for refs in qnode.quantum.outputs.values():
original_output_refs.update(refs)
for otherRefs in qnode.quantum.outputs.values():
original_output_refs.update(otherRefs)

Check warning on line 70 in python/lsst/pipe/base/script/transfer_from_graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/script/transfer_from_graph.py#L70

Added line #L70 was not covered by tests

# Get data repository definitions from the QuantumGraph; these can have
# different storage classes than those in the quanta.
Expand Down
4 changes: 1 addition & 3 deletions python/lsst/pipe/base/tests/mocks/_pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,7 @@ def runQuantum(

# store mock outputs
for name, refs in outputRefs:
if not isinstance(refs, list):
refs = [refs]
for ref in refs:
for ref in ensure_iterable(refs):
output = MockDataset(
ref=ref.to_simple(), quantum=mock_dataset_quantum, output_connection_name=name
)
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/pipe/base/tests/simpleQGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def populateButler(
instrument = pipeline.getInstrument()
if instrument is not None:
instrument_class = doImportType(instrument)
instrumentName = instrument_class.getName()
instrumentName = cast(Instrument, instrument_class).getName()
instrumentClass = get_full_type_name(instrument_class)
else:
instrumentName = "INSTR"
Expand Down