Skip to content

Commit

Permalink
Use enum to make PipelineGraph load options clearer.
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Aug 3, 2023
1 parent 2a8a0c8 commit 49dc2cb
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 75 deletions.
102 changes: 41 additions & 61 deletions python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ._mapping_views import DatasetTypeMappingView, TaskMappingView
from ._nodes import NodeKey, NodeType
from ._task_subsets import TaskSubset
from ._tasks import TaskInitNode, TaskNode, _TaskNodeImportedData
from ._tasks import TaskImportMode, TaskInitNode, TaskNode, _TaskNodeImportedData

if TYPE_CHECKING:
from ..config import PipelineTaskConfig
Expand Down Expand Up @@ -1002,11 +1002,7 @@ def make_dataset_type_xgraph(self, init: bool = False) -> networkx.DiGraph:

@classmethod
def _read_stream(
cls,
stream: BinaryIO,
import_and_configure: bool = True,
check_edges_unchanged: bool = False,
assume_edges_unchanged: bool = False,
cls, stream: BinaryIO, import_mode: TaskImportMode = TaskImportMode.REQUIRE_CONSISTENT_EDGES
) -> PipelineGraph:
"""Read a serialized `PipelineGraph` from a file-like object.
Expand All @@ -1015,15 +1011,11 @@ def _read_stream(
stream : `BinaryIO`
File-like object opened for binary reading, containing
gzip-compressed JSON.
import_and_configure : `bool`, optional
If `True`, import and configure all tasks immediately (see the
`import_and_configure` method). If `False`, some `TaskNode` and
`TaskInitNode` attributes will not be available, but reading may be
much faster.
check_edges_unchanged : `bool`, optional
Forwarded to `import_and_configure` after reading.
assume_edges_unchanged : `bool`, optional
Forwarded to `import_and_configure` after reading.
import_mode : `TaskImportMode`, optional
Whether to import tasks, and how to reconcile any differences
between the imported task's connections and the those that were
persisted with the graph. Default is to check that they are the
same.
Returns
-------
Expand All @@ -1035,8 +1027,9 @@ def _read_stream(
PipelineGraphReadError
Raised if the serialized `PipelineGraph` is not self-consistent.
EdgesChangedError
Raised if ``check_edges_unchanged=True`` and the edges of a task do
change after import and reconfiguration.
Raised if ``import_mode`` is
`TaskImportMode.REQUIRED_CONSISTENT_EDGES` and the edges of a task
did change after import and reconfiguration.
Notes
-----
Expand All @@ -1049,19 +1042,13 @@ def _read_stream(
with gzip.open(stream, "rb") as uncompressed_stream:
data = json.load(uncompressed_stream)
serialized_graph = SerializedPipelineGraph.parse_obj(data)
return serialized_graph.deserialize(
import_and_configure=import_and_configure,
check_edges_unchanged=check_edges_unchanged,
assume_edges_unchanged=assume_edges_unchanged,
)
return serialized_graph.deserialize(import_mode)

@classmethod
def _read_uri(
cls,
uri: ResourcePathExpression,
import_and_configure: bool = True,
check_edges_unchanged: bool = False,
assume_edges_unchanged: bool = False,
import_mode: TaskImportMode = TaskImportMode.REQUIRE_CONSISTENT_EDGES,
) -> PipelineGraph:
"""Read a serialized `PipelineGraph` from a file at a URI.
Expand All @@ -1070,15 +1057,11 @@ def _read_uri(
uri : convertible to `lsst.resources.ResourcePath`
URI to a gzip-compressed JSON file containing a serialized pipeline
graph.
import_and_configure : `bool`, optional
If `True`, import and configure all tasks immediately (see
the `import_and_configure` method). If `False`, some `TaskNode`
and `TaskInitNode` attributes will not be available, but reading
may be much faster.
check_edges_unchanged : `bool`, optional
Forwarded to `import_and_configure` after reading.
assume_edges_unchanged : `bool`, optional
Forwarded to `import_and_configure` after reading.
import_mode : `TaskImportMode`, optional
Whether to import tasks, and how to reconcile any differences
between the imported task's connections and the those that were
persisted with the graph. Default is to check that they are the
same.
Returns
-------
Expand All @@ -1090,8 +1073,9 @@ def _read_uri(
PipelineGraphReadError
Raised if the serialized `PipelineGraph` is not self-consistent.
EdgesChangedError
Raised if ``check_edges_unchanged=True`` and the edges of a task do
change after import and reconfiguration.
Raised if ``import_mode`` is
`TaskImportMode.REQUIRED_CONSISTENT_EDGES` and the edges of a task
did change after import and reconfiguration.
Notes
-----
Expand All @@ -1101,12 +1085,7 @@ def _read_uri(
"""
uri = ResourcePath(uri)
with uri.open("rb") as stream:
return cls._read_stream(
cast(BinaryIO, stream),
import_and_configure=import_and_configure,
check_edges_unchanged=check_edges_unchanged,
assume_edges_unchanged=assume_edges_unchanged,
)
return cls._read_stream(cast(BinaryIO, stream), import_mode=import_mode)

def _write_stream(self, stream: BinaryIO) -> None:
"""Write the pipeline to a file-like object.
Expand Down Expand Up @@ -1164,31 +1143,26 @@ def _write_uri(self, uri: ResourcePathExpression) -> None:
self._write_stream(cast(BinaryIO, stream))

def _import_and_configure(
self, check_edges_unchanged: bool = False, assume_edges_unchanged: bool = False
self, import_mode: TaskImportMode = TaskImportMode.REQUIRE_CONSISTENT_EDGES
) -> None:
"""Import the `PipelineTask` classes referenced by all task nodes and
update those nodes accordingly.
Parameters
----------
check_edges_unchanged : `bool`, optional
If `True`, require the edges (connections) of the modified tasks to
remain unchanged after importing and configuring each task, and
verify that this is the case.
assume_edges_unchanged : `bool`, optional
If `True`, the caller declares that the edges (connections) of the
modified tasks will remain unchanged importing and configuring each
task, and that it is unnecessary to check this.
import_mode : `TaskImportMode`, optional
Whether to import tasks, and how to reconcile any differences
between the imported task's connections and the those that were
persisted with the graph. Default is to check that they are the
same. This method does nothing if this is
`TaskImportMode.DO_NOT_IMPORT`.
Raises
------
ValueError
Raised if ``assume_edges_unchanged`` and ``check_edges_unchanged``
are both `True`, or if a full config is provided for a task after
another full config or an override has already been provided.
EdgesChangedError
Raised if ``check_edges_unchanged=True`` and the edges of a task do
change.
Raised if ``import_mode`` is
`TaskImportMode.REQUIRED_CONSISTENT_EDGES` and the edges of a task
did change after import and reconfiguration.
Notes
-----
Expand All @@ -1202,13 +1176,19 @@ def _import_and_configure(
usually because the software used to read a serialized graph is newer
than the software used to write it (e.g. a new config option has been
added, or the task was moved to a new module with a forwarding alias
left behind). These changes are allowed by ``check=True``.
left behind). These changes are allowed by
`TaskImportMode.REQUIRE_CONSISTENT_EDGES`.
If importing and configuring a task causes its edges to change, any
dataset type nodes linked to those edges will be reset to the
unresolved state.
"""
rebuild = check_edges_unchanged or not assume_edges_unchanged
if import_mode is TaskImportMode.DO_NOT_IMPORT:
return
rebuild = (
import_mode is TaskImportMode.REQUIRE_CONSISTENT_EDGES
or import_mode is TaskImportMode.OVERRIDE_EDGES
)
updates: dict[str, TaskNode] = {}
node_key: NodeKey
for node_key, node_state in self._xgraph.nodes.items():
Expand All @@ -1219,8 +1199,8 @@ def _import_and_configure(
updates[task_node.label] = new_task_node
self._replace_task_nodes(
updates,
check_edges_unchanged=check_edges_unchanged,
assume_edges_unchanged=assume_edges_unchanged,
check_edges_unchanged=(import_mode is TaskImportMode.REQUIRE_CONSISTENT_EDGES),
assume_edges_unchanged=(import_mode is TaskImportMode.ASSUME_CONSISTENT_EDGES),
message_header=(
"In task with label {task_label!r}, persisted edges (A)"
"differ from imported and configured edges (B):"
Expand Down
35 changes: 34 additions & 1 deletion python/lsst/pipe/base/pipeline_graph/_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations

__all__ = ("TaskNode", "TaskInitNode")
__all__ = ("TaskNode", "TaskInitNode", "TaskImportMode")

import dataclasses
import enum
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING, Any, cast

Expand All @@ -43,6 +44,38 @@
from ..pipelineTask import PipelineTask


class TaskImportMode(enum.Enum):
"""Enumeration of the ways to handle importing tasks when reading a
serialized PipelineGraph.
"""

DO_NOT_IMPORT = enum.auto()
"""Do not import tasks or instantiate their configs and connections."""

REQUIRE_CONSISTENT_EDGES = enum.auto()
"""Import tasks and instantiate their config and connection objects, and
check that the connections still define the same edges.
"""

ASSUME_CONSISTENT_EDGES = enum.auto()
"""Import tasks and instantiate their config and connection objects, but do
not check that the connections still define the same edges.
This is safe only when the caller knows the task definition has not changed
since the pipeline graph was persisted, such as when it was saved and
loaded with the same pipeline version.
"""

OVERRIDE_EDGES = enum.auto()
"""Import tasks and instantiate their config and connection objects, and
allow the edges defined in those connections to override those in the
persisted graph.
This may cause dataset type nodes to be unresolved, since resolutions
consistent with the original edges may be invalidated.
"""


@dataclasses.dataclass(frozen=True)
class _TaskNodeImportedData:
"""An internal struct that holds `TaskNode` and `TaskInitNode` state that
Expand Down
12 changes: 3 additions & 9 deletions python/lsst/pipe/base/pipeline_graph/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from ._nodes import NodeKey, NodeType
from ._pipeline_graph import PipelineGraph
from ._task_subsets import TaskSubset
from ._tasks import TaskInitNode, TaskNode
from ._tasks import TaskImportMode, TaskInitNode, TaskNode

_U = TypeVar("_U")

Expand Down Expand Up @@ -527,9 +527,7 @@ def serialize(cls, target: PipelineGraph) -> SerializedPipelineGraph:

def deserialize(
self,
import_and_configure: bool = True,
check_edges_unchanged: bool = False,
assume_edges_unchanged: bool = False,
import_mode: TaskImportMode,
) -> PipelineGraph:
"""Transform a `SerializedPipelineGraph` into a `PipelineGraph`."""
universe: DimensionUniverse | None = None
Expand Down Expand Up @@ -615,9 +613,5 @@ def deserialize(
universe=universe,
data_id=self.data_id,
)
if import_and_configure:
result._import_and_configure(
check_edges_unchanged=check_edges_unchanged,
assume_edges_unchanged=assume_edges_unchanged,
)
result._import_and_configure(import_mode)
return result
9 changes: 5 additions & 4 deletions tests/test_pipeline_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
NodeType,
PipelineGraph,
PipelineGraphError,
TaskImportMode,
UnresolvedGraphError,
)
from lsst.pipe.base.tests.mocks import (
Expand Down Expand Up @@ -154,12 +155,12 @@ def test_unresolved_deferred_import_io(self) -> None:
stream = io.BytesIO()
self.graph._write_stream(stream)
stream.seek(0)
roundtripped = PipelineGraph._read_stream(stream, import_and_configure=False)
roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT)
self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False)
# Check that we can still resolve the graph without importing tasks.
roundtripped.resolve(MockRegistry(self.dimensions, {}))
self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False)
roundtripped._import_and_configure(assume_edges_unchanged=True)
roundtripped._import_and_configure(TaskImportMode.ASSUME_CONSISTENT_EDGES)
self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True)

def test_resolved_accessors(self) -> None:
Expand Down Expand Up @@ -221,9 +222,9 @@ def test_resolved_deferred_import_io(self) -> None:
stream = io.BytesIO()
self.graph._write_stream(stream)
stream.seek(0)
roundtripped = PipelineGraph._read_stream(stream, import_and_configure=False)
roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT)
self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False)
roundtripped._import_and_configure(check_edges_unchanged=True)
roundtripped._import_and_configure(TaskImportMode.REQUIRE_CONSISTENT_EDGES)
self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True)

def test_unresolved_copies(self) -> None:
Expand Down

0 comments on commit 49dc2cb

Please sign in to comment.