Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-35082 Save DimensionUniverse in QuantumGraph #259

Merged
merged 7 commits into from
Jul 14, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/changes/DM-35082.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
QuantumGraph now saves the DimensionUniverse it was created with when it is persisted. This removes the need
to explicitly pass the DimensionUniverse when loading a saved graph.
30 changes: 18 additions & 12 deletions python/lsst/pipe/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _validateValue(self, value: Any) -> None:

def __set__(
self,
instance: pexConfig.Field,
instance: pexConfig.Config,
value: Any,
at: Optional[StackFrame] = None,
label: str = "assignment",
Expand Down Expand Up @@ -107,7 +107,11 @@ class PipelineTaskConfigMeta(pexConfig.ConfigMeta):
"""

def __new__(
cls: Type[_S], name: str, bases: Tuple[PipelineTaskConfig, ...], dct: Dict[str, Any], **kwargs: Any
cls: Type[_S],
name: str,
bases: Tuple[type[PipelineTaskConfig], ...],
dct: Dict[str, Any],
**kwargs: Any,
) -> _S:
if name != "PipelineTaskConfig":
# Verify that a connection class was specified and the argument is
Expand All @@ -125,10 +129,10 @@ def __new__(

# Create all the fields that will be used in the newly created sub
# config (under the attribute name "connections")
configConnectionsNamespace = {}
configConnectionsNamespace: dict[str, pexConfig.Field] = {}
for fieldName, obj in connectionsClass.allConnections.items():
configConnectionsNamespace[fieldName] = pexConfig.Field(
dtype=str, doc=f"name for connection {fieldName}", default=obj.name
configConnectionsNamespace[fieldName] = pexConfig.Field[str](
doc=f"name for connection {fieldName}", default=obj.name
)
# If there are default templates also add them as fields to
# configure the template values
Expand Down Expand Up @@ -179,14 +183,17 @@ class to allow configuration of the connections class. This dynamically
`~lsst.pex.config.ConfigField` with the attribute name `connections`.
"""

saveMetadata = pexConfig.Field(
dtype=bool,
connections: pexConfig.ConfigField
"""Field which refers to a dynamically added configuration class which is
based on a PipelineTaskConnections class.
"""

saveMetadata = pexConfig.Field[bool](
default=True,
optional=False,
doc="Flag to enable/disable metadata saving for a task, enabled by default.",
)
saveLogOutput = pexConfig.Field(
dtype=bool,
saveLogOutput = pexConfig.Field[bool](
default=True,
optional=False,
doc="Flag to enable/disable saving of log output for a task, enabled by default.",
Expand All @@ -208,10 +215,9 @@ class ResourceConfig(pexConfig.Config):
estimates.
"""

minMemoryMB = pexConfig.Field(
dtype=int,
minMemoryMB = pexConfig.Field[int](
default=None,
optional=True,
doc="Minimal memory needed by task, can be None if estimate is unknown.",
)
minNumCores = pexConfig.Field(dtype=int, default=1, doc="Minimal number of cores needed by task.")
minNumCores = pexConfig.Field[int](default=1, doc="Minimal number of cores needed by task.")
10 changes: 8 additions & 2 deletions python/lsst/pipe/base/graph/_loadHelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def dumpHeader(cls, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: in

def load(
self,
universe: DimensionUniverse,
universe: Optional[DimensionUniverse] = None,
nodes: Optional[Iterable[Union[UUID, str]]] = None,
graphID: Optional[str] = None,
) -> QuantumGraph:
Expand All @@ -216,9 +216,12 @@ def load(

Parameters
----------
universe: `~lsst.daf.butler.DimensionUniverse`
universe: `~lsst.daf.butler.DimensionUniverse` or None
DimensionUniverse instance, not used by the method itself but
needed to ensure that registry data structures are initialized.
The universe saved with the graph is used, but if one is passed
it will be used to validate the compatibility with the loaded
graph universe.
nodes : `Iterable` of `UUID` or `str`; or `None`
The nodes to load from the graph, loads all if value is None
(the default)
Expand All @@ -238,6 +241,9 @@ def load(
Raised if one or more of the nodes requested is not in the
`QuantumGraph` or if graphID parameter does not match the graph
being loaded.
RuntimeError
Raise if Supplied DimensionUniverse is not compatible with the
DimensionUniverse saved in the graph
"""
# verify this is the expected graph
if graphID is not None and self.headerInfo._buildId != graphID:
Expand Down
48 changes: 41 additions & 7 deletions python/lsst/pipe/base/graph/_versionDeserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,16 @@
from typing import TYPE_CHECKING, Callable, ClassVar, DefaultDict, Dict, Optional, Set, Tuple, Type

import networkx as nx
from lsst.daf.butler import DimensionRecord, DimensionUniverse, Quantum, SerializedDimensionRecord
from lsst.pex.config import Config
from lsst.daf.butler import (
DimensionConfig,
DimensionRecord,
DimensionUniverse,
Quantum,
SerializedDimensionRecord,
)
from lsst.utils import doImportType

from ..config import PipelineTaskConfig
from ..pipeline import TaskDef
from ..pipelineTask import PipelineTask
from ._implDetails import DatasetTypeName, _DatasetTracker
Expand Down Expand Up @@ -112,7 +118,10 @@ def readHeaderInfo(self, rawHeader: bytes) -> SimpleNamespace:
raise NotImplementedError("Base class does not implement this method")

def constructGraph(
self, nodes: set[uuid.UUID], _readBytes: Callable[[int, int], bytes], universe: DimensionUniverse
self,
nodes: set[uuid.UUID],
_readBytes: Callable[[int, int], bytes],
universe: Optional[DimensionUniverse] = None,
) -> QuantumGraph:
"""Constructs a graph from the deserialized information.

Expand Down Expand Up @@ -192,7 +201,10 @@ def unpackHeader(self, rawHeader: bytes) -> Optional[str]:
return None

def constructGraph(
self, nodes: set[uuid.UUID], _readBytes: Callable[[int, int], bytes], universe: DimensionUniverse
self,
nodes: set[uuid.UUID],
_readBytes: Callable[[int, int], bytes],
universe: Optional[DimensionUniverse] = None,
) -> QuantumGraph:
# need to import here to avoid cyclic imports
from . import QuantumGraph
Expand Down Expand Up @@ -322,7 +334,10 @@ def unpackHeader(self, rawHeader: bytes) -> Optional[str]:
return lzma.decompress(rawHeader).decode()

def constructGraph(
self, nodes: set[uuid.UUID], _readBytes: Callable[[int, int], bytes], universe: DimensionUniverse
self,
nodes: set[uuid.UUID],
_readBytes: Callable[[int, int], bytes],
universe: Optional[DimensionUniverse] = None,
) -> QuantumGraph:
# need to import here to avoid cyclic imports
from . import QuantumGraph
Expand Down Expand Up @@ -474,14 +489,25 @@ def readHeaderInfo(self, rawHeader: bytes) -> SimpleNamespace:
infoMappings.dimensionRecords = {}
for k, v in infoMap["DimensionRecords"].items():
infoMappings.dimensionRecords[int(k)] = SerializedDimensionRecord(**v)
# This is important to be a get call here, so that it supports versions
# of saved quantum graph that might not have a saved universe without
# changing save format
if (universeConfig := infoMap.get("universe")) is not None:
universe = DimensionUniverse(config=DimensionConfig(universeConfig))
else:
universe = DimensionUniverse()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry a bit that we should be using a universe that was passed into the loader if such a universe exists, although I imagine this code path is never going to trigger because historically people haven't been persisting their graph so that they can load it again 6 months later.

infoMappings.universe = universe
self.infoMappings = infoMappings
return infoMappings

def unpackHeader(self, rawHeader: bytes) -> Optional[str]:
return lzma.decompress(rawHeader).decode()

def constructGraph(
self, nodes: set[uuid.UUID], _readBytes: Callable[[int, int], bytes], universe: DimensionUniverse
self,
nodes: set[uuid.UUID],
_readBytes: Callable[[int, int], bytes],
universe: Optional[DimensionUniverse] = None,
) -> QuantumGraph:
# need to import here to avoid cyclic imports
from . import QuantumGraph
Expand All @@ -493,6 +519,14 @@ def constructGraph(
taskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
recontitutedDimensions: Dict[int, Tuple[str, DimensionRecord]] = {}

if universe is not None:
if not universe.checkCompatibility(self.infoMappings.universe):
raise RuntimeError(
"The saved dimension universe is not compatible with the supplied universe"
timj marked this conversation as resolved.
Show resolved Hide resolved
)
else:
universe = self.infoMappings.universe

for node in nodes:
start, stop = self.infoMappings.map[node]["bytes"]
start, stop = start + self.headerSize, stop + self.headerSize
Expand All @@ -517,7 +551,7 @@ def constructGraph(
# bytes are compressed, so decompress them
taskDefDump = json.loads(lzma.decompress(_readBytes(start, stop)))
taskClass: Type[PipelineTask] = doImportType(taskDefDump["taskName"])
config: Config = taskClass.ConfigClass()
config: PipelineTaskConfig = taskClass.ConfigClass()
config.loadFromStream(taskDefDump["config"])
# Rebuild TaskDef
recreatedTaskDef = TaskDef(
Expand Down
28 changes: 23 additions & 5 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ def __init__(
quanta: Mapping[TaskDef, Set[Quantum]],
metadata: Optional[Mapping[str, Any]] = None,
pruneRefs: Optional[Iterable[DatasetRef]] = None,
universe: Optional[DimensionUniverse] = None,
):
self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs)
self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs, universe=universe)

def _buildGraphs(
self,
Expand All @@ -136,10 +137,14 @@ def _buildGraphs(
_buildId: Optional[BuildId] = None,
metadata: Optional[Mapping[str, Any]] = None,
pruneRefs: Optional[Iterable[DatasetRef]] = None,
universe: Optional[DimensionUniverse] = None,
) -> None:
"""Builds the graph that is used to store the relation between tasks,
and the graph that holds the relations between quanta
"""
if universe is None:
universe = DimensionUniverse()
self._universe = universe
self._metadata = metadata
self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
# Data structures used to identify relations between components;
Expand Down Expand Up @@ -740,7 +745,7 @@ def metadata(self) -> Optional[MappingProxyType[str, Any]]:
def loadUri(
cls,
uri: ResourcePathExpression,
universe: DimensionUniverse,
universe: Optional[DimensionUniverse] = None,
nodes: Optional[Iterable[uuid.UUID]] = None,
graphID: Optional[BuildId] = None,
minimumVersion: int = 3,
Expand All @@ -751,9 +756,12 @@ def loadUri(
----------
uri : convertible to `ResourcePath`
URI from where to load the graph.
universe: `~lsst.daf.butler.DimensionUniverse`
universe: `~lsst.daf.butler.DimensionUniverse` optional
DimensionUniverse instance, not used by the method itself but
needed to ensure that registry data structures are initialized.
If None it is loaded from the QuantumGraph saved structure. If
supplied, the DimensionUniverse from the loaded `QuantumGraph`
will be validated against the supplied argument for compatibility.
nodes: iterable of `int` or None
Numbers that correspond to nodes in the graph. If specified, only
these nodes will be loaded. Defaults to None, in which case all
Expand Down Expand Up @@ -783,6 +791,9 @@ def loadUri(
`QuantumGraph` or if graphID parameter does not match the graph
being loaded or if the supplied uri does not point at a valid
`QuantumGraph` save file.
RuntimeError
Raise if Supplied DimensionUniverse is not compatible with the
DimensionUniverse saved in the graph


Notes
Expand Down Expand Up @@ -883,6 +894,10 @@ def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple
headerData["GraphBuildID"] = self.graphID
headerData["Metadata"] = self._metadata

# Store the universe this graph was created with
universeConfig = self._universe.dimensionConfig
headerData["universe"] = universeConfig.toDict()

# counter for the number of bytes processed thus far
count = 0
# serialize out the task Defs recording the start and end bytes of each
Expand Down Expand Up @@ -1011,7 +1026,7 @@ def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple
def load(
cls,
file: BinaryIO,
universe: DimensionUniverse,
universe: Optional[DimensionUniverse] = None,
nodes: Optional[Iterable[uuid.UUID]] = None,
graphID: Optional[BuildId] = None,
minimumVersion: int = 3,
Expand All @@ -1022,9 +1037,12 @@ def load(
----------
file : `io.IO` of bytes
File with pickle data open in binary mode.
universe: `~lsst.daf.butler.DimensionUniverse`
universe: `~lsst.daf.butler.DimensionUniverse`, optional
DimensionUniverse instance, not used by the method itself but
needed to ensure that registry data structures are initialized.
If None it is loaded from the QuantumGraph saved structure. If
supplied, the DimensionUniverse from the loaded `QuantumGraph`
will be validated against the supplied argument for compatibility.
nodes: iterable of `int` or None
Numbers that correspond to nodes in the graph. If specified, only
these nodes will be loaded. Defaults to None, in which case all
Expand Down
12 changes: 7 additions & 5 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
Tuple,
Type,
Union,
cast,
)

# -----------------------------
Expand All @@ -63,6 +64,7 @@

from . import pipelineIR, pipeTools
from ._task_metadata import TaskMetadata
from .config import PipelineTaskConfig
from .configOverrides import ConfigOverrides
from .connections import iterConnections
from .pipelineTask import PipelineTask
Expand Down Expand Up @@ -119,7 +121,7 @@ class TaskDef:
taskName : `str`, optional
The fully-qualified `PipelineTask` class name. If not provided,
``taskClass`` must be.
config : `lsst.pex.config.Config`, optional
config : `lsst.pipe.base.config.PipelineTaskConfig`, optional
Instance of the configuration class corresponding to this task class,
usually with all overrides applied. This config will be frozen. If
not provided, ``taskClass`` must be provided and
Expand All @@ -137,7 +139,7 @@ class TaskDef:
def __init__(
self,
taskName: Optional[str] = None,
config: Optional[Config] = None,
config: Optional[PipelineTaskConfig] = None,
taskClass: Optional[Type[PipelineTask]] = None,
label: Optional[str] = None,
):
Expand Down Expand Up @@ -203,7 +205,7 @@ def logOutputDatasetName(self) -> Optional[str]:
"""Name of a dataset type for log output from this task, `None` if
logs are not to be saved (`str`)
"""
if self.config.saveLogOutput:
if cast(PipelineTaskConfig, self.config).saveLogOutput:
return self.label + "_log"
else:
return None
Expand All @@ -227,7 +229,7 @@ def __hash__(self) -> int:
return hash((self.taskClass, self.label))

@classmethod
def _unreduce(cls, taskName: str, config: Config, label: str) -> TaskDef:
def _unreduce(cls, taskName: str, config: PipelineTaskConfig, label: str) -> TaskDef:
"""Custom callable for unpickling.

All arguments are forwarded directly to the constructor; this
Expand All @@ -236,7 +238,7 @@ def _unreduce(cls, taskName: str, config: Config, label: str) -> TaskDef:
"""
return cls(taskName=taskName, config=config, label=label)

def __reduce__(self) -> Tuple[Callable[[str, Config, str], TaskDef], Tuple[str, Config, str]]:
def __reduce__(self) -> Tuple[Callable[[str, PipelineTaskConfig, str], TaskDef], Tuple[str, Config, str]]:
return (self._unreduce, (self.taskName, self.config, self.label))


Expand Down