Skip to content

Commit

Permalink
Make SerializedPipelineGraph compatible with pydantic v2.
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Jul 24, 2023
1 parent 60f68d0 commit 1253e4f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
2 changes: 1 addition & 1 deletion python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ def _write_stream(self, stream: BinaryIO) -> None:

with gzip.open(stream, mode="wb") as compressed_stream:
compressed_stream.write(
SerializedPipelineGraph.serialize(self).json(exclude_defaults=True, indent=2).encode("utf-8")
SerializedPipelineGraph.serialize(self).json(exclude_defaults=True).encode("utf-8")
)

def _write_uri(self, uri: ResourcePathExpression) -> None:
Expand Down
27 changes: 14 additions & 13 deletions python/lsst/pipe/base/pipeline_graph/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import networkx
import pydantic
from lsst.daf.butler import DatasetType, DimensionConfig, DimensionGraph, DimensionUniverse
from lsst.daf.butler._compat import _BaseModelCompat

from .. import automatic_connection_constants as acc
from ._dataset_types import DatasetTypeNode
Expand Down Expand Up @@ -78,7 +79,7 @@ def expect_not_none(value: _U | None, msg: str) -> _U:
return value


class SerializedEdge(pydantic.BaseModel):
class SerializedEdge(_BaseModelCompat):
"""Struct used to represent a serialized `Edge` in a `PipelineGraph`.
All `ReadEdge` and `WriteEdge` state not included here is instead
Expand Down Expand Up @@ -107,7 +108,7 @@ class SerializedEdge(pydantic.BaseModel):
@classmethod
def serialize(cls, target: Edge) -> SerializedEdge:
"""Transform an `Edge` to a `SerializedEdge`."""
return SerializedEdge.construct(
return SerializedEdge.model_construct(
storage_class=target.storage_class_name,
dataset_type_name=target.dataset_type_name,
raw_dimensions=sorted(target.raw_dimensions),
Expand Down Expand Up @@ -153,7 +154,7 @@ def deserialize_write_edge(
)


class SerializedTaskInitNode(pydantic.BaseModel):
class SerializedTaskInitNode(_BaseModelCompat):
"""Struct used to represent a serialized `TaskInitNode` in a
`PipelineGraph`.
Expand Down Expand Up @@ -182,7 +183,7 @@ class SerializedTaskInitNode(pydantic.BaseModel):
@classmethod
def serialize(cls, target: TaskInitNode) -> SerializedTaskInitNode:
"""Transform a `TaskInitNode` to a `SerializedTaskInitNode`."""
return cls.construct(
return cls.model_construct(
inputs={
connection_name: SerializedEdge.serialize(edge)
for connection_name, edge in sorted(target.inputs.items())
Expand Down Expand Up @@ -224,7 +225,7 @@ def deserialize(
)


class SerializedTaskNode(pydantic.BaseModel):
class SerializedTaskNode(_BaseModelCompat):
"""Struct used to represent a serialized `TaskNode` in a `PipelineGraph`.
The task label is serialized by the context in which a
Expand Down Expand Up @@ -271,7 +272,7 @@ class SerializedTaskNode(pydantic.BaseModel):
@classmethod
def serialize(cls, target: TaskNode) -> SerializedTaskNode:
"""Transform a `TaskNode` to a `SerializedTaskNode`."""
return cls.construct(
return cls.model_construct(
task_class=target.task_class_name,
init=SerializedTaskInitNode.serialize(target.init),
config_str=target.get_config_str(),
Expand Down Expand Up @@ -350,7 +351,7 @@ def deserialize(
)


class SerializedDatasetTypeNode(pydantic.BaseModel):
class SerializedDatasetTypeNode(_BaseModelCompat):
"""Struct used to represent a serialized `DatasetTypeNode` in a
`PipelineGraph`.
Expand Down Expand Up @@ -391,8 +392,8 @@ class SerializedDatasetTypeNode(pydantic.BaseModel):
def serialize(cls, target: DatasetTypeNode | None) -> SerializedDatasetTypeNode:
"""Transform a `DatasetTypeNode` to a `SerializedDatasetTypeNode`."""
if target is None:
return cls.construct()
return cls.construct(
return cls.model_construct()
return cls.model_construct(
dimensions=list(target.dataset_type.dimensions.names),
storage_class=target.dataset_type.storageClass_name,
is_calibration=target.dataset_type.isCalibration(),
Expand Down Expand Up @@ -445,7 +446,7 @@ def deserialize(
return None


class SerializedTaskSubset(pydantic.BaseModel):
class SerializedTaskSubset(_BaseModelCompat):
"""Struct used to represent a serialized `TaskSubset` in a `PipelineGraph`.
The subsetlabel is serialized by the context in which a
Expand All @@ -464,15 +465,15 @@ class SerializedTaskSubset(pydantic.BaseModel):
@classmethod
def serialize(cls, target: TaskSubset) -> SerializedTaskSubset:
"""Transform a `TaskSubset` into a `SerializedTaskSubset`."""
return cls.construct(description=target._description, tasks=list(sorted(target)))
return cls.model_construct(description=target._description, tasks=list(sorted(target)))

def deserialize_task_subset(self, label: str, xgraph: networkx.MultiDiGraph) -> TaskSubset:
"""Transform a `SerializedTaskSubset` into a `TaskSubset`."""
members = set(self.tasks)
return TaskSubset(xgraph, label, members, self.description)


class SerializedPipelineGraph(pydantic.BaseModel):
class SerializedPipelineGraph(_BaseModelCompat):
"""Struct used to represent a serialized `PipelineGraph`."""

version: str = ".".join(str(v) for v in _IO_VERSION_INFO)
Expand Down Expand Up @@ -500,7 +501,7 @@ class SerializedPipelineGraph(pydantic.BaseModel):
@classmethod
def serialize(cls, target: PipelineGraph) -> SerializedPipelineGraph:
"""Transform a `PipelineGraph` into a `SerializedPipelineGraph`."""
result = SerializedPipelineGraph.construct(
result = SerializedPipelineGraph.model_construct(
description=target.description,
tasks={label: SerializedTaskNode.serialize(node) for label, node in target.tasks.items()},
dataset_types={
Expand Down

0 comments on commit 1253e4f

Please sign in to comment.