Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-28668: PipelineTask unit test framework bypasses dimensions checks #182

Merged
merged 4 commits into from
Jun 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions doc/lsst.pipe.base/testing-a-pipeline-task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,29 @@ There is currently no test framework for the use of init-inputs in task construc
task = OrTask(config=config)
run = testUtils.runTestQuantum(task, butler, quantum)
run.assert_called_once_with(cat=testCatalog)

.. _testing-a-pipeline-task-static-analysis:

Analyzing Connections Classes
=============================

Mistakes in creating pipeline connections classes can lead to hard-to-debug errors at run time.
The `lsst.pipe.base.testUtils.lintConnections` function analyzes a connections class for common errors.
The only errors currently tested are those involving inconsistencies between connection and quantum dimensions.

All tests done by `lintConnections` are heuristic, looking for common patterns of misuse.
Advanced users who are *deliberately* bending the usual rules can use keywords to turn off specific tests.

.. code-block:: py
:emphasize-lines: 9-10

class ListConnections(PipelineTaskConnections,
dimensions=("instrument", "visit", "detector")):
cat = connectionTypes.Input(
name="src",
storageClass="SourceCatalog",
dimensions=("instrument", "visit", "detector"),
multiple=True) # force a list of one catalog

lintConnections(ListConnections) # warns that cat always has one input
taranu marked this conversation as resolved.
Show resolved Hide resolved
lintConnections(ListConnections, checkUnnecessaryMultiple=False) # passes
4 changes: 4 additions & 0 deletions python/lsst/pipe/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
A `PipelineTaskConfig` class instance whose class has been configured
to use this `PipelineTaskConnectionsClass`

See also
--------
iterConnections

Notes
-----
``PipelineTaskConnection`` classes are created by declaring class
Expand Down
135 changes: 123 additions & 12 deletions python/lsst/pipe/base/testUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
__all__ = ["assertValidInitOutput",
"assertValidOutput",
"getInitInputs",
"lintConnections",
"makeQuantum",
"runTestQuantum",
]
Expand All @@ -33,7 +34,8 @@
import itertools
import unittest.mock

from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, Quantum, StorageClassFactory
from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, Quantum, StorageClassFactory, \
SkyPixDimension
from lsst.pipe.base import ButlerQuantumContext


Expand Down Expand Up @@ -61,29 +63,89 @@ def makeQuantum(task, butler, dataId, ioDataIds):
connections = task.config.ConnectionsClass(config=task.config)

try:
inputs = defaultdict(list)
outputs = defaultdict(list)
for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
_checkDimensionsMatch(butler.registry.dimensions, connections.dimensions, dataId.keys())
except ValueError as e:
raise ValueError("Error in quantum dimensions.") from e

inputs = defaultdict(list)
outputs = defaultdict(list)
for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
try:
connection = connections.__getattribute__(name)
_checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
ids = _normalizeDataIds(ioDataIds[name])
for id in ids:
ref = _refFromConnection(butler, connection, id)
inputs[ref.datasetType].append(ref)
for name in connections.outputs:
except (ValueError, KeyError) as e:
raise ValueError(f"Error in connection {name}.") from e
for name in connections.outputs:
try:
connection = connections.__getattribute__(name)
_checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
ids = _normalizeDataIds(ioDataIds[name])
for id in ids:
ref = _refFromConnection(butler, connection, id)
outputs[ref.datasetType].append(ref)
quantum = Quantum(taskClass=type(task),
dataId=dataId,
inputs=inputs,
outputs=outputs)
return quantum
except KeyError as e:
raise ValueError("Mismatch in input data.") from e
except (ValueError, KeyError) as e:
raise ValueError(f"Error in connection {name}.") from e
quantum = Quantum(taskClass=type(task),
dataId=dataId,
inputs=inputs,
outputs=outputs)
return quantum


def _checkDimensionsMatch(universe, expected, actual):
"""Test whether two sets of dimensions agree after conversions.

Parameters
----------
universe : `lsst.daf.butler.DimensionUniverse`
The set of all known dimensions.
expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
The dimensions expected from a task specification.
actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
The dimensions provided by input.

Raises
------
ValueError
Raised if ``expected`` and ``actual`` cannot be reconciled.
"""
if _simplify(universe, expected) != _simplify(universe, actual):
raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.")


def _simplify(universe, dimensions):
"""Reduce a set of dimensions to a string-only form.

Parameters
----------
universe : `lsst.daf.butler.DimensionUniverse`
The set of all known dimensions.
dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
A set of dimensions to simplify.

Returns
-------
dimensions : `Set` [`str`]
A copy of ``dimensions`` reduced to string form, with all spatial
dimensions simplified to ``skypix``.
"""
simplified = set()
for dimension in dimensions:
# skypix not a real Dimension, handle it first
if dimension == "skypix":
simplified.add(dimension)
else:
# Need a Dimension to test spatialness
fullDimension = universe[dimension] if isinstance(dimension, str) else dimension
taranu marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(fullDimension, SkyPixDimension):
simplified.add("skypix")
else:
simplified.add(fullDimension.name)
return simplified


def _checkDataIdMultiplicity(name, dataIds, multiple):
Expand Down Expand Up @@ -155,6 +217,8 @@ def _refFromConnection(butler, connection, dataId, **kwargs):
``dataId``, in the collection pointed to by ``butler``.
"""
universe = butler.registry.dimensions
# DatasetRef only tests if required dimension is missing, but not extras
_checkDimensionsMatch(universe, connection.dimensions, dataId.keys())
dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)

# skypix is a PipelineTask alias for "some spatial index", Butler doesn't
Expand Down Expand Up @@ -362,3 +426,50 @@ def getInitInputs(butler, config):
initInputs[name] = butler.get(dsType)

return initInputs


def lintConnections(connections, *,
checkMissingMultiple=True,
checkUnnecessaryMultiple=True,
):
"""Inspect a connections class for common errors.

These tests are designed to detect misuse of connections features in
standard designs. An unusually designed connections class may trigger
alerts despite being correctly written; specific checks can be turned off
using keywords.

Parameters
----------
connections : `lsst.pipe.base.PipelineTaskConnections`-type
kfindeisen marked this conversation as resolved.
Show resolved Hide resolved
The connections class to test.
checkMissingMultiple : `bool`
Whether to test for single connections that would match multiple
datasets at run time.
checkUnnecessaryMultiple : `bool`
Whether to test for multiple connections that would only match
one dataset.

Raises
------
AssertionError
Raised if any of the selected checks fail for any connection.
"""
# Since all comparisons are inside the class, don't bother
# normalizing skypix.
quantumDimensions = connections.dimensions

errors = ""
# connectionTypes.DimensionedConnection is implementation detail,
# don't use it.
for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs):
connection = connections.allConnections[name]
connDimensions = set(connection.dimensions)
if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions:
kfindeisen marked this conversation as resolved.
Show resolved Hide resolved
errors += f"Connection {name} may be called with multiple values of " \
f"{connDimensions - quantumDimensions} but has multiple=False.\n"
if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions:
errors += f"Connection {name} has multiple=True but can only be called with one " \
f"value of {connDimensions} for each {quantumDimensions}.\n"
if errors:
raise AssertionError(errors)
61 changes: 60 additions & 1 deletion tests/test_testUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from lsst.pipe.base import Struct, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, connectionTypes
from lsst.pipe.base.testUtils import runTestQuantum, makeQuantum, assertValidOutput, \
assertValidInitOutput, getInitInputs
assertValidInitOutput, getInitInputs, lintConnections


class VisitConnections(PipelineTaskConnections, dimensions={"instrument", "visit"}):
Expand Down Expand Up @@ -294,6 +294,7 @@ def testMakeQuantumInvalidDimension(self):
config.connections.a = "PatchA"
task = VisitTask(config=config)
dataIdV = {"instrument": "notACam", "visit": 102}
dataIdVExtra = {"instrument": "notACam", "visit": 102, "detector": 42}
dataIdP = {"skymap": "sky", "tract": 42, "patch": 0}

inA = [1, 2, 3]
Expand All @@ -302,13 +303,36 @@ def testMakeQuantumInvalidDimension(self):
self.butler.put(butlerTests.MetricsExample(data=inA), "PatchA", dataIdP)
self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataIdV)

# dataIdV is correct everywhere, dataIdP should error
with self.assertRaises(ValueError):
makeQuantum(task, self.butler, dataIdV, {
"a": dataIdP,
"b": dataIdV,
"outA": dataIdV,
"outB": dataIdV,
})
with self.assertRaises(ValueError):
makeQuantum(task, self.butler, dataIdP, {
"a": dataIdV,
"b": dataIdV,
"outA": dataIdV,
"outB": dataIdV,
})
# should not accept small changes, either
with self.assertRaises(ValueError):
makeQuantum(task, self.butler, dataIdV, {
"a": dataIdV,
"b": dataIdV,
"outA": dataIdVExtra,
"outB": dataIdV,
})
with self.assertRaises(ValueError):
makeQuantum(task, self.butler, dataIdVExtra, {
"a": dataIdV,
"b": dataIdV,
"outA": dataIdV,
"outB": dataIdV,
})

def testMakeQuantumMissingMultiple(self):
task = PatchTask()
Expand Down Expand Up @@ -574,6 +598,41 @@ def testSkypixHandling(self):
# PixA dataset should have been retrieved by runTestQuantum
run.assert_called_once_with(a=data)

def testLintConnectionsOk(self):
lintConnections(VisitConnections)
lintConnections(PatchConnections)
lintConnections(SkyPixConnections)

def testLintConnectionsMissingMultiple(self):
class BadConnections(PipelineTaskConnections,
dimensions={"tract", "patch", "skymap"}):
coadds = connectionTypes.Input(
name="coadd_calexp",
storageClass="ExposureF",
# Some authors use list rather than set; check that linter
# can handle it.
dimensions=["tract", "patch", "band", "skymap"],
)

with self.assertRaises(AssertionError):
lintConnections(BadConnections)
lintConnections(BadConnections, checkMissingMultiple=False)

def testLintConnectionsExtraMultiple(self):
class BadConnections(PipelineTaskConnections,
# Some authors use list rather than set.
dimensions=["tract", "patch", "band", "skymap"]):
kfindeisen marked this conversation as resolved.
Show resolved Hide resolved
coadds = connectionTypes.Input(
name="coadd_calexp",
storageClass="ExposureF",
multiple=True,
dimensions={"tract", "patch", "band", "skymap"},
)

with self.assertRaises(AssertionError):
lintConnections(BadConnections)
lintConnections(BadConnections, checkUnnecessaryMultiple=False)


class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
pass
Expand Down