Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-27492: Add support for skip-existing-in option #197

Merged
merged 5 commits into from
Aug 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/changes/DM-27492.api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
`GraphBuilder` constructor boolean argument `skipExisting` is replaced with
`skipExistingIn` which accepts collections to check for existing quantum
outputs.
115 changes: 73 additions & 42 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from .pipeline import PipelineDatasetTypes, TaskDatasetTypes, TaskDef, Pipeline
from .graph import QuantumGraph
from lsst.daf.butler import (
CollectionSearch,
CollectionType,
DataCoordinate,
DatasetRef,
DatasetType,
Expand Down Expand Up @@ -580,7 +582,7 @@ def connectDataIds(self, registry, collections, userQuery, externalDataId):
_LOG.debug("Finished processing %d rows from data ID query.", n)
yield commonDataIds

def resolveDatasetRefs(self, registry, collections, run, commonDataIds, *, skipExisting=True,
def resolveDatasetRefs(self, registry, collections, run, commonDataIds, *, skipExistingIn=None,
clobberOutputs=True):
"""Perform follow up queries for each dataset data ID produced in
`fillDataIds`.
Expand All @@ -602,53 +604,81 @@ def resolveDatasetRefs(self, registry, collections, run, commonDataIds, *, skipE
commonDataIds : \
`lsst.daf.butler.registry.queries.DataCoordinateQueryResults`
Result of a previous call to `connectDataIds`.
skipExisting : `bool`, optional
If `True` (default), a Quantum is not created if all its outputs
already exist in ``run``. Ignored if ``run`` is `None`.
skipExistingIn
Expressions representing the collections to search for existing
output datasets that should be skipped. May be any of the types
accepted by `lsst.daf.butler.CollectionSearch.fromExpression`.
`None` or empty string/sequence disables skipping.
clobberOutputs : `bool`, optional
If `True` (default), allow quanta to created even if outputs exist;
this requires the same behavior behavior to be enabled when
executing. If ``skipExisting`` is also `True`, completed quanta
executing. If ``skipExistingIn`` is not `None`, completed quanta
(those with metadata, or all outputs if there is no metadata
dataset configured) will be skipped rather than clobbered.

Raises
------
OutputExistsError
Raised if an output dataset already exists in the output run
and ``skipExisting`` is `False`, or if only some outputs are
present and ``clobberOutputs`` is `False`.
and ``skipExistingIn`` does not include output run, or if only
some outputs are present and ``clobberOutputs`` is `False`.
"""
skipCollections: Optional[CollectionSearch] = None
skipExistingInRun = False
if skipExistingIn:
skipCollections = CollectionSearch.fromExpression(skipExistingIn)
if run:
# as optimization check in the explicit list of names first
skipExistingInRun = run in skipCollections.explicitNames()
if not skipExistingInRun:
# need to flatten it and check again
skipExistingInRun = run in registry.queryCollections(
skipExistingIn,
collectionTypes=CollectionType.RUN,
)

# Look up [init] intermediate and output datasets in the output
# collection, if there is an output collection.
if run is not None:
if run is not None or skipCollections is not None:
for datasetType, refs in itertools.chain(self.initIntermediates.items(),
self.initOutputs.items(),
self.intermediates.items(),
self.outputs.items()):
_LOG.debug("Resolving %d datasets for intermediate and/or output dataset %s.",
len(refs), datasetType.name)
isInit = datasetType in self.initIntermediates or datasetType in self.initOutputs
resolvedRefQueryResults = commonDataIds.subset(
datasetType.dimensions,
unique=True
).findDatasets(
datasetType,
collections=run,
findFirst=True
)
for resolvedRef in resolvedRefQueryResults:
# TODO: we could easily support per-DatasetType
# skipExisting and I could imagine that being useful - it's
# probably required in order to support writing initOutputs
# before QuantumGraph generation.
assert resolvedRef.dataId in refs
if skipExisting or isInit or clobberOutputs:
subset = commonDataIds.subset(datasetType.dimensions, unique=True)

# look at RUN collection first
if run is not None:
resolvedRefQueryResults = subset.findDatasets(
datasetType,
collections=run,
findFirst=True
)
for resolvedRef in resolvedRefQueryResults:
# TODO: we could easily support per-DatasetType
# skipExisting and I could imagine that being useful -
# it's probably required in order to support writing
# initOutputs before QuantumGraph generation.
assert resolvedRef.dataId in refs
andy-slac marked this conversation as resolved.
Show resolved Hide resolved
if not (skipExistingInRun or isInit or clobberOutputs):
raise OutputExistsError(f"Output dataset {datasetType.name} already exists in "
f"output RUN collection '{run}' with data ID"
f" {resolvedRef.dataId}.")

# And check skipExistingIn too, if RUN collection is in
# it is handled above
if skipCollections is not None:
resolvedRefQueryResults = subset.findDatasets(
datasetType,
collections=skipCollections,
findFirst=True
)
for resolvedRef in resolvedRefQueryResults:
assert resolvedRef.dataId in refs
refs[resolvedRef.dataId] = resolvedRef
else:
raise OutputExistsError(f"Output dataset {datasetType.name} already exists in "
f"output RUN collection '{run}' with data ID"
f" {resolvedRef.dataId}.")

# Look up input and initInput datasets in the input collection(s).
for datasetType, refs in itertools.chain(self.initInputs.items(), self.inputs.items()):
_LOG.debug("Resolving %d datasets for input dataset %s.", len(refs), datasetType.name)
Expand Down Expand Up @@ -689,13 +719,13 @@ def resolveDatasetRefs(self, registry, collections, run, commonDataIds, *, skipE
dataIdsFailed = []
dataIdsSucceeded = []
for quantum in task.quanta.values():
# Process outputs datasets only if there is a run to look for
# outputs in and skipExisting and/or clobberOutputs is True.
# Note that if skipExisting is False, any output datasets that
# already exist would have already caused an exception to be
# raised. We never update the DatasetRefs in the quantum
# because those should never be resolved.
if run is not None and (skipExisting or clobberOutputs):
# Process outputs datasets only if skipExistingIn is not None
# or there is a run to look for outputs in and clobberOutputs
# is True. Note that if skipExistingIn is None, any output
# datasets that already exist would have already caused an
# exception to be raised. We never update the DatasetRefs in
# the quantum because those should never be resolved.
if skipCollections is not None or (run is not None and clobberOutputs):
resolvedRefs = []
unresolvedRefs = []
haveMetadata = False
Expand All @@ -710,7 +740,7 @@ def resolveDatasetRefs(self, registry, collections, run, commonDataIds, *, skipE
if resolvedRefs:
if haveMetadata or not unresolvedRefs:
dataIdsSucceeded.append(quantum.dataId)
if skipExisting:
if skipCollections is not None:
continue
else:
dataIdsFailed.append(quantum.dataId)
Expand Down Expand Up @@ -764,7 +794,7 @@ def resolveDatasetRefs(self, registry, collections, run, commonDataIds, *, skipE
if ref is not None})
# Actually remove any quanta that we decided to skip above.
if dataIdsSucceeded:
if skipExisting:
if skipCollections is not None:
_LOG.debug("Pruning successful %d quanta for task with label '%s' because all of their "
"outputs exist or metadata was written successfully.",
len(dataIdsSucceeded), task.taskDef.label)
Expand Down Expand Up @@ -837,19 +867,20 @@ class GraphBuilder(object):
----------
registry : `~lsst.daf.butler.Registry`
Data butler instance.
skipExisting : `bool`, optional
If `True` (default), a Quantum is not created if all its outputs
already exist.
skipExistingIn
Expressions representing the collections to search for existing
output datasets that should be skipped. May be any of the types
accepted by `lsst.daf.butler.CollectionSearch.fromExpression`.
clobberOutputs : `bool`, optional
If `True` (default), allow quanta to created even if partial outputs
exist; this requires the same behavior behavior to be enabled when
executing.
"""

def __init__(self, registry, skipExisting=True, clobberOutputs=True):
def __init__(self, registry, skipExistingIn=None, clobberOutputs=True):
self.registry = registry
self.dimensions = registry.dimensions
self.skipExisting = skipExisting
self.skipExistingIn = skipExistingIn
self.clobberOutputs = clobberOutputs

def makeGraph(self, pipeline, collections, run, userQuery,
Expand Down Expand Up @@ -902,6 +933,6 @@ def makeGraph(self, pipeline, collections, run, userQuery,
dataId = DataCoordinate.makeEmpty(self.registry.dimensions)
with scaffolding.connectDataIds(self.registry, collections, userQuery, dataId) as commonDataIds:
scaffolding.resolveDatasetRefs(self.registry, collections, run, commonDataIds,
skipExisting=self.skipExisting,
skipExistingIn=self.skipExistingIn,
clobberOutputs=self.clobberOutputs)
return scaffolding.makeQuantumGraph(metadata=metadata)
18 changes: 0 additions & 18 deletions python/lsst/pipe/base/pipeTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,6 @@
# Local non-exported definitions --
# ----------------------------------


def _loadTaskClass(taskDef, taskFactory):
"""Import task class if necessary.

Raises
------
`ImportError` is raised when task class cannot be imported.
`MissingTaskFactoryError` is raised when TaskFactory is needed but not
provided.
"""
taskClass = taskDef.taskClass
if not taskClass:
if not taskFactory:
raise MissingTaskFactoryError("Task class is not defined but task "
"factory instance is not provided")
taskClass = taskFactory.loadTaskClass(taskDef.taskName)
return taskClass

# ------------------------
# Exported definitions --
# ------------------------
Expand Down
47 changes: 45 additions & 2 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
# -------------------------------
from dataclasses import dataclass
from types import MappingProxyType
from typing import Dict, Iterable, Mapping, Set, Union, Generator, TYPE_CHECKING, Optional, Tuple
from typing import (ClassVar, Dict, Iterable, Iterator, Mapping, Set, Union,
Generator, TYPE_CHECKING, Optional, Tuple)

import copy
import re
Expand Down Expand Up @@ -807,6 +808,10 @@ class PipelineDatasetTypes:
`Pipeline`.
"""

packagesDatasetName: ClassVar[str] = "packages"
"""Name of a dataset type used to save package versions.
"""

initInputs: NamedValueSet[DatasetType]
"""Dataset types that are needed as inputs in order to construct the Tasks
in this Pipeline.
Expand Down Expand Up @@ -915,7 +920,7 @@ def fromPipeline(
if include_packages:
allInitOutputs.add(
DatasetType(
"packages",
cls.packagesDatasetName,
registry.dimensions.empty,
storageClass="Packages",
)
Expand Down Expand Up @@ -990,3 +995,41 @@ def frozen(s: NamedValueSet) -> NamedValueSet:
prerequisites=frozen(prerequisites),
byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
)

@classmethod
def initOutputNames(cls, pipeline: Union[Pipeline, Iterable[TaskDef]], *,
include_configs: bool = True, include_packages: bool = True) -> Iterator[str]:
"""Return the names of dataset types ot task initOutputs, Configs,
and package versions for a pipeline.

Parameters
----------
pipeline: `Pipeline` or `Iterable` [ `TaskDef` ]
A `Pipeline` instance or collection of `TaskDef` instances.
include_configs : `bool`, optional
If `True` (default) include config dataset types.
include_packages : `bool`, optional
If `True` (default) include the dataset type for package versions.

Yields
------
datasetTypeName : `str`
Name of the dataset type.
"""
if include_packages:
# Package versions dataset type
yield cls.packagesDatasetName

if isinstance(pipeline, Pipeline):
pipeline = pipeline.toExpandedPipeline()

for taskDef in pipeline:

# all task InitOutputs
for name in taskDef.connections.initOutputs:
attribute = getattr(taskDef.connections, name)
yield attribute.name

# config dataset name
if include_configs:
yield taskDef.configDatasetName