diff --git a/doc/changes/DM-53019.feature.md b/doc/changes/DM-53019.feature.md new file mode 100644 index 000000000..9543ca6b3 --- /dev/null +++ b/doc/changes/DM-53019.feature.md @@ -0,0 +1,6 @@ +Improve provenance tracking for failed quanta and retries. + +By storing extra information in the log datasets written during extra +information, we can record caught exceptions, track which other quanta have +already executed in the same process, and keep track of previous attempts to +run the same quantum. diff --git a/python/lsst/pipe/base/_status.py b/python/lsst/pipe/base/_status.py index 6e555e4ab..0b4713209 100644 --- a/python/lsst/pipe/base/_status.py +++ b/python/lsst/pipe/base/_status.py @@ -27,28 +27,37 @@ from __future__ import annotations +__all__ = ( + "AlgorithmError", + "AnnotatedPartialOutputsError", + "ExceptionInfo", + "InvalidQuantumError", + "NoWorkFound", + "QuantumAttemptStatus", + "QuantumSuccessCaveats", + "RepeatableQuantumError", + "UnprocessableDataError", + "UpstreamFailureNoWorkFound", +) + import abc import enum import logging +import sys from typing import TYPE_CHECKING, Any, ClassVar, Protocol +import pydantic + from lsst.utils import introspection +from lsst.utils.logging import LsstLogAdapter, getLogger from ._task_metadata import GetSetDictMetadata, NestedMetadataDict if TYPE_CHECKING: - from lsst.utils.logging import LsstLogAdapter + from ._task_metadata import TaskMetadata -__all__ = ( - "AlgorithmError", - "AnnotatedPartialOutputsError", - "InvalidQuantumError", - "NoWorkFound", - "QuantumSuccessCaveats", - "RepeatableQuantumError", - "UnprocessableDataError", - "UpstreamFailureNoWorkFound", -) + +_LOG = getLogger(__name__) class QuantumSuccessCaveats(enum.Flag): @@ -175,6 +184,142 @@ def legend() -> dict[str, str]: } +class ExceptionInfo(pydantic.BaseModel): + """Information about an exception that was raised.""" + + type_name: str + """Fully-qualified Python type name for the exception raised.""" + + message: str + """String message included in the exception.""" + + metadata: dict[str, float | int | str | bool | None] + """Additional metadata included in the exception.""" + + @classmethod + def _from_metadata(cls, md: TaskMetadata) -> ExceptionInfo: + """Construct from task metadata. + + Parameters + ---------- + md : `TaskMetadata` + Metadata about the error, as written by + `AnnotatedPartialOutputsError`. + + Returns + ------- + info : `ExceptionInfo` + Information about the exception. + """ + result = cls(type_name=md["type"], message=md["message"], metadata={}) + if "metadata" in md: + raw_err_metadata = md["metadata"].to_dict() + for k, v in raw_err_metadata.items(): + # Guard against error metadata we wouldn't be able to serialize + # later via Pydantic; don't want one weird value bringing down + # our ability to report on an entire run. + if isinstance(v, float | int | str | bool): + result.metadata[k] = v + else: + _LOG.debug( + "Not propagating nested or JSON-incompatible exception metadata key %s=%r.", k, v + ) + return result + + # Work around the fact that Sphinx chokes on Pydantic docstring formatting, + # when we inherit those docstrings in our public classes. + if "sphinx" in sys.modules and not TYPE_CHECKING: + + def copy(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.copy`.""" + return super().copy(*args, **kwargs) + + def model_dump(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_dump`.""" + return super().model_dump(*args, **kwargs) + + def model_dump_json(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_dump_json`.""" + return super().model_dump(*args, **kwargs) + + def model_copy(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_copy`.""" + return super().model_copy(*args, **kwargs) + + @classmethod + def model_construct(cls, *args: Any, **kwargs: Any) -> Any: # type: ignore[misc, override] + """See `pydantic.BaseModel.model_construct`.""" + return super().model_construct(*args, **kwargs) + + @classmethod + def model_json_schema(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_json_schema`.""" + return super().model_json_schema(*args, **kwargs) + + @classmethod + def model_validate(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_validate`.""" + return super().model_validate(*args, **kwargs) + + @classmethod + def model_validate_json(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_validate_json`.""" + return super().model_validate_json(*args, **kwargs) + + @classmethod + def model_validate_strings(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_validate_strings`.""" + return super().model_validate_strings(*args, **kwargs) + + +class QuantumAttemptStatus(enum.Enum): + """Enum summarizing an attempt to run a quantum.""" + + UNKNOWN = -3 + """The status of this attempt is unknown. + + This usually means no logs or metadata were written, and it at least could + not be determined whether the quantum was blocked by an upstream failure + (if it was definitely blocked, `BLOCKED` is set instead). + """ + + LOGS_MISSING = -2 + """Task metadata was written for this attempt but logs were not. + + This is a rare condition that requires a hard failure (i.e. the kind that + can prevent a ``finally`` block from running or I/O from being durable) at + a very precise time. + """ + + FAILED = -1 + """Execution of the quantum failed. + + This is always set if the task metadata dataset was not written but logs + were, as is the case when a Python exception is caught and handled by the + execution system. It may also be set in cases where logs were not written + either, but other information was available (e.g. from higher-level + orchestration tooling) to mark it as a failure. + """ + + BLOCKED = 0 + """This quantum was not executed because an upstream quantum failed. + + Upstream quanta with status `UNKNOWN` or `FAILED` are considered blockers; + `LOGS_MISSING` is not. + """ + + SUCCESSFUL = 1 + """This quantum was successfully executed. + + Quanta may be considered successful even if they do not write any outputs + or shortcut early by raising `NoWorkFound` or one of its variants. They + may even be considered successful if they raise + `AnnotatedPartialOutputsError` if the executor is configured to treat that + exception as a non-failure. See `QuantumSuccessCaveats` for details on how + these "successes with caveats" are reported. + """ + + class GetSetDictMetadataHolder(Protocol): """Protocol for objects that have a ``metadata`` attribute that satisfies `GetSetDictMetadata`. diff --git a/python/lsst/pipe/base/log_capture.py b/python/lsst/pipe/base/log_capture.py index 01a1edd85..9b18fed44 100644 --- a/python/lsst/pipe/base/log_capture.py +++ b/python/lsst/pipe/base/log_capture.py @@ -29,28 +29,105 @@ __all__ = ["LogCapture"] +import dataclasses import logging import os import shutil import tempfile +import uuid from collections.abc import Iterator from contextlib import contextmanager, suppress from logging import FileHandler -from lsst.daf.butler import Butler, FileDataset, LimitedButler, Quantum -from lsst.daf.butler.logging import ButlerLogRecordHandler, ButlerLogRecords, ButlerMDC, JsonLogFormatter +import pydantic -from ._status import InvalidQuantumError +from lsst.daf.butler import Butler, FileDataset, LimitedButler, Quantum +from lsst.daf.butler.logging import ( + ButlerLogRecord, + ButlerLogRecordHandler, + ButlerLogRecords, + ButlerMDC, + JsonLogFormatter, +) + +from ._status import ExceptionInfo, InvalidQuantumError +from ._task_metadata import TaskMetadata from .automatic_connection_constants import METADATA_OUTPUT_TEMPLATE from .pipeline_graph import TaskNode _LOG = logging.getLogger(__name__) -class _LogCaptureFlag: - """Simple flag to enable/disable log-to-butler saving.""" +class _ExecutionLogRecordsExtra(pydantic.BaseModel): + """Extra information about a quantum's execution stored with logs. + + This middleware-private model includes information that is not directly + available via any public interface, as it is used exclusively for + provenance extraction and then made available through the provenance + quantum graph. + """ + + exception: ExceptionInfo | None = None + """Exception information for this quantum, if it failed. + """ + + metadata: TaskMetadata | None = None + """Metadata for this quantum, if it failed. + + Metadata datasets are written if and only if a quantum succeeds, but we + still want to capture metadata from failed attempts, so we store it in the + log dataset. This field is always `None` when the quantum succeeds, + because in that case the metadata is already stored separately. + """ + + previous_process_quanta: list[uuid.UUID] = pydantic.Field(default_factory=list) + """The IDs of other quanta previously executed in the same process as this + one. + """ + + logs: list[ButlerLogRecord] = pydantic.Field(default_factory=list) + """Logs for this attempt. + + This is always empty for the most recent attempt, because that stores logs + in the main section of the butler log records. + """ + + previous_attempts: list[_ExecutionLogRecordsExtra] = pydantic.Field(default_factory=list) + """Information about previous attempts to run this task within the same + `~lsst.daf.butler.CollectionType.RUN` collection. + + This is always empty for any attempt other than the most recent one, + as all previous attempts are flattened into one list. + """ + + def attach_previous_attempt(self, log_records: ButlerLogRecords) -> None: + """Attach logs from a previous attempt to this struct. + + Parameters + ---------- + log_records : `ButlerLogRecords` + Logs from a past attempt to run a quantum. + """ + previous = self.model_validate(log_records.extra) + previous.logs.extend(log_records) + self.previous_attempts.extend(previous.previous_attempts) + self.previous_attempts.append(previous) + previous.previous_attempts.clear() + + +@dataclasses.dataclass +class _LogCaptureContext: + """Controls for log capture returned by the `LogCapture.capture_logging` + context manager. + """ store: bool = True + """Whether to store logs at all.""" + + extra: _ExecutionLogRecordsExtra = dataclasses.field(default_factory=_ExecutionLogRecordsExtra) + """Extra information about the quantum's execution to store for provenance + extraction. + """ class LogCapture: @@ -88,7 +165,7 @@ def from_full(cls, butler: Butler) -> LogCapture: return cls(butler, butler) @contextmanager - def capture_logging(self, task_node: TaskNode, /, quantum: Quantum) -> Iterator[_LogCaptureFlag]: + def capture_logging(self, task_node: TaskNode, /, quantum: Quantum) -> Iterator[_LogCaptureContext]: """Configure logging system to capture logs for execution of this task. Parameters @@ -121,7 +198,7 @@ def capture_logging(self, task_node: TaskNode, /, quantum: Quantum) -> Iterator[ metadata_ref = quantum.outputs[METADATA_OUTPUT_TEMPLATE.format(label=task_node.label)][0] mdc["RUN"] = metadata_ref.run - ctx = _LogCaptureFlag() + ctx = _LogCaptureContext() log_dataset_name = ( task_node.log_output.dataset_type_name if task_node.log_output is not None else None ) @@ -154,6 +231,12 @@ def capture_logging(self, task_node: TaskNode, /, quantum: Quantum) -> Iterator[ # Ensure that the logs are stored in butler. logging.getLogger().removeHandler(log_handler_file) log_handler_file.close() + if ctx.extra: + with open(log_file, "a") as log_stream: + ButlerLogRecords.write_streaming_extra( + log_stream, + ctx.extra.model_dump_json(exclude_unset=True, exclude_defaults=True), + ) if ctx.store: self._ingest_log_records(quantum, log_dataset_name, log_file) shutil.rmtree(tmpdir, ignore_errors=True) @@ -165,7 +248,15 @@ def capture_logging(self, task_node: TaskNode, /, quantum: Quantum) -> Iterator[ try: with ButlerMDC.set_mdc(mdc): yield ctx + except: + raise + else: + # If the quantum succeeded, we don't need to save the + # metadata in the logs, because we'll have saved them in + # the metadata. + ctx.extra.metadata = None finally: + log_handler_memory.records.extra = ctx.extra.model_dump() # Ensure that the logs are stored in butler. logging.getLogger().removeHandler(log_handler_memory) if ctx.store: diff --git a/python/lsst/pipe/base/quantum_graph/_common.py b/python/lsst/pipe/base/quantum_graph/_common.py index 6594c72a3..d4ea47a15 100644 --- a/python/lsst/pipe/base/quantum_graph/_common.py +++ b/python/lsst/pipe/base/quantum_graph/_common.py @@ -60,6 +60,7 @@ import zstandard from lsst.daf.butler import DataCoordinate, DataIdValue +from lsst.daf.butler._rubin import generate_uuidv7 from lsst.resources import ResourcePath, ResourcePathExpression from ..pipeline_graph import DatasetTypeNode, Edge, PipelineGraph, TaskImportMode, TaskNode @@ -157,6 +158,11 @@ class HeaderModel(pydantic.BaseModel): quantum graph file). """ + provenance_dataset_id: uuid.UUID = pydantic.Field(default_factory=generate_uuidv7) + """The dataset ID for provenance quantum graph when it is ingested into + a butler repository. + """ + @classmethod def from_old_quantum_graph(cls, old_quantum_graph: QuantumGraph) -> HeaderModel: """Extract a header from an old `QuantumGraph` instance. diff --git a/python/lsst/pipe/base/quantum_graph/_predicted.py b/python/lsst/pipe/base/quantum_graph/_predicted.py index ac0dabeea..d0a2acb1c 100644 --- a/python/lsst/pipe/base/quantum_graph/_predicted.py +++ b/python/lsst/pipe/base/quantum_graph/_predicted.py @@ -1899,11 +1899,12 @@ def finish(self) -> PredictedQuantumGraph: """Construct a `PredictedQuantumGraph` instance from this reader.""" return self.components.assemble() - def read_all(self) -> PredictedQuantumGraphReader: + def read_all(self) -> None: """Read all components in full.""" - return self.read_thin_graph().read_execution_quanta() + self.read_thin_graph() + self.read_execution_quanta() - def read_thin_graph(self) -> PredictedQuantumGraphReader: + def read_thin_graph(self) -> None: """Read the thin graph. The thin graph is a quantum-quantum DAG with internal integer IDs for @@ -1918,17 +1919,15 @@ def read_thin_graph(self) -> PredictedQuantumGraphReader: self.components.quantum_indices.update( {row.key: row.index for row in self.address_reader.rows.values()} ) - return self - def read_init_quanta(self) -> PredictedQuantumGraphReader: + def read_init_quanta(self) -> None: """Read the list of special quanta that represent init-inputs and init-outputs. """ if not self.components.init_quanta.root: self.components.init_quanta = self._read_single_block("init_quanta", PredictedInitQuantaModel) - return self - def read_dimension_data(self) -> PredictedQuantumGraphReader: + def read_dimension_data(self) -> None: """Read all dimension records. Record data IDs will be immediately deserialized, while other fields @@ -1948,11 +1947,8 @@ def read_dimension_data(self) -> PredictedQuantumGraphReader: universe=self.components.pipeline_graph.universe, ), ) - return self - def read_quantum_datasets( - self, quantum_ids: Iterable[uuid.UUID] | None = None - ) -> PredictedQuantumGraphReader: + def read_quantum_datasets(self, quantum_ids: Iterable[uuid.UUID] | None = None) -> None: """Read information about all datasets produced and consumed by the given quantum IDs. @@ -1977,7 +1973,7 @@ def read_quantum_datasets( self.address_reader.read_all() for address_row in self.address_reader.rows.values(): self.components.quantum_indices[address_row.key] = address_row.index - return self + return with MultiblockReader.open_in_zip( self.zf, "quantum_datasets", int_size=self.components.header.int_size ) as mb_reader: @@ -1991,11 +1987,9 @@ def read_quantum_datasets( ) if quantum_datasets is not None: self.components.quantum_datasets[address_row.key] = quantum_datasets - return self + return - def read_execution_quanta( - self, quantum_ids: Iterable[uuid.UUID] | None = None - ) -> PredictedQuantumGraphReader: + def read_execution_quanta(self, quantum_ids: Iterable[uuid.UUID] | None = None) -> None: """Read all information needed to execute the given quanta. Parameters @@ -2004,4 +1998,6 @@ def read_execution_quanta( Iterable of quantum IDs to load. If not provided, all quanta will be loaded. The UUIDs of special init quanta will be ignored. """ - return self.read_init_quanta().read_dimension_data().read_quantum_datasets(quantum_ids) + self.read_init_quanta() + self.read_dimension_data() + self.read_quantum_datasets(quantum_ids) diff --git a/python/lsst/pipe/base/quantum_graph/_provenance.py b/python/lsst/pipe/base/quantum_graph/_provenance.py index e4050c2e9..e09cdca11 100644 --- a/python/lsst/pipe/base/quantum_graph/_provenance.py +++ b/python/lsst/pipe/base/quantum_graph/_provenance.py @@ -32,10 +32,12 @@ "ProvenanceDatasetModel", "ProvenanceInitQuantumInfo", "ProvenanceInitQuantumModel", + "ProvenanceLogRecordsModel", "ProvenanceQuantumGraph", "ProvenanceQuantumGraphReader", "ProvenanceQuantumInfo", "ProvenanceQuantumModel", + "ProvenanceTaskMetadataModel", ) @@ -45,7 +47,7 @@ from collections import Counter from collections.abc import Iterable, Iterator, Mapping from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Self, TypedDict +from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypedDict, TypeVar import astropy.table import networkx @@ -53,12 +55,13 @@ import pydantic from lsst.daf.butler import DataCoordinate +from lsst.daf.butler.logging import ButlerLogRecord, ButlerLogRecords from lsst.resources import ResourcePathExpression from lsst.utils.packages import Packages -from .._status import QuantumSuccessCaveats +from .._status import ExceptionInfo, QuantumAttemptStatus, QuantumSuccessCaveats +from .._task_metadata import TaskMetadata from ..pipeline_graph import PipelineGraph, TaskImportMode, TaskInitNode -from ..quantum_provenance_graph import ExceptionInfo, QuantumRunStatus from ..resource_usage import QuantumResourceUsage from ._common import ( BaseQuantumGraph, @@ -76,12 +79,6 @@ from ._multiblock import AddressReader, MultiblockReader from ._predicted import PredictedDatasetModel, PredictedQuantumDatasetsModel -if TYPE_CHECKING: - from lsst.daf.butler.logging import ButlerLogRecords - - from .._task_metadata import TaskMetadata - - DATASET_ADDRESS_INDEX = 0 QUANTUM_ADDRESS_INDEX = 1 LOG_ADDRESS_INDEX = 2 @@ -92,6 +89,8 @@ LOG_MB_NAME = "logs" METADATA_MB_NAME = "metadata" +_I = TypeVar("_I", bound=uuid.UUID | int) + class ProvenanceDatasetInfo(DatasetInfo): """A typed dictionary that annotates the attributes of the NetworkX graph @@ -108,13 +107,13 @@ class ProvenanceDatasetInfo(DatasetInfo): dataset_id: uuid.UUID """Unique identifier for the dataset.""" - exists: bool - """Whether this dataset existed immediately after the quantum graph was - run. + produced: bool + """Whether this dataset was produced (vs. only predicted). This is always `True` for overall input datasets. It is also `True` for datasets that were produced and then removed before/during transfer back to - the central butler repository. + the central butler repository, so it may not reflect the continued + existence of the dataset. """ @@ -131,17 +130,38 @@ class ProvenanceQuantumInfo(QuantumInfo): `ProvenanceQuantumGraph.quantum_only_xgraph` """ - status: QuantumRunStatus - """Enumerated status for the quantum.""" + status: QuantumAttemptStatus + """Enumerated status for the quantum. + + This corresponds to the last attempt to run this quantum, or + `QuantumAttemptStatus.BLOCKED` if there were no attempts. + """ caveats: QuantumSuccessCaveats | None - """Flags indicating caveats on successful quanta.""" + """Flags indicating caveats on successful quanta. + + This corresponds to the last attempt to run this quantum. + """ exception: ExceptionInfo | None - """Information about an exception raised when the quantum was executing.""" + """Information about an exception raised when the quantum was executing. + + This corresponds to the last attempt to run this quantum. + """ resource_usage: QuantumResourceUsage | None - """Resource usage information (timing, memory use) for this quantum.""" + """Resource usage information (timing, memory use) for this quantum. + + This corresponds to the last attempt to run this quantum. + """ + + attempts: list[ProvenanceQuantumAttemptModel] + """Information about each attempt to run this quantum. + + An entry is added merely if the quantum *should* have been attempted; an + empty `list` is used only for quanta that were blocked by an upstream + failure. + """ class ProvenanceInitQuantumInfo(TypedDict): @@ -173,13 +193,13 @@ class ProvenanceInitQuantumInfo(TypedDict): class ProvenanceDatasetModel(PredictedDatasetModel): """Data model for the datasets in a provenance quantum graph file.""" - exists: bool - """Whether this dataset existed immediately after the quantum graph was - run. + produced: bool + """Whether this dataset was produced (vs. only predicted). This is always `True` for overall input datasets. It is also `True` for datasets that were produced and then removed before/during transfer back to - the central butler repository. + the central butler repository, so it may not reflect the continued + existence of the dataset. """ producer: QuantumIndex | None = None @@ -225,7 +245,7 @@ def from_predicted( Notes ----- - This initializes `exists` to `True` when ``producer is None`` and + This initializes `produced` to `True` when ``producer is None`` and `False` otherwise, on the assumption that it will be updated later. """ return cls.model_construct( @@ -233,7 +253,7 @@ def from_predicted( dataset_type_name=predicted.dataset_type_name, data_coordinate=predicted.data_coordinate, run=predicted.run, - exists=(producer is None), # if it's not produced by this QG, it's an overall input + produced=(producer is None), # if it's not produced by this QG, it's an overall input producer=producer, consumers=list(consumers), ) @@ -268,7 +288,7 @@ def _add_to_graph(self, graph: ProvenanceQuantumGraph, address_reader: AddressRe dataset_type_name=self.dataset_type_name, pipeline_node=dataset_type_node, run=self.run, - exists=self.exists, + produced=self.produced, ) producer_id: uuid.UUID | None = None if self.producer is not None: @@ -327,24 +347,15 @@ def model_validate_strings(cls, *args: Any, **kwargs: Any) -> Any: return super().model_validate_strings(*args, **kwargs) -class ProvenanceQuantumModel(pydantic.BaseModel): - """Data model for the quanta in a provenance quantum graph file.""" - - quantum_id: uuid.UUID - """Unique identifier for the quantum.""" - - task_label: TaskLabel - """Name of the type of this dataset. - - This is always a parent dataset type name, not a component. - - Note that full dataset type definitions are stored in the pipeline graph. +class _GenericProvenanceQuantumAttemptModel(pydantic.BaseModel, Generic[_I]): + """Data model for a now-superseded attempt to run a quantum in a + provenance quantum graph file. """ - data_coordinate: DataCoordinateValues = pydantic.Field(default_factory=list) - """The full values (required and implied) of this dataset's data ID.""" + attempt: int = 0 + """Counter incremented for every attempt to execute this quantum.""" - status: QuantumRunStatus = QuantumRunStatus.METADATA_MISSING + status: QuantumAttemptStatus = QuantumAttemptStatus.UNKNOWN """Enumerated status for the quantum.""" caveats: QuantumSuccessCaveats | None = None @@ -353,6 +364,212 @@ class ProvenanceQuantumModel(pydantic.BaseModel): exception: ExceptionInfo | None = None """Information about an exception raised when the quantum was executing.""" + resource_usage: QuantumResourceUsage | None = None + """Resource usage information (timing, memory use) for this quantum.""" + + previous_process_quanta: list[_I] = pydantic.Field(default_factory=list) + """The IDs of other quanta previously executed in the same process as this + one. + """ + + def remap_uuids( + self: ProvenanceQuantumAttemptModel, indices: Mapping[uuid.UUID, QuantumIndex] + ) -> StorageProvenanceQuantumAttemptModel: + return StorageProvenanceQuantumAttemptModel( + attempt=self.attempt, + status=self.status, + caveats=self.caveats, + exception=self.exception, + resource_usage=self.resource_usage, + previous_process_quanta=[indices[q] for q in self.previous_process_quanta], + ) + + def remap_indices( + self: StorageProvenanceQuantumAttemptModel, address_reader: AddressReader + ) -> ProvenanceQuantumAttemptModel: + return ProvenanceQuantumAttemptModel( + attempt=self.attempt, + status=self.status, + caveats=self.caveats, + exception=self.exception, + resource_usage=self.resource_usage, + previous_process_quanta=[address_reader.find(q).key for q in self.previous_process_quanta], + ) + + # Work around the fact that Sphinx chokes on Pydantic docstring formatting, + # when we inherit those docstrings in our public classes. + if "sphinx" in sys.modules and not TYPE_CHECKING: + + def copy(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.copy`.""" + return super().copy(*args, **kwargs) + + def model_dump(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_dump`.""" + return super().model_dump(*args, **kwargs) + + def model_dump_json(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_dump_json`.""" + return super().model_dump(*args, **kwargs) + + def model_copy(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_copy`.""" + return super().model_copy(*args, **kwargs) + + @classmethod + def model_construct(cls, *args: Any, **kwargs: Any) -> Any: # type: ignore[misc, override] + """See `pydantic.BaseModel.model_construct`.""" + return super().model_construct(*args, **kwargs) + + @classmethod + def model_json_schema(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_json_schema`.""" + return super().model_json_schema(*args, **kwargs) + + @classmethod + def model_validate(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_validate`.""" + return super().model_validate(*args, **kwargs) + + @classmethod + def model_validate_json(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_validate_json`.""" + return super().model_validate_json(*args, **kwargs) + + @classmethod + def model_validate_strings(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_validate_strings`.""" + return super().model_validate_strings(*args, **kwargs) + + +StorageProvenanceQuantumAttemptModel: TypeAlias = _GenericProvenanceQuantumAttemptModel[QuantumIndex] +ProvenanceQuantumAttemptModel: TypeAlias = _GenericProvenanceQuantumAttemptModel[uuid.UUID] + + +class ProvenanceLogRecordsModel(pydantic.BaseModel): + """Data model for storing execution logs in a provenance quantum graph + file. + """ + + attempts: list[list[ButlerLogRecord] | None] = pydantic.Field(default_factory=list) + """Logs from attempts to run this task, ordered chronologically from first + to last. + """ + + # Work around the fact that Sphinx chokes on Pydantic docstring formatting, + # when we inherit those docstrings in our public classes. + if "sphinx" in sys.modules and not TYPE_CHECKING: + + def copy(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.copy`.""" + return super().copy(*args, **kwargs) + + def model_dump(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_dump`.""" + return super().model_dump(*args, **kwargs) + + def model_dump_json(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_dump_json`.""" + return super().model_dump(*args, **kwargs) + + def model_copy(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_copy`.""" + return super().model_copy(*args, **kwargs) + + @classmethod + def model_construct(cls, *args: Any, **kwargs: Any) -> Any: # type: ignore[misc, override] + """See `pydantic.BaseModel.model_construct`.""" + return super().model_construct(*args, **kwargs) + + @classmethod + def model_json_schema(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_json_schema`.""" + return super().model_json_schema(*args, **kwargs) + + @classmethod + def model_validate(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_validate`.""" + return super().model_validate(*args, **kwargs) + + @classmethod + def model_validate_json(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_validate_json`.""" + return super().model_validate_json(*args, **kwargs) + + @classmethod + def model_validate_strings(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_validate_strings`.""" + return super().model_validate_strings(*args, **kwargs) + + +class ProvenanceTaskMetadataModel(pydantic.BaseModel): + """Data model for storing task metadata in a provenance quantum graph + file. + """ + + attempts: list[TaskMetadata | None] = pydantic.Field(default_factory=list) + """Metadata from attempts to run this task, ordered chronologically from + first to last. + """ + + # Work around the fact that Sphinx chokes on Pydantic docstring formatting, + # when we inherit those docstrings in our public classes. + if "sphinx" in sys.modules and not TYPE_CHECKING: + + def copy(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.copy`.""" + return super().copy(*args, **kwargs) + + def model_dump(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_dump`.""" + return super().model_dump(*args, **kwargs) + + def model_dump_json(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_dump_json`.""" + return super().model_dump(*args, **kwargs) + + def model_copy(self, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_copy`.""" + return super().model_copy(*args, **kwargs) + + @classmethod + def model_construct(cls, *args: Any, **kwargs: Any) -> Any: # type: ignore[misc, override] + """See `pydantic.BaseModel.model_construct`.""" + return super().model_construct(*args, **kwargs) + + @classmethod + def model_json_schema(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_json_schema`.""" + return super().model_json_schema(*args, **kwargs) + + @classmethod + def model_validate(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_validate`.""" + return super().model_validate(*args, **kwargs) + + @classmethod + def model_validate_json(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_validate_json`.""" + return super().model_validate_json(*args, **kwargs) + + @classmethod + def model_validate_strings(cls, *args: Any, **kwargs: Any) -> Any: + """See `pydantic.BaseModel.model_validate_strings`.""" + return super().model_validate_strings(*args, **kwargs) + + +class ProvenanceQuantumModel(pydantic.BaseModel): + """Data model for the quanta in a provenance quantum graph file.""" + + quantum_id: uuid.UUID + """Unique identifier for the quantum.""" + + task_label: TaskLabel + """Name of the type of this dataset.""" + + data_coordinate: DataCoordinateValues = pydantic.Field(default_factory=list) + """The full values (required and implied) of this dataset's data ID.""" + inputs: dict[ConnectionName, list[DatasetIndex]] = pydantic.Field(default_factory=dict) """Internal integer IDs of the datasets predicted to be consumed by this quantum, grouped by connection name. @@ -363,8 +580,14 @@ class ProvenanceQuantumModel(pydantic.BaseModel): quantum, grouped by connection name. """ - resource_usage: QuantumResourceUsage | None = None - """Resource usage information (timing, memory use) for this quantum.""" + attempts: list[StorageProvenanceQuantumAttemptModel] = pydantic.Field(default_factory=list) + """Provenance for all attempts to execute this quantum, ordered + chronologically from first to last. + + An entry is added merely if the quantum *should* have been attempted; an + empty `list` is used only for quanta that were blocked by an upstream + failure. + """ @property def node_id(self) -> uuid.UUID: @@ -429,15 +652,21 @@ def _add_to_graph(self, graph: ProvenanceQuantumGraph, address_reader: AddressRe """ task_node = graph.pipeline_graph.tasks[self.task_label] data_id = DataCoordinate.from_full_values(task_node.dimensions, tuple(self.data_coordinate)) + last_attempt = ( + self.attempts[-1] + if self.attempts + else StorageProvenanceQuantumAttemptModel(status=QuantumAttemptStatus.BLOCKED) + ) graph._bipartite_xgraph.add_node( self.quantum_id, data_id=data_id, task_label=self.task_label, pipeline_node=task_node, - status=self.status, - caveats=self.caveats, - exception=self.exception, - resource_usage=self.resource_usage, + status=last_attempt.status, + caveats=last_attempt.caveats, + exception=last_attempt.exception, + resource_usage=last_attempt.resource_usage, + attempts=[a.remap_indices(address_reader) for a in self.attempts], ) for connection_name, dataset_indices in self.inputs.items(): read_edge = task_node.get_input_edge(connection_name) @@ -881,7 +1110,7 @@ def make_quantum_table(self) -> astropy.table.Table: for task_label, quanta_for_task in self.quanta_by_task.items(): if not self.header.n_task_quanta[task_label]: continue - status_counts = Counter[QuantumRunStatus]( + status_counts = Counter[QuantumAttemptStatus]( self._quantum_only_xgraph.nodes[q]["status"] for q in quanta_for_task.values() ) caveat_counts = Counter[QuantumSuccessCaveats | None]( @@ -901,11 +1130,11 @@ def make_quantum_table(self) -> astropy.table.Table: rows.append( { "Task": task_label, - "Unknown": status_counts.get(QuantumRunStatus.METADATA_MISSING, 0), - "Successful": status_counts.get(QuantumRunStatus.SUCCESSFUL, 0), + "Unknown": status_counts.get(QuantumAttemptStatus.UNKNOWN, 0), + "Successful": status_counts.get(QuantumAttemptStatus.SUCCESSFUL, 0), "Caveats": caveats, - "Blocked": status_counts.get(QuantumRunStatus.BLOCKED, 0), - "Failed": status_counts.get(QuantumRunStatus.FAILED, 0), + "Blocked": status_counts.get(QuantumAttemptStatus.BLOCKED, 0), + "Failed": status_counts.get(QuantumAttemptStatus.FAILED, 0), "TOTAL": len(quanta_for_task), "EXPECTED": self.header.n_task_quanta[task_label], } @@ -988,7 +1217,7 @@ class ProvenanceQuantumGraphReader(BaseQuantumGraphReader): the `graph` attribute`. The various ``read_*`` methods in this class update the `graph` attribute - in place and return ``self``. + in place. """ graph: ProvenanceQuantumGraph = dataclasses.field(init=False) @@ -1037,30 +1266,19 @@ def open( def __post_init__(self) -> None: self.graph = ProvenanceQuantumGraph(self.header, self.pipeline_graph) - def read_init_quanta(self) -> Self: + def read_init_quanta(self) -> None: """Read the thin graph, with all edge information and categorization of quanta by task label. - - Returns - ------- - self : `ProvenanceQuantumGraphReader` - The reader (to permit method-chaining). """ init_quanta = self._read_single_block("init_quanta", ProvenanceInitQuantaModel) for init_quantum in init_quanta.root: self.graph._init_quanta[init_quantum.task_label] = init_quantum.quantum_id init_quanta._add_to_graph(self.graph, self.address_reader) - return self - def read_full_graph(self) -> Self: + def read_full_graph(self) -> None: """Read all bipartite edges and all quantum and dataset node attributes, fully populating the `graph` attribute. - Returns - ------- - self : `ProvenanceQuantumGraphReader` - The reader (to permit method-chaining). - Notes ----- This does not read logs, metadata, or packages ; those must always be @@ -1069,9 +1287,8 @@ def read_full_graph(self) -> Self: self.read_init_quanta() self.read_datasets() self.read_quanta() - return self - def read_datasets(self, datasets: Iterable[uuid.UUID | DatasetIndex] | None = None) -> Self: + def read_datasets(self, datasets: Iterable[uuid.UUID | DatasetIndex] | None = None) -> None: """Read information about the given datasets. Parameters @@ -1080,15 +1297,10 @@ def read_datasets(self, datasets: Iterable[uuid.UUID | DatasetIndex] | None = No Iterable of dataset IDs or indices to load. If not provided, all datasets will be loaded. The UUIDs and indices of quanta will be ignored. - - Return - ------- - self : `ProvenanceQuantumGraphReader` - The reader (to permit method-chaining). """ - return self._read_nodes(datasets, DATASET_ADDRESS_INDEX, DATASET_MB_NAME, ProvenanceDatasetModel) + self._read_nodes(datasets, DATASET_ADDRESS_INDEX, DATASET_MB_NAME, ProvenanceDatasetModel) - def read_quanta(self, quanta: Iterable[uuid.UUID | QuantumIndex] | None = None) -> Self: + def read_quanta(self, quanta: Iterable[uuid.UUID | QuantumIndex] | None = None) -> None: """Read information about the given quanta. Parameters @@ -1097,13 +1309,8 @@ def read_quanta(self, quanta: Iterable[uuid.UUID | QuantumIndex] | None = None) Iterable of quantum IDs or indices to load. If not provided, all quanta will be loaded. The UUIDs and indices of datasets and special init quanta will be ignored. - - Return - ------- - self : `ProvenanceQuantumGraphReader` - The reader (to permit method-chaining). """ - return self._read_nodes(quanta, QUANTUM_ADDRESS_INDEX, QUANTUM_MB_NAME, ProvenanceQuantumModel) + self._read_nodes(quanta, QUANTUM_ADDRESS_INDEX, QUANTUM_MB_NAME, ProvenanceQuantumModel) def _read_nodes( self, @@ -1111,7 +1318,7 @@ def _read_nodes( address_index: int, mb_name: str, model_type: type[ProvenanceDatasetModel] | type[ProvenanceQuantumModel], - ) -> Self: + ) -> None: node: ProvenanceDatasetModel | ProvenanceQuantumModel | None if nodes is None: self.address_reader.read_all() @@ -1129,6 +1336,7 @@ def _read_nodes( # also have other outstanding reference holders). continue node._add_to_graph(self.graph, self.address_reader) + return with MultiblockReader.open_in_zip(self.zf, mb_name, int_size=self.header.int_size) as mb_reader: for node_id_or_index in nodes: address_row = self.address_reader.find(node_id_or_index) @@ -1141,66 +1349,74 @@ def _read_nodes( ) if node is not None: node._add_to_graph(self.graph, self.address_reader) - return self def fetch_logs( self, nodes: Iterable[uuid.UUID | DatasetIndex | QuantumIndex] - ) -> dict[uuid.UUID | DatasetIndex | QuantumIndex, ButlerLogRecords]: + ) -> dict[uuid.UUID | DatasetIndex | QuantumIndex, list[ButlerLogRecords | None]]: """Fetch log datasets. Parameters ---------- nodes : `~collections.abc.Iterable` [ `uuid.UUID` ] - UUIDs of the log datasets themselves or of the quanta they - correspond to. + UUIDs or internal integer IDS of the log datasets themselves or of + the quanta they correspond to. Returns ------- - logs : `dict` [ `uuid.UUID`, `ButlerLogRecords`] - Logs for the given IDs. + logs : `dict` [ `uuid.UUID` or `int`, `list` [\ + `lsst.daf.butler.ButlerLogRecords` or `None`] ] + Logs for the given IDs. Each value is a list of + `lsst.daf.butler.ButlerLogRecords` instances representing different + execution attempts, ordered chronologically from first to last. + Attempts where logs were missing will have `None` in this list. """ - from lsst.daf.butler.logging import ButlerLogRecords - - result: dict[uuid.UUID | DatasetIndex | QuantumIndex, ButlerLogRecords] = {} + result: dict[uuid.UUID | DatasetIndex | QuantumIndex, list[ButlerLogRecords | None]] = {} with MultiblockReader.open_in_zip(self.zf, LOG_MB_NAME, int_size=self.header.int_size) as mb_reader: for node_id_or_index in nodes: address_row = self.address_reader.find(node_id_or_index) - log = mb_reader.read_model( - address_row.addresses[LOG_ADDRESS_INDEX], ButlerLogRecords, self.decompressor + logs_by_attempt = mb_reader.read_model( + address_row.addresses[LOG_ADDRESS_INDEX], ProvenanceLogRecordsModel, self.decompressor ) - if log is not None: - result[node_id_or_index] = log + if logs_by_attempt is not None: + result[node_id_or_index] = [ + ButlerLogRecords.from_records(attempt_logs) if attempt_logs is not None else None + for attempt_logs in logs_by_attempt.attempts + ] return result def fetch_metadata( self, nodes: Iterable[uuid.UUID | DatasetIndex | QuantumIndex] - ) -> dict[uuid.UUID | DatasetIndex | QuantumIndex, TaskMetadata]: + ) -> dict[uuid.UUID | DatasetIndex | QuantumIndex, list[TaskMetadata | None]]: """Fetch metadata datasets. Parameters ---------- nodes : `~collections.abc.Iterable` [ `uuid.UUID` ] - UUIDs of the metadata datasets themselves or of the quanta they - correspond to. + UUIDs or internal integer IDs of the metadata datasets themselves + or of the quanta they correspond to. Returns ------- - metadata : `dict` [ `uuid.UUID`, `TaskMetadata`] - Metadata for the given IDs. + metadata : `dict` [ `uuid.UUID` or `int`, `list` [`.TaskMetadata`] ] + Metadata for the given IDs. Each value is a list of + `.TaskMetadata` instances representing different execution + attempts, ordered chronologically from first to last. Attempts + where metadata was missing (not written even in the fallback extra + provenance in the logs) will have `None` in this list. """ - from .._task_metadata import TaskMetadata - - result: dict[uuid.UUID | DatasetIndex | QuantumIndex, TaskMetadata] = {} + result: dict[uuid.UUID | DatasetIndex | QuantumIndex, list[TaskMetadata | None]] = {} with MultiblockReader.open_in_zip( self.zf, METADATA_MB_NAME, int_size=self.header.int_size ) as mb_reader: for node_id_or_index in nodes: address_row = self.address_reader.find(node_id_or_index) - metadata = mb_reader.read_model( - address_row.addresses[METADATA_ADDRESS_INDEX], TaskMetadata, self.decompressor + metadata_by_attempt = mb_reader.read_model( + address_row.addresses[METADATA_ADDRESS_INDEX], + ProvenanceTaskMetadataModel, + self.decompressor, ) - if metadata is not None: - result[node_id_or_index] = metadata + if metadata_by_attempt is not None: + result[node_id_or_index] = metadata_by_attempt.attempts return result def fetch_packages(self) -> Packages: diff --git a/python/lsst/pipe/base/quantum_graph/aggregator/_communicators.py b/python/lsst/pipe/base/quantum_graph/aggregator/_communicators.py index ce43cfe9c..2bd5a9223 100644 --- a/python/lsst/pipe/base/quantum_graph/aggregator/_communicators.py +++ b/python/lsst/pipe/base/quantum_graph/aggregator/_communicators.py @@ -59,7 +59,7 @@ from lsst.utils.logging import VERBOSE, LsstLogAdapter from ._config import AggregatorConfig -from ._progress import Progress, make_worker_log +from ._progress import ProgressManager, make_worker_log from ._structs import IngestRequest, ScanReport, ScanResult _T = TypeVar("_T") @@ -340,7 +340,7 @@ def __init__( config: AggregatorConfig, ) -> None: self.config = config - self.progress = Progress(log, config) + self.progress = ProgressManager(log, config) self.n_scanners = n_scanners # The supervisor sends scan requests to scanners on this queue. # When complete, the supervisor sends n_scanners sentinals and each @@ -406,13 +406,13 @@ def wait_for_workers_to_finish(self, already_failing: bool = False) -> None: pass case _Sentinel.INGESTER_DONE: self._ingester_done = True - self.progress.finish_ingests() + self.progress.quantum_ingests.close() case _Sentinel.SCANNER_DONE: self._n_scanners_done += 1 - self.progress.finish_scans() + self.progress.scans.close() case _Sentinel.WRITER_DONE: self._writer_done = True - self.progress.finish_writes() + self.progress.writes.close() case unexpected: raise AssertionError(f"Unexpected message {unexpected!r} to supervisor.") self.log.verbose( @@ -530,9 +530,9 @@ def _handle_progress_reports( if not already_failing: raise FatalWorkerError() case _IngestReport(n_producers=n_producers): - self.progress.report_ingests(n_producers) + self.progress.quantum_ingests.update(n_producers) case _Sentinel.WRITE_REPORT: - self.progress.report_write() + self.progress.writes.update(1) case _ProgressLog(message=message, level=level): self.progress.log.log(level, "%s [after %0.1fs]", message, self.progress.elapsed_time) case _: @@ -626,10 +626,10 @@ def log_progress(self, level: int, message: str) -> None: Parameters ---------- - message : `str` - Log message. level : `int` Log level. Should be ``VERBOSE`` or higher. + message : `str` + Log message. """ self._reports.put(_ProgressLog(message=message, level=level), block=False) diff --git a/python/lsst/pipe/base/quantum_graph/aggregator/_progress.py b/python/lsst/pipe/base/quantum_graph/aggregator/_progress.py index 7abf46b73..0c6221b50 100644 --- a/python/lsst/pipe/base/quantum_graph/aggregator/_progress.py +++ b/python/lsst/pipe/base/quantum_graph/aggregator/_progress.py @@ -27,20 +27,86 @@ from __future__ import annotations -__all__ = ("Progress", "make_worker_log") +__all__ = ("ProgressCounter", "ProgressManager", "make_worker_log") import logging import os import time from types import TracebackType -from typing import Self +from typing import Any, Self from lsst.utils.logging import TRACE, VERBOSE, LsstLogAdapter, PeriodicLogger, getLogger from ._config import AggregatorConfig -class Progress: +class ProgressCounter: + """A progress tracker for an individual aspect of the aggregation process. + + Parameters + ---------- + parent : `ProgressManager` + The parent progress manager object. + description : `str` + Human-readable description of this aspect. + unit : `str` + Unit (in plural form) for the items being counted. + total : `int`, optional + Expected total number of items. May be set later. + """ + + def __init__(self, parent: ProgressManager, description: str, unit: str, total: int | None = None): + self._parent = parent + self.total = total + self._description = description + self._current = 0 + self._unit = unit + self._bar: Any = None + + def update(self, n: int) -> None: + """Report that ``n`` new items have been processed. + + Parameters + ---------- + n : `int` + Number of new items processed. + """ + self._current += n + if self._parent.interactive: + if self._bar is None: + if n == self.total: + return + from tqdm import tqdm + + self._bar = tqdm(desc=self._description, total=self.total, leave=False, unit=f" {self._unit}") + else: + self._bar.update(n) + if self._current == self.total: + self._bar.close() + self._parent._log_status() + + def close(self) -> None: + """Close the counter, guaranteeing that `update` will not be called + again. + """ + if self._bar is not None: + self._bar.close() + self._bar = None + + def append_log_terms(self, msg: list[str]) -> None: + """Append a log message for this counter to a list if it is active. + + Parameters + ---------- + msg : `list` [ `str` ] + List of messages to concatenate into a single line and log + together, to be modified in-place. + """ + if self.total is not None and self._current > 0 and self._current < self.total: + msg.append(f"{self._description} ({self._current} of {self.total} {self._unit})") + + +class ProgressManager: """A helper class for the provenance aggregator that handles reporting progress to the user. @@ -66,10 +132,9 @@ def __init__(self, log: LsstLogAdapter, config: AggregatorConfig): self.log = log self.config = config self._periodic_log = PeriodicLogger(self.log, config.log_status_interval) - self._n_scanned: int = 0 - self._n_ingested: int = 0 - self._n_written: int = 0 - self._n_quanta: int | None = None + self.scans = ProgressCounter(self, "scanning", "quanta") + self.writes = ProgressCounter(self, "writing", "quanta") + self.quantum_ingests = ProgressCounter(self, "ingesting outputs", "quanta") self.interactive = config.interactive_status def __enter__(self) -> Self: @@ -90,29 +155,6 @@ def __exit__( self._logging_redirect.__exit__(exc_type, exc_value, traceback) return None - def set_n_quanta(self, n_quanta: int) -> None: - """Set the total number of quanta. - - Parameters - ---------- - n_quanta : `int` - Total number of quanta, including special "init" quanta. - - Notes - ----- - This method must be called before any of the ``report_*`` methods. - """ - self._n_quanta = n_quanta - if self.interactive: - from tqdm import tqdm - - self._scan_progress = tqdm(desc="Scanning", total=n_quanta, leave=False, unit="quanta") - self._ingest_progress = tqdm( - desc="Ingesting", total=n_quanta, leave=False, smoothing=0.1, unit="quanta" - ) - if self.config.output_path is not None: - self._write_progress = tqdm(desc="Writing", total=n_quanta, leave=False, unit="quanta") - @property def elapsed_time(self) -> float: """The time in seconds since the start of the aggregator.""" @@ -120,60 +162,11 @@ def elapsed_time(self) -> float: def _log_status(self) -> None: """Invoke the periodic logger with the current status.""" - self._periodic_log.log( - "%s quanta scanned, %s quantum outputs ingested, " - "%s provenance quanta written (of %s) after %0.1fs.", - self._n_scanned, - self._n_ingested, - self._n_written, - self._n_quanta, - self.elapsed_time, - ) - - def report_scan(self) -> None: - """Report that a quantum was scanned.""" - self._n_scanned += 1 - if self.interactive: - self._scan_progress.update(1) - else: - self._log_status() - - def finish_scans(self) -> None: - """Report that all scanning is done.""" - if self.interactive: - self._scan_progress.close() - - def report_ingests(self, n_quanta: int) -> None: - """Report that ingests for multiple quanta were completed. - - Parameters - ---------- - n_quanta : `int` - Number of quanta whose outputs were ingested. - """ - self._n_ingested += n_quanta - if self.interactive: - self._ingest_progress.update(n_quanta) - else: - self._log_status() - - def finish_ingests(self) -> None: - """Report that all ingests are done.""" - if self.interactive: - self._ingest_progress.close() - - def report_write(self) -> None: - """Report that a quantum's provenance was written.""" - self._n_written += 1 - if self.interactive: - self._write_progress.update() - else: - self._log_status() - - def finish_writes(self) -> None: - """Report that all writes are done.""" - if self.interactive: - self._write_progress.close() + log_terms: list[str] = [] + self.scans.append_log_terms(log_terms) + self.writes.append_log_terms(log_terms) + self.quantum_ingests.append_log_terms(log_terms) + self._periodic_log.log("Status after %0.1fs: %s.", self.elapsed_time, "; ".join(log_terms)) def make_worker_log(name: str, config: AggregatorConfig) -> LsstLogAdapter: diff --git a/python/lsst/pipe/base/quantum_graph/aggregator/_scanner.py b/python/lsst/pipe/base/quantum_graph/aggregator/_scanner.py index 612901038..5a58ac8b5 100644 --- a/python/lsst/pipe/base/quantum_graph/aggregator/_scanner.py +++ b/python/lsst/pipe/base/quantum_graph/aggregator/_scanner.py @@ -39,10 +39,10 @@ from lsst.utils.iteration import ensure_iterable from ... import automatic_connection_constants as acc -from ..._status import QuantumSuccessCaveats +from ..._status import ExceptionInfo, QuantumAttemptStatus, QuantumSuccessCaveats from ..._task_metadata import TaskMetadata +from ...log_capture import _ExecutionLogRecordsExtra from ...pipeline_graph import PipelineGraph, TaskImportMode -from ...quantum_provenance_graph import ExceptionInfo from ...resource_usage import QuantumResourceUsage from .._multiblock import Compressor from .._predicted import ( @@ -50,6 +50,7 @@ PredictedQuantumDatasetsModel, PredictedQuantumGraphReader, ) +from .._provenance import ProvenanceQuantumAttemptModel from ._communicators import ScannerCommunicator from ._structs import IngestRequest, ScanReport, ScanResult, ScanStatus @@ -179,7 +180,7 @@ def scan_dataset(self, predicted: PredictedDatasetModel) -> bool: Returns ------- exists : `bool`` - Whether the dataset exists + Whether the dataset exists. """ ref = self.reader.components.make_dataset_ref(predicted) return self.qbb.stored(ref) @@ -212,29 +213,67 @@ def scan_quantum(self, quantum_id: uuid.UUID) -> ScanResult: ) result = ScanResult(predicted_quantum.quantum_id, ScanStatus.INCOMPLETE) del self.reader.components.quantum_datasets[quantum_id] - log_id = self._read_and_compress_log(predicted_quantum, result) - if not self.comms.config.assume_complete and not result.log: + last_attempt = ProvenanceQuantumAttemptModel() + if not self._read_log(predicted_quantum, result, last_attempt): self.comms.log.debug("Abandoning scan for %s; no log dataset.", quantum_id) - result.status = ScanStatus.ABANDONED self.comms.report_scan(ScanReport(result.quantum_id, result.status)) return result - metadata_id = self._read_and_compress_metadata(predicted_quantum, result) - if result.metadata: - result.status = ScanStatus.SUCCESSFUL - result.existing_outputs.add(metadata_id) - elif self.comms.config.assume_complete: - result.status = ScanStatus.FAILED - else: + if not self._read_metadata(predicted_quantum, result, last_attempt): # We found the log dataset, but no metadata; this means the # quantum failed, but a retry might still happen that could # turn it into a success if we can't yet assume the run is # complete. self.comms.log.debug("Abandoning scan for %s.", quantum_id) - result.status = ScanStatus.ABANDONED self.comms.report_scan(ScanReport(result.quantum_id, result.status)) return result - if result.log: - result.existing_outputs.add(log_id) + last_attempt.attempt = len(result.attempts) + result.attempts.append(last_attempt) + assert result.status is not ScanStatus.INCOMPLETE + assert result.status is not ScanStatus.ABANDONED + assert result.log_model is not None, "Only set to None after converting to JSON." + assert result.metadata_model is not None, "Only set to None after converting to JSON." + + if len(result.log_model.attempts) < len(result.attempts): + # Logs were not found for this attempt; must have been a hard error + # that kept the `finally` block from running or otherwise + # interrupted the writing of the logs. + result.log_model.attempts.append(None) + if result.status is ScanStatus.SUCCESSFUL: + # But we found the metadata! Either that hard error happened + # at a very unlucky time (in between those two writes), or + # something even weirder happened. + result.attempts[-1].status = QuantumAttemptStatus.LOGS_MISSING + else: + result.attempts[-1].status = QuantumAttemptStatus.FAILED + if len(result.metadata_model.attempts) < len(result.attempts): + # Metadata missing usually just means a failure. In any case, the + # status will already be correct, either because it was set to a + # failure when we read the logs, or left at UNKNOWN if there were + # no logs. Note that scanners never process BLOCKED quanta at all. + result.metadata_model.attempts.append(None) + assert len(result.log_model.attempts) == len(result.attempts) or len( + result.metadata_model.attempts + ) == len(result.attempts), ( + "The only way we can add more than one quantum attempt is by " + "extracting info stored with the logs, and that always appends " + "a log attempt and a metadata attempt, so this must be a bug in " + "the scanner." + ) + # Now that we're done gathering the log and metadata information into + # models, dump them to JSON and delete the originals. + result.log_content = result.log_model.model_dump_json().encode() + result.log_model = None + result.metadata_content = result.metadata_model.model_dump_json().encode() + result.metadata_model = None + if self.compressor is not None: + if result.log_content is not None: + result.log_content = self.compressor.compress(result.log_content) + if result.metadata_content is not None: + result.metadata_content = self.compressor.compress(result.metadata_content) + result.is_compressed = True + # Scan for output dataset existence, skipping any the metadata reported + # as having been definitively written, as well as and the metadata and + # logs themselves (since we just checked those). for predicted_output in itertools.chain.from_iterable(predicted_quantum.outputs.values()): if predicted_output.dataset_id not in result.existing_outputs and self.scan_dataset( predicted_output @@ -242,8 +281,6 @@ def scan_quantum(self, quantum_id: uuid.UUID) -> ScanResult: result.existing_outputs.add(predicted_output.dataset_id) to_ingest = self._make_ingest_request(predicted_quantum, result) self.comms.report_scan(ScanReport(result.quantum_id, result.status)) - assert result.status is not ScanStatus.INCOMPLETE - assert result.status is not ScanStatus.ABANDONED if self.comms.config.output_path is not None: self.comms.request_write(result) self.comms.request_ingest(to_ingest) @@ -279,9 +316,12 @@ def _make_ingest_request( to_ingest_records = self.qbb._datastore.export_predicted_records(to_ingest_refs) return IngestRequest(result.quantum_id, to_ingest_predicted, to_ingest_records) - def _read_and_compress_metadata( - self, predicted_quantum: PredictedQuantumDatasetsModel, result: ScanResult - ) -> uuid.UUID: + def _read_metadata( + self, + predicted_quantum: PredictedQuantumDatasetsModel, + result: ScanResult, + last_attempt: ProvenanceQuantumAttemptModel, + ) -> bool: """Attempt to read the metadata dataset for a quantum to extract provenance information from it. @@ -291,53 +331,62 @@ def _read_and_compress_metadata( Information about the predicted quantum. result : `ScanResult` Result object to be modified in-place. + last_attempt : `ScanningProvenanceQuantumAttemptModel` + Structure to fill in with information about the last attempt to + run this quantum. Returns ------- - dataset_id : `uuid.UUID` - UUID of the metadata dataset. + complete : `bool` + Whether the quantum is complete. """ - assert not result.metadata, "We shouldn't be scanning again if we already read the metadata." (predicted_dataset,) = predicted_quantum.outputs[acc.METADATA_OUTPUT_CONNECTION_NAME] ref = self.reader.components.make_dataset_ref(predicted_dataset) try: # This assumes QBB metadata writes are atomic, which should be the # case. If it's not we'll probably get pydantic validation errors # here. - content: TaskMetadata = self.qbb.get(ref, storageClass="TaskMetadata") + metadata: TaskMetadata = self.qbb.get(ref, storageClass="TaskMetadata") except FileNotFoundError: - if not self.comms.config.assume_complete: - return ref.id + if self.comms.config.assume_complete: + result.status = ScanStatus.FAILED + else: + result.status = ScanStatus.ABANDONED + return False else: + result.status = ScanStatus.SUCCESSFUL + result.existing_outputs.add(ref.id) + last_attempt.status = QuantumAttemptStatus.SUCCESSFUL try: # Int conversion guards against spurious conversion to # float that can apparently sometimes happen in # TaskMetadata. - result.caveats = QuantumSuccessCaveats(int(content["quantum"]["caveats"])) + last_attempt.caveats = QuantumSuccessCaveats(int(metadata["quantum"]["caveats"])) except LookupError: pass try: - result.exception = ExceptionInfo._from_metadata( - content[predicted_quantum.task_label]["failure"] + last_attempt.exception = ExceptionInfo._from_metadata( + metadata[predicted_quantum.task_label]["failure"] ) except LookupError: pass try: - result.existing_outputs = { - uuid.UUID(id_str) for id_str in ensure_iterable(content["quantum"].getArray("outputs")) - } + result.existing_outputs.update( + uuid.UUID(id_str) for id_str in ensure_iterable(metadata["quantum"].getArray("outputs")) + ) except LookupError: pass - result.resource_usage = QuantumResourceUsage.from_task_metadata(content) - result.metadata = content.model_dump_json().encode() - if self.compressor is not None: - result.metadata = self.compressor.compress(result.metadata) - result.is_compressed = True - return ref.id - - def _read_and_compress_log( - self, predicted_quantum: PredictedQuantumDatasetsModel, result: ScanResult - ) -> uuid.UUID: + last_attempt.resource_usage = QuantumResourceUsage.from_task_metadata(metadata) + assert result.metadata_model is not None, "Only set to None after converting to JSON." + result.metadata_model.attempts.append(metadata) + return True + + def _read_log( + self, + predicted_quantum: PredictedQuantumDatasetsModel, + result: ScanResult, + last_attempt: ProvenanceQuantumAttemptModel, + ) -> bool: """Attempt to read the log dataset for a quantum to test for the quantum's completion (the log is always written last) and aggregate the log content in the provenance quantum graph. @@ -348,24 +397,76 @@ def _read_and_compress_log( Information about the predicted quantum. result : `ScanResult` Result object to be modified in-place. + last_attempt : `ScanningProvenanceQuantumAttemptModel` + Structure to fill in with information about the last attempt to + run this quantum. Returns ------- - dataset_id : `uuid.UUID` - UUID of the log dataset. + complete : `bool` + Whether the quantum is complete. """ (predicted_dataset,) = predicted_quantum.outputs[acc.LOG_OUTPUT_CONNECTION_NAME] ref = self.reader.components.make_dataset_ref(predicted_dataset) try: # This assumes QBB log writes are atomic, which should be the case. # If it's not we'll probably get pydantic validation errors here. - content: ButlerLogRecords = self.qbb.get(ref) + log_records: ButlerLogRecords = self.qbb.get(ref) except FileNotFoundError: - if not self.comms.config.assume_complete: - return ref.id + if self.comms.config.assume_complete: + result.status = ScanStatus.FAILED + else: + result.status = ScanStatus.ABANDONED + return False else: - result.log = content.model_dump_json().encode() - if self.compressor is not None: - result.log = self.compressor.compress(result.log) - result.is_compressed = True - return ref.id + # Set the attempt's run status to FAILED, since the default is + # UNKNOWN (i.e. logs *and* metadata are missing) and we now know + # the logs exist. This will usually get replaced by SUCCESSFUL + # when we look for metadata next. + last_attempt.status = QuantumAttemptStatus.FAILED + result.existing_outputs.add(ref.id) + if log_records.extra: + log_extra = _ExecutionLogRecordsExtra.model_validate(log_records.extra) + self._extract_from_log_extra(log_extra, result, last_attempt=last_attempt) + assert result.log_model is not None, "Only set to None after converting to JSON." + result.log_model.attempts.append(list(log_records)) + return True + + def _extract_from_log_extra( + self, + log_extra: _ExecutionLogRecordsExtra, + result: ScanResult, + last_attempt: ProvenanceQuantumAttemptModel | None, + ) -> None: + for previous_attempt_log_extra in log_extra.previous_attempts: + self._extract_from_log_extra(previous_attempt_log_extra, result, last_attempt=None) + quantum_attempt: ProvenanceQuantumAttemptModel + if last_attempt is None: + # This is not the last attempt, so it must be a failure. + quantum_attempt = ProvenanceQuantumAttemptModel( + attempt=len(result.attempts), status=QuantumAttemptStatus.FAILED + ) + # We also need to get the logs from this extra provenance, since + # they won't be the main section of the log records. + assert result.log_model is not None, "Only set to None after converting to JSON." + result.log_model.attempts.append(log_extra.logs) + # The special last attempt is only appended after we attempt to + # read metadata later, but we have to append this one now. + result.attempts.append(quantum_attempt) + else: + assert not log_extra.logs, "Logs for the last attempt should not be stored in the extra JSON." + quantum_attempt = last_attempt + if log_extra.exception is not None or log_extra.metadata is not None or last_attempt is None: + # We won't be getting a separate metadata dataset, so anything we + # might get from the metadata has to come from this extra + # provenance in the logs. + quantum_attempt.exception = log_extra.exception + assert result.metadata_model is not None, "Only set to None after converting to JSON." + if log_extra.metadata is not None: + quantum_attempt.resource_usage = QuantumResourceUsage.from_task_metadata(log_extra.metadata) + result.metadata_model.attempts.append(log_extra.metadata) + else: + result.metadata_model.attempts.append(None) + # Regardless of whether this is the last attempt or not, we can only + # get the previous_process_quanta from the log extra. + quantum_attempt.previous_process_quanta.extend(log_extra.previous_process_quanta) diff --git a/python/lsst/pipe/base/quantum_graph/aggregator/_structs.py b/python/lsst/pipe/base/quantum_graph/aggregator/_structs.py index 14eec4b85..b4688d04b 100644 --- a/python/lsst/pipe/base/quantum_graph/aggregator/_structs.py +++ b/python/lsst/pipe/base/quantum_graph/aggregator/_structs.py @@ -40,11 +40,13 @@ from lsst.daf.butler.datastore.record_data import DatastoreRecordData -from ..._status import QuantumSuccessCaveats -from ...quantum_provenance_graph import ExceptionInfo, QuantumRunStatus -from ...resource_usage import QuantumResourceUsage from .._common import DatastoreName from .._predicted import PredictedDatasetModel +from .._provenance import ( + ProvenanceLogRecordsModel, + ProvenanceQuantumAttemptModel, + ProvenanceTaskMetadataModel, +) class ScanStatus(enum.Enum): @@ -126,42 +128,33 @@ class ScanResult: status: ScanStatus """Combined status for the scan and the execution of the quantum.""" - caveats: QuantumSuccessCaveats | None = None - """Flags indicating caveats on successful quanta.""" - - exception: ExceptionInfo | None = None - """Information about an exception raised when the quantum was executing.""" - - resource_usage: QuantumResourceUsage | None = None - """Resource usage information (timing, memory use) for this quantum.""" + attempts: list[ProvenanceQuantumAttemptModel] = dataclasses.field(default_factory=list) + """Provenance information about each attempt to run the quantum.""" existing_outputs: set[uuid.UUID] = dataclasses.field(default_factory=set) """Unique IDs of the output datasets that were actually written.""" - metadata: bytes = b"" - """Raw content of the metadata dataset.""" + metadata_model: ProvenanceTaskMetadataModel | None = dataclasses.field( + default_factory=ProvenanceTaskMetadataModel + ) + """Task metadata information for each attempt. + + This is set to `None` to keep the pickle size small after it is saved + to `metadata_content`. + """ + + metadata_content: bytes = b"" + """Serialized form of `metadata_model`.""" + + log_model: ProvenanceLogRecordsModel | None = dataclasses.field(default_factory=ProvenanceLogRecordsModel) + """Log records for each attempt. - log: bytes = b"" - """Raw content of the log dataset.""" + This is set to `None` to keep the pickle size small after it is saved + to `log_content`. + """ + + log_content: bytes = b"" + """Serialized form of `logs_model`.""" is_compressed: bool = False """Whether the `metadata` and `log` attributes are compressed.""" - - def get_run_status(self) -> QuantumRunStatus: - """Translate the scan status and metadata/log presence into a run - status. - """ - if self.status is ScanStatus.BLOCKED: - return QuantumRunStatus.BLOCKED - if self.status is ScanStatus.INIT: - return QuantumRunStatus.SUCCESSFUL - if self.log: - if self.metadata: - return QuantumRunStatus.SUCCESSFUL - else: - return QuantumRunStatus.FAILED - else: - if self.metadata: - return QuantumRunStatus.LOGS_MISSING - else: - return QuantumRunStatus.METADATA_MISSING diff --git a/python/lsst/pipe/base/quantum_graph/aggregator/_supervisor.py b/python/lsst/pipe/base/quantum_graph/aggregator/_supervisor.py index 46fcfbf4f..c2ee99704 100644 --- a/python/lsst/pipe/base/quantum_graph/aggregator/_supervisor.py +++ b/python/lsst/pipe/base/quantum_graph/aggregator/_supervisor.py @@ -107,9 +107,10 @@ def loop(self) -> None: """Scan the outputs of the quantum graph to gather provenance and ingest outputs. """ - self.comms.progress.set_n_quanta( - self.predicted.header.n_quanta + len(self.predicted.init_quanta.root) - ) + n_quanta = self.predicted.header.n_quanta + len(self.predicted.init_quanta.root) + self.comms.progress.scans.total = n_quanta + self.comms.progress.writes.total = n_quanta + self.comms.progress.quantum_ingests.total = n_quanta ready_set: set[uuid.UUID] = set() for ready_quanta in self.walker: self.comms.log.debug("Sending %d new quanta to scan queue.", len(ready_quanta)) @@ -137,8 +138,8 @@ def handle_report(self, scan_report: ScanReport) -> None: for blocked_quantum_id in blocked_quanta: if self.comms.config.output_path is not None: self.comms.request_write(ScanResult(blocked_quantum_id, status=ScanStatus.BLOCKED)) - self.comms.progress.report_scan() - self.comms.progress.report_ingests(len(blocked_quanta)) + self.comms.progress.scans.update(1) + self.comms.progress.quantum_ingests.update(len(blocked_quanta)) case ScanStatus.ABANDONED: self.comms.log.debug("Abandoning scan for %s: quantum has not succeeded (yet).") self.walker.fail(scan_report.quantum_id) @@ -147,7 +148,7 @@ def handle_report(self, scan_report: ScanReport) -> None: raise AssertionError( f"Unexpected status {unexpected!r} in scanner loop for {scan_report.quantum_id}." ) - self.comms.progress.report_scan() + self.comms.progress.scans.update(1) def aggregate_graph(predicted_path: str, butler_path: str, config: AggregatorConfig) -> None: @@ -159,7 +160,7 @@ def aggregate_graph(predicted_path: str, butler_path: str, config: AggregatorCon Path to the predicted quantum graph. butler_path : `str` Path or alias to the central butler repository. - config: `AggregatorConfig` + config : `AggregatorConfig` Configuration for the aggregator. """ log = getLogger("lsst.pipe.base.quantum_graph.aggregator") diff --git a/python/lsst/pipe/base/quantum_graph/aggregator/_writer.py b/python/lsst/pipe/base/quantum_graph/aggregator/_writer.py index 1a2f1fd1d..6ba1c71a3 100644 --- a/python/lsst/pipe/base/quantum_graph/aggregator/_writer.py +++ b/python/lsst/pipe/base/quantum_graph/aggregator/_writer.py @@ -464,7 +464,7 @@ def write_init_outputs(self, data_writers: _DataWriters) -> None: producer=self.indices[predicted_init_quantum.quantum_id], consumers=self.xgraph.successors(dataset_index), ) - provenance_output.exists = predicted_output.dataset_id in existing_outputs + provenance_output.produced = predicted_output.dataset_id in existing_outputs data_writers.datasets.write_model( provenance_output.dataset_id, provenance_output, data_writers.compressor ) @@ -551,16 +551,13 @@ def make_scan_data(self, request: ScanResult) -> list[_ScanData]: producer=quantum_index, consumers=self.xgraph.successors(dataset_index), ) - provenance_output.exists = provenance_output.dataset_id in request.existing_outputs + provenance_output.produced = provenance_output.dataset_id in request.existing_outputs data.datasets[provenance_output.dataset_id] = provenance_output.model_dump_json().encode() provenance_quantum = ProvenanceQuantumModel.from_predicted(predicted_quantum, self.indices) - provenance_quantum.status = request.get_run_status() - provenance_quantum.caveats = request.caveats - provenance_quantum.exception = request.exception - provenance_quantum.resource_usage = request.resource_usage + provenance_quantum.attempts = [a.remap_uuids(self.indices) for a in request.attempts] data.quantum = provenance_quantum.model_dump_json().encode() - data.metadata = request.metadata - data.log = request.log + data.metadata = request.metadata_content + data.log = request.log_content return [data] def write_scan_data(self, scan_data: _ScanData, data_writers: _DataWriters) -> None: diff --git a/python/lsst/pipe/base/quantum_provenance_graph.py b/python/lsst/pipe/base/quantum_provenance_graph.py index df778fc44..9b0633578 100644 --- a/python/lsst/pipe/base/quantum_provenance_graph.py +++ b/python/lsst/pipe/base/quantum_provenance_graph.py @@ -49,7 +49,7 @@ import uuid from collections.abc import Callable, Iterator, Mapping, Sequence, Set from enum import Enum -from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, cast +from typing import Any, ClassVar, Literal, TypedDict, cast import astropy.table import networkx @@ -72,7 +72,7 @@ from lsst.resources import ResourcePathExpression from lsst.utils.logging import PeriodicLogger, getLogger -from ._status import QuantumSuccessCaveats +from ._status import ExceptionInfo, QuantumSuccessCaveats from .automatic_connection_constants import ( LOG_OUTPUT_CONNECTION_NAME, LOG_OUTPUT_TEMPLATE, @@ -82,9 +82,6 @@ ) from .graph import QuantumGraph, QuantumNode -if TYPE_CHECKING: - from ._task_metadata import TaskMetadata - _LOG = getLogger(__name__) @@ -188,45 +185,6 @@ class QuantumRunStatus(Enum): SUCCESSFUL = 1 -class ExceptionInfo(pydantic.BaseModel): - """Information about an exception that was raised.""" - - type_name: str - """Fully-qualified Python type name for the exception raised.""" - - message: str - """String message included in the exception.""" - - metadata: dict[str, float | int | str | bool | None] - """Additional metadata included in the exception.""" - - @classmethod - def _from_metadata(cls, md: TaskMetadata) -> ExceptionInfo: - """Construct from task metadata. - - Parameters - ---------- - md : `TaskMetadata` - Metadata about the error, as written by - `AnnotatedPartialOutputsError`. - - Returns - ------- - info : `ExceptionInfo` - Information about the exception. - """ - result = cls(type_name=md["type"], message=md["message"], metadata={}) - if "metadata" in md: - raw_err_metadata = md["metadata"].to_dict() - for k, v in raw_err_metadata.items(): - # Guard against error metadata we couldn't serialize later - # via Pydantic; don't want one weird value bringing down our - # ability to report on an entire run. - if isinstance(v, float | int | str | bool): - result.metadata[k] = v - return result - - class QuantumRun(pydantic.BaseModel): """Information about a quantum in a given run collection.""" diff --git a/python/lsst/pipe/base/single_quantum_executor.py b/python/lsst/pipe/base/single_quantum_executor.py index cbaf8abd3..e04cc541a 100644 --- a/python/lsst/pipe/base/single_quantum_executor.py +++ b/python/lsst/pipe/base/single_quantum_executor.py @@ -44,12 +44,19 @@ NamedKeyDict, Quantum, ) +from lsst.utils.introspection import get_full_type_name from lsst.utils.timer import logInfo from ._quantumContext import ExecutionResources, QuantumContext -from ._status import AnnotatedPartialOutputsError, InvalidQuantumError, NoWorkFound, QuantumSuccessCaveats +from ._status import ( + AnnotatedPartialOutputsError, + ExceptionInfo, + InvalidQuantumError, + NoWorkFound, + QuantumSuccessCaveats, +) from .connections import AdjustQuantumHelper -from .log_capture import LogCapture +from .log_capture import LogCapture, _ExecutionLogRecordsExtra from .pipeline_graph import TaskNode from .pipelineTask import PipelineTask from .quantum_graph_executor import QuantumExecutor @@ -147,6 +154,7 @@ def __init__( self._skip_existing = self._butler.run in self._butler.collections.query( skip_existing_in, flatten_chains=True ) + self._previous_process_quanta: list[uuid.UUID] = [] def execute( self, task_node: TaskNode, /, quantum: Quantum, quantum_id: uuid.UUID | None = None @@ -196,7 +204,7 @@ def _execute( # or raises an exception do not try to store logs, as they may be # already in butler. captureLog.store = False - if self._check_existing_outputs(quantum, task_node, limited_butler): + if self._check_existing_outputs(quantum, task_node, limited_butler, captureLog.extra): _LOG.info( "Skipping already-successful quantum for label=%s dataId=%s.", task_node.label, @@ -205,6 +213,9 @@ def _execute( return quantum captureLog.store = True + captureLog.extra.previous_process_quanta.extend(self._previous_process_quanta) + if quantum_id is not None: + self._previous_process_quanta.append(quantum_id) try: quantum = self._updated_quantum_inputs(quantum, task_node, limited_butler) except NoWorkFound as exc: @@ -261,6 +272,11 @@ def _execute( e.__class__.__name__, str(e), ) + captureLog.extra.exception = ExceptionInfo( + type_name=get_full_type_name(e), + message=str(e), + metadata={}, + ) raise else: quantumMetadata["butler_metrics"] = butler_metrics.model_dump() @@ -268,11 +284,13 @@ def _execute( # Stringify the UUID for easier compatibility with # PropertyList. quantumMetadata["outputs"] = [str(output) for output in outputsPut] - logInfo(None, "end", metadata=quantumMetadata) # type: ignore[arg-type] - fullMetadata = task.getFullMetadata() - fullMetadata["quantum"] = quantumMetadata - if self._job_metadata is not None: - fullMetadata["job"] = self._job_metadata + finally: + logInfo(None, "end", metadata=quantumMetadata) # type: ignore[arg-type] + fullMetadata = task.getFullMetadata() + fullMetadata["quantum"] = quantumMetadata + if self._job_metadata is not None: + fullMetadata["job"] = self._job_metadata + captureLog.extra.metadata = fullMetadata self._write_metadata(quantum, fullMetadata, task_node, limited_butler) stopTime = time.time() _LOG.info( @@ -284,7 +302,12 @@ def _execute( return quantum def _check_existing_outputs( - self, quantum: Quantum, task_node: TaskNode, /, limited_butler: LimitedButler + self, + quantum: Quantum, + task_node: TaskNode, + /, + limited_butler: LimitedButler, + log_extra: _ExecutionLogRecordsExtra, ) -> bool: """Decide whether this quantum needs to be executed. @@ -302,6 +325,8 @@ def _check_existing_outputs( Task definition structure. limited_butler : `~lsst.daf.butler.LimitedButler` Butler to use for querying and clobbering. + log_extra : `.log_capture.TaskLogRecordsExtra` + Extra information to attach to log records. Returns ------- @@ -337,6 +362,15 @@ def _check_existing_outputs( "Looking for existing outputs in the way for label=%s dataId=%s.", task_node.label, quantum.dataId ) ref_dict = limited_butler.stored_many(chain.from_iterable(quantum.outputs.values())) + if task_node.log_output is not None: + (log_ref,) = quantum.outputs[task_node.log_output.dataset_type_name] + if ref_dict[log_ref]: + _LOG.debug( + "Attaching logs from previous attempt on label=%s dataId=%s.", + task_node.label, + quantum.dataId, + ) + log_extra.attach_previous_attempt(limited_butler.get(log_ref)) existingRefs = [ref for ref, exists in ref_dict.items() if exists] missingRefs = [ref for ref, exists in ref_dict.items() if not exists] if existingRefs: diff --git a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py index 3133eb400..be443a5f2 100644 --- a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py +++ b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py @@ -95,7 +95,7 @@ class ForcedFailure: memory_required: Quantity | None = None """If not `None`, this failure simulates an out-of-memory failure by - raising only if this value exceeds `ExecutionResources.max_mem`.f + raising only if this value exceeds `ExecutionResources.max_mem`. """ def set_config(self, config: MockPipelineTaskConfig) -> None: diff --git a/tests/test_aggregator.py b/tests/test_aggregator.py index a6c1c036c..6dbfabed1 100644 --- a/tests/test_aggregator.py +++ b/tests/test_aggregator.py @@ -39,7 +39,12 @@ import lsst.utils.tests from lsst.daf.butler import Butler, ButlerLogRecords, QuantumBackedButler -from lsst.pipe.base import AlgorithmError, QuantumSuccessCaveats, TaskMetadata +from lsst.pipe.base import ( + AlgorithmError, + QuantumAttemptStatus, + QuantumSuccessCaveats, + TaskMetadata, +) from lsst.pipe.base import automatic_connection_constants as acc from lsst.pipe.base.cli.cmd.commands import aggregate_graph as aggregate_graph_cli from lsst.pipe.base.graph_walker import GraphWalker @@ -54,7 +59,6 @@ ProvenanceQuantumInfo, ) from lsst.pipe.base.quantum_graph.aggregator import AggregatorConfig, FatalWorkerError, aggregate_graph -from lsst.pipe.base.quantum_provenance_graph import QuantumRunStatus from lsst.pipe.base.resource_usage import QuantumResourceUsage from lsst.pipe.base.single_quantum_executor import SingleQuantumExecutor from lsst.pipe.base.tests.mocks import ( @@ -231,7 +235,11 @@ def make_test_repo() -> Iterator[PrepInfo]: ) def iter_graph_execution( - self, repo: ResourcePath, qg: PredictedQuantumGraph, raise_on_partial_outputs: bool + self, + repo: ResourcePath, + qg: PredictedQuantumGraph, + raise_on_partial_outputs: bool, + is_retry: bool = False, ) -> Iterator[uuid.UUID]: """Return an iterator that executes and yields quanta one by one. @@ -245,12 +253,14 @@ def iter_graph_execution( raise_on_partial_outputs : `bool` Whether to raise on `lsst.pipe.base.AnnotatedPartialOutputsError` or treat it as a success with caveats. + is_retry : `bool`, optional + If `True`, this is a retry attempt and hence some outputs may + already be present; skip successes and reprocess failures. Returns ------- quanta : `~collections.abc.Iterator` [`uuid.UUID`] - An iterator over successful quantum IDs. Failed and blocked quanta - are not included. + An iterator over all executed quantum IDs (not blocked ones). """ qg.init_output_run(qg.make_init_qbb(repo)) sqe = SingleQuantumExecutor( @@ -259,7 +269,9 @@ def iter_graph_execution( quantum, qg.pipeline_graph.universe, ), - assume_no_existing_outputs=True, + assume_no_existing_outputs=not is_retry, + skip_existing=is_retry, + clobber_outputs=is_retry, raise_on_partial_outputs=raise_on_partial_outputs, ) qg.build_execution_quanta() @@ -274,7 +286,7 @@ def iter_graph_execution( walker.fail(quantum_id) else: walker.finish(quantum_id) - yield quantum_id + yield quantum_id def check_provenance_graph( self, @@ -283,7 +295,8 @@ def check_provenance_graph( butler: Butler, expect_failure: bool, start_time: float, - ) -> None: + expect_failures_retried: bool = False, + ) -> ProvenanceQuantumGraph: """Run a batter of tests on a provenance quantum graph produced by scanning the graph created by `make_test_repo`. @@ -302,6 +315,14 @@ def check_provenance_graph( start_time : `float` A POSIX timestamp that strictly precedes the start time of any quantum's execution. + expect_failures_retried : `bool`, optional + If `True`, expect an initial attempt with failures prior to the + most recent attempt. + + Returns + ------- + prov : `ProvenanceQuantumGraph` + The full provenance quantum graph. """ prov_reader.read_full_graph() prov = prov_reader.graph @@ -374,6 +395,14 @@ def check_provenance_graph( exception_type="lsst.pipe.base.tests.mocks.MockAlgorithmError", msg=msg, ) + if expect_failures_retried: + self.assertEqual(len(prov_qinfo["attempts"]), 2) + self.assertEqual( + prov_qinfo["attempts"][0].exception.type_name, + "lsst.pipe.base.tests.mocks.MockAlgorithmError", + ) + else: + self.assertEqual(len(prov_qinfo["attempts"]), 1) case "consolidate", {"visit": 2}: # This quantum will succeed (with one predicted input # missing) or be blocked. @@ -382,6 +411,9 @@ def check_provenance_graph( self._expect_blocked(prov_qinfo, existence, msg=msg) else: self._expect_successful(prov_qinfo, existence, msg=msg) + self.assertEqual( + len(prov_qinfo["attempts"]), expect_failures_retried or not expect_failure + ) case ( "resample" | "coadd", {"tract": 1, "patch": 1} | {"tract": 0, "patch": 5}, @@ -403,6 +435,9 @@ def check_provenance_graph( ), msg=msg, ) + self.assertEqual( + len(prov_qinfo["attempts"]), expect_failures_retried or not expect_failure + ) case ( "resample", {"tract": 0, "patch": 4, "visit": 2} | {"tract": 1, "patch": 0, "visit": 2}, @@ -414,6 +449,9 @@ def check_provenance_graph( self._expect_blocked(prov_qinfo, existence, msg=msg) else: self._expect_successful(prov_qinfo, existence, msg=msg) + self.assertEqual( + len(prov_qinfo["attempts"]), expect_failures_retried or not expect_failure + ) case ( "coadd", {"tract": 0, "patch": 4, "band": "r"} | {"tract": 1, "patch": 0, "band": "r"}, @@ -426,18 +464,22 @@ def check_provenance_graph( else: self._expect_all_exist(existence["input_image"], msg=msg) self._expect_successful(prov_qinfo, existence, msg=msg) + self.assertEqual( + len(prov_qinfo["attempts"]), expect_failures_retried or not expect_failure + ) case _: # All other quanta should succeed and have all inputs # present. for connection_name in prov_qinfo["pipeline_node"].inputs.keys(): self._expect_all_exist(existence[connection_name], msg=msg) self._expect_successful(prov_qinfo, existence, msg=msg) - if not checked_some_metadata and prov_qinfo["status"] is QuantumRunStatus.SUCCESSFUL: + self.assertEqual(len(prov_qinfo["attempts"]), 1) + if not checked_some_metadata and prov_qinfo["status"] is QuantumAttemptStatus.SUCCESSFUL: self.check_metadata(quantum_id, prov_reader, butler) checked_some_metadata = True if not checked_some_log and prov_qinfo["status"] in ( - QuantumRunStatus.SUCCESSFUL, - QuantumRunStatus.FAILED, + QuantumAttemptStatus.SUCCESSFUL, + QuantumAttemptStatus.FAILED, ): self.check_log(quantum_id, prov_reader, butler) checked_some_log = True @@ -449,6 +491,7 @@ def check_provenance_graph( self.check_packages(prov_reader) self.check_quantum_table(prov_reader.graph, expect_failure=expect_failure) self.check_exception_table(prov_reader.graph, expect_failure=expect_failure) + return prov def _expect_all_exist(self, existence: list[bool], msg: str) -> None: self.assertTrue(all(existence), msg=msg) @@ -468,7 +511,7 @@ def _expect_successful( *, msg: str, ) -> None: - self.assertEqual(info["status"], QuantumRunStatus.SUCCESSFUL, msg=msg) + self.assertEqual(info["status"], QuantumAttemptStatus.SUCCESSFUL, msg=msg) self.assertEqual(info["caveats"], caveats, msg=msg) if exception_type is None: self.assertIsNone(info["exception"], msg=msg) @@ -490,7 +533,8 @@ def _expect_successful( def _expect_failure( self, info: ProvenanceQuantumInfo, existence: dict[str, list[bool]], msg: str ) -> None: - self.assertEqual(info["status"], QuantumRunStatus.FAILED, msg=msg) + self.assertEqual(info["status"], QuantumAttemptStatus.FAILED, msg=msg) + self.assertEqual(info["exception"].type_name, "lsst.pipe.base.tests.mocks.MockAlgorithmError") self._expect_all_exist(existence[acc.LOG_OUTPUT_CONNECTION_NAME], msg=msg) self._expect_none_exist(existence[acc.METADATA_OUTPUT_CONNECTION_NAME], msg=msg) for connection_name in info["pipeline_node"].outputs.keys(): @@ -502,7 +546,8 @@ def _expect_blocked( existence: dict[str, list[bool]], msg: str, ) -> None: - self.assertEqual(info["status"], QuantumRunStatus.BLOCKED, msg=msg) + self.assertEqual(info["status"], QuantumAttemptStatus.BLOCKED, msg=msg) + self.assertEqual(info["attempts"], []) self._expect_none_exist(existence[acc.LOG_OUTPUT_CONNECTION_NAME], msg=msg) self._expect_none_exist(existence[acc.METADATA_OUTPUT_CONNECTION_NAME], msg=msg) for connection_name in info["pipeline_node"].outputs.keys(): @@ -540,7 +585,7 @@ def check_dataset( self.assertEqual(pred_info["dataset_type_name"], prov_info["dataset_type_name"]) self.assertEqual(pred_info["data_id"], prov_info["data_id"]) self.assertEqual(pred_info["run"], prov_info["run"]) - exists = prov_info["exists"] + exists = prov_info["produced"] dataset_type_name = prov_info["dataset_type_name"] # We can remove this guard when we ingest QG-backed metadata and logs. if not dataset_type_name.endswith("_metadata") and not dataset_type_name.endswith("_log"): @@ -571,14 +616,14 @@ def check_metadata( Client for the data repository. """ # Try reading metadata through the quantum ID. - (metadata1,) = provenance_reader.fetch_metadata([quantum_id]).values() + ((metadata1,),) = provenance_reader.fetch_metadata([quantum_id]).values() self.assertIsInstance(metadata1, TaskMetadata) for _, dataset_id, pipeline_edges in provenance_reader.graph.bipartite_xgraph.out_edges( quantum_id, data="pipeline_edges" ): if pipeline_edges[0].connection_name == acc.METADATA_OUTPUT_CONNECTION_NAME: # Also try reading metadata through the dataset ID. - (metadata2,) = provenance_reader.fetch_metadata([dataset_id]).values() + ((metadata2,),) = provenance_reader.fetch_metadata([dataset_id]).values() break else: raise AssertionError("No metadata connection found.") @@ -601,14 +646,14 @@ def check_log( Client for the data repository. """ # Try reading log through the quantum ID. - (log1,) = provenance_reader.fetch_logs([quantum_id]).values() + ((log1,),) = provenance_reader.fetch_logs([quantum_id]).values() self.assertIsInstance(log1, ButlerLogRecords) for _, dataset_id, pipeline_edges in provenance_reader.graph.bipartite_xgraph.out_edges( quantum_id, data="pipeline_edges" ): if pipeline_edges[0].connection_name == acc.LOG_OUTPUT_CONNECTION_NAME: # Also try reading log through the dataset ID. - (log2,) = provenance_reader.fetch_logs([dataset_id]).values() + ((log2,),) = provenance_reader.fetch_logs([dataset_id]).values() break else: raise AssertionError("No log connection found.") @@ -644,10 +689,7 @@ def check_resource_usage_table( quantum's execution. """ tbl = prov.make_task_resource_usage_table("calibrate", include_data_ids=True) - if expect_failure: - self.assertEqual(len(tbl), prov.header.n_task_quanta["calibrate"] - 1) - else: - self.assertEqual(len(tbl), prov.header.n_task_quanta["calibrate"]) + self.assertEqual(len(tbl), prov.header.n_task_quanta["calibrate"]) self.assertCountEqual( tbl.colnames, ["quantum_id"] @@ -729,12 +771,9 @@ def check_exception_table(self, prov: ProvenanceQuantumGraph, expect_failure: bo succeed without writing anything (`False`). """ t = prov.make_exception_table() - if expect_failure: - self.assertEqual(len(t), 0) - else: - self.assertEqual(list(t["Task"]), ["calibrate"]) - self.assertEqual(list(t["Exception"]), ["lsst.pipe.base.tests.mocks.MockAlgorithmError"]) - self.assertEqual(list(t["Count"]), [1]) + self.assertEqual(list(t["Task"]), ["calibrate"]) + self.assertEqual(list(t["Exception"]), ["lsst.pipe.base.tests.mocks.MockAlgorithmError"]) + self.assertEqual(list(t["Count"]), [1]) def test_all_successful(self) -> None: """Test running a full graph with no failures, and then scanning the @@ -745,15 +784,22 @@ def test_all_successful(self) -> None: with self.make_test_repo() as prep: prep.config.assume_complete = False start_time = time.time() - executed_quanta = list( + attempted_quanta = list( self.iter_graph_execution(prep.butler_path, prep.predicted, raise_on_partial_outputs=False) ) - self.assertCountEqual(executed_quanta, prep.predicted.quantum_only_xgraph.nodes.keys()) + self.assertCountEqual(attempted_quanta, prep.predicted.quantum_only_xgraph.nodes.keys()) aggregate_graph(prep.predicted_path, prep.butler_path, prep.config) with ProvenanceQuantumGraphReader.open(prep.config.output_path) as reader: - self.check_provenance_graph( - prep.predicted, reader, prep.butler, expect_failure=False, start_time=start_time + prov = self.check_provenance_graph( + prep.predicted, + reader, + prep.butler, + expect_failure=False, + start_time=start_time, ) + for i, quantum_id in enumerate(attempted_quanta): + qinfo: ProvenanceQuantumInfo = prov.quantum_only_xgraph.nodes[quantum_id] + self.assertEqual(qinfo["attempts"][-1].previous_process_quanta, attempted_quanta[:i]) def test_all_successful_two_phase(self) -> None: """Test running some of a graph with no failures, scanning with @@ -764,23 +810,29 @@ def test_all_successful_two_phase(self) -> None: execution_iter = self.iter_graph_execution( prep.butler_path, prep.predicted, raise_on_partial_outputs=False ) - executed_quanta = list(itertools.islice(execution_iter, 9)) - self.assertEqual(len(executed_quanta), 9) + attempted_quanta = list(itertools.islice(execution_iter, 9)) + self.assertEqual(len(attempted_quanta), 9) # Run the scanner while telling it to assume failures might change, - # so it just waits for incomplete quanta to finish (and then times - # out). + # so it just abandons incomplete quanta. prep.config.assume_complete = False with self.assertRaises(RuntimeError): aggregate_graph(prep.predicted_path, prep.butler_path, prep.config) # Finish executing the quanta. - executed_quanta.extend(execution_iter) + attempted_quanta.extend(execution_iter) # Scan again, and write the provenance QG. aggregate_graph(prep.predicted_path, prep.butler_path, prep.config) # Run the scanner again. with ProvenanceQuantumGraphReader.open(prep.config.output_path) as reader: - self.check_provenance_graph( - prep.predicted, reader, prep.butler, expect_failure=False, start_time=start_time + prov = self.check_provenance_graph( + prep.predicted, + reader, + prep.butler, + expect_failure=False, + start_time=start_time, ) + for i, quantum_id in enumerate(attempted_quanta): + qinfo: ProvenanceQuantumInfo = prov.quantum_only_xgraph.nodes[quantum_id] + self.assertEqual(qinfo["attempts"][-1].previous_process_quanta, attempted_quanta[:i]) def test_some_failed(self) -> None: """Test running a full graph with some failures, and then scanning the @@ -789,15 +841,21 @@ def test_some_failed(self) -> None: with self.make_test_repo() as prep: prep.config.assume_complete = True start_time = time.time() - for _ in self.iter_graph_execution( - prep.butler_path, prep.predicted, raise_on_partial_outputs=True - ): - pass + attempted_quanta = list( + self.iter_graph_execution(prep.butler_path, prep.predicted, raise_on_partial_outputs=True) + ) aggregate_graph(prep.predicted_path, prep.butler_path, prep.config) with ProvenanceQuantumGraphReader.open(prep.config.output_path) as reader: - self.check_provenance_graph( - prep.predicted, reader, prep.butler, expect_failure=True, start_time=start_time + prov = self.check_provenance_graph( + prep.predicted, + reader, + prep.butler, + expect_failure=True, + start_time=start_time, ) + for i, quantum_id in enumerate(attempted_quanta): + qinfo: ProvenanceQuantumInfo = prov.quantum_only_xgraph.nodes[quantum_id] + self.assertEqual(qinfo["attempts"][-1].previous_process_quanta, attempted_quanta[:i]) def test_some_failed_two_phase(self) -> None: """Test running a full graph with some failures, then scanning the @@ -806,19 +864,68 @@ def test_some_failed_two_phase(self) -> None: """ with self.make_test_repo() as prep: start_time = time.time() - for _ in self.iter_graph_execution( - prep.butler_path, prep.predicted, raise_on_partial_outputs=True - ): - pass + attempted_quanta = list( + self.iter_graph_execution(prep.butler_path, prep.predicted, raise_on_partial_outputs=True) + ) prep.config.assume_complete = False with self.assertRaisesRegex(RuntimeError, "1 quantum abandoned"): aggregate_graph(prep.predicted_path, prep.butler_path, prep.config) prep.config.assume_complete = True aggregate_graph(prep.predicted_path, prep.butler_path, prep.config) with ProvenanceQuantumGraphReader.open(prep.config.output_path) as reader: - self.check_provenance_graph( - prep.predicted, reader, prep.butler, expect_failure=True, start_time=start_time + prov = self.check_provenance_graph( + prep.predicted, + reader, + prep.butler, + expect_failure=True, + start_time=start_time, ) + for i, quantum_id in enumerate(attempted_quanta): + qinfo: ProvenanceQuantumInfo = prov.quantum_only_xgraph.nodes[quantum_id] + self.assertEqual(qinfo["attempts"][-1].previous_process_quanta, attempted_quanta[:i]) + + def test_retry(self) -> None: + """Test running a full graph with some failures, rerunning the quanta + that failed or were blocked in the first attempt, and then scanning + for provenance. + """ + with self.make_test_repo() as prep: + prep.config.assume_complete = True + start_time = time.time() + attempted_quanta_1 = list( + self.iter_graph_execution(prep.butler_path, prep.predicted, raise_on_partial_outputs=True) + ) + attempted_quanta_2 = list( + self.iter_graph_execution( + prep.butler_path, prep.predicted, raise_on_partial_outputs=False, is_retry=True + ) + ) + aggregate_graph(prep.predicted_path, prep.butler_path, prep.config) + with ProvenanceQuantumGraphReader.open(prep.config.output_path) as reader: + prov = self.check_provenance_graph( + prep.predicted, + reader, + prep.butler, + expect_failure=False, + start_time=start_time, + expect_failures_retried=True, + ) + for i, quantum_id in enumerate(attempted_quanta_1): + qinfo: ProvenanceQuantumInfo = prov.quantum_only_xgraph.nodes[quantum_id] + self.assertEqual(qinfo["attempts"][0].previous_process_quanta, attempted_quanta_1[:i]) + expected: list[uuid.UUID] = [] + for quantum_id in attempted_quanta_2: + qinfo: ProvenanceQuantumInfo = prov.quantum_only_xgraph.nodes[quantum_id] + if ( + quantum_id in attempted_quanta_1 + and qinfo["attempts"][0].status is QuantumAttemptStatus.SUCCESSFUL + ): + # These weren't actually attempted twice, since they + # were already successful in the first round. + self.assertEqual(len(qinfo["attempts"]), 1) + else: + self.assertEqual(qinfo["attempts"][-1].previous_process_quanta, expected) + expected.append(quantum_id) def test_worker_failures(self) -> None: """Test that if failures occur on (multiple) workers we shut down diff --git a/tests/test_predicted_qg.py b/tests/test_predicted_qg.py index 70d74746c..549698817 100644 --- a/tests/test_predicted_qg.py +++ b/tests/test_predicted_qg.py @@ -511,7 +511,8 @@ def test_io(self) -> None: components.write(tmpfile, zstd_dict_n_inputs=24) # enable dict compression code path # Test a full read with the new class. with PredictedQuantumGraph.open(tmpfile, page_size=four_row_page_size) as reader: - full_qg = reader.read_all().finish() + reader.read_all() + full_qg = reader.finish() self.check_quantum_graph(full_qg, dimension_data_deserialized=False) # Test a full read with the old class (uses new class and then # converts to old, and we convert back to new for the test). @@ -527,7 +528,8 @@ def test_io(self) -> None: ) # Test a thin but shallow read with the new class. with PredictedQuantumGraph.open(tmpfile, page_size=four_row_page_size) as reader: - thin_qg = reader.read_thin_graph().finish() + reader.read_thin_graph() + thin_qg = reader.finish() self.check_quantum_graph( thin_qg, dimension_data_deserialized=False, @@ -560,7 +562,8 @@ def test_no_compression_dict(self) -> None: tmpfile = os.path.join(tmpdir, "new.qg") components.write(tmpfile, zstd_dict_size=0) with PredictedQuantumGraph.open(tmpfile, page_size=three_row_page_size) as reader: - full_qg = reader.read_all().finish() + reader.read_all() + full_qg = reader.finish() self.check_quantum_graph(full_qg, dimension_data_deserialized=False) def test_dot(self) -> None: