Skip to content

Commit

Permalink
Replace instrument check in QG gen with instrument-defaulting.
Browse files Browse the repository at this point in the history
Using a Pipeline that contains an instrument entry now automatically
includes that instrument as a constraint in the QG gen query, and
daf_butler will check that it's consistent with the instrument provided
in the user expression (which after this change is no longer required).

This is on the branch for this ticket only because the daf_butler
changes for this ticket reshuffled some modules, which broke the code
here that was previously using those (internal to daf_butler)
interfaces.  Happily, this is better behavior, better encapsulation,
and less code (well, not overall, but here).

On the testing side, we can remove a bunch of tests here (the
functionality we delegate to in daf_butler is already tested there),
but we also need to make the testing utility code a bit less cavalier,
and in particular make the pipeline's instrument agree with the
registry's instrument.
  • Loading branch information
TallJimbo committed Dec 17, 2020
1 parent 647f735 commit 059c297
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 203 deletions.
154 changes: 14 additions & 140 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
NamedKeyDict,
Quantum,
)
from lsst.daf.butler.registry.queries.exprParser import ParseError, ParserYacc, TreeVisitor
from lsst.utils import doImport

# ----------------------------------
Expand Down Expand Up @@ -480,7 +479,7 @@ def __repr__(self):
"""

@contextmanager
def connectDataIds(self, registry, collections, userQuery):
def connectDataIds(self, registry, collections, userQuery, externalDataId):
"""Query for the data IDs that connect nodes in the `QuantumGraph`.
This method populates `_TaskScaffolding.dataIds` and
Expand All @@ -494,8 +493,13 @@ def connectDataIds(self, registry, collections, userQuery):
Expressions representing the collections to search for input
datasets. May be any of the types accepted by
`lsst.daf.butler.CollectionSearch.fromExpression`.
userQuery : `str`, optional
userQuery : `str` or `None`
User-provided expression to limit the data IDs processed.
externalDataId : `DataCoordinate`
Externally-provided data ID that should be used to restrict the
results, just as if these constraints had been included via ``AND``
in ``userQuery``. This includes (at least) any instrument named
in the pipeline definition.
Returns
-------
Expand All @@ -522,6 +526,7 @@ def connectDataIds(self, registry, collections, userQuery):
datasets=list(self.inputs),
collections=collections,
where=userQuery,
dataId=externalDataId,
).materialize() as commonDataIds:
_LOG.debug("Expanding data IDs.")
commonDataIds = commonDataIds.expanded()
Expand Down Expand Up @@ -760,95 +765,6 @@ def makeQuantumGraph(self):
return graph


class _InstrumentFinder(TreeVisitor):
"""Implementation of TreeVisitor which looks for instrument name
Instrument should be specified as a boolean expression
instrument = 'string'
'string' = instrument
so we only need to find a binary operator where operator is "=",
one side is a string literal and other side is an identifier.
All visit methods return tuple of (type, value), non-useful nodes
return None for both type and value.
"""
def __init__(self):
self.instruments = []

def visitNumericLiteral(self, value, node):
# do not care about numbers
return (None, None)

def visitStringLiteral(self, value, node):
# return type and value
return ("str", value)

def visitTimeLiteral(self, value, node):
# do not care about these
return (None, None)

def visitRangeLiteral(self, start, stop, stride, node):
# do not care about these
return (None, None)

def visitIdentifier(self, name, node):
if name.lower() == "instrument":
return ("id", "instrument")
return (None, None)

def visitUnaryOp(self, operator, operand, node):
# do not care about these
return (None, None)

def visitBinaryOp(self, operator, lhs, rhs, node):
if operator == "=":
if lhs == ("id", "instrument") and rhs[0] == "str":
self.instruments.append(rhs[1])
elif rhs == ("id", "instrument") and lhs[0] == "str":
self.instruments.append(lhs[1])
return (None, None)

def visitIsIn(self, lhs, values, not_in, node):
# do not care about these
return (None, None)

def visitParens(self, expression, node):
# do not care about these
return (None, None)


def _findInstruments(queryStr):
"""Get the names of any instrument named in the query string by searching
for "instrument = <value>" and similar patterns.
Parameters
----------
queryStr : `str` or None
The query string to search, or None if there is no query.
Returns
-------
instruments : `list` [`str`]
The list of instrument names found in the query.
Raises
------
ValueError
If the query expression can not be parsed.
"""
if not queryStr:
return []
parser = ParserYacc()
finder = _InstrumentFinder()
try:
tree = parser.parse(queryStr)
except ParseError as exc:
raise ValueError(f"failed to parse query expression: {queryStr}") from exc
tree.visit(finder)
return finder.instruments


# ------------------------
# Exported definitions --
# ------------------------
Expand Down Expand Up @@ -927,54 +843,12 @@ def makeGraph(self, pipeline, collections, run, userQuery):
instrument = pipeline.getInstrument()
if isinstance(instrument, str):
instrument = doImport(instrument)
instrumentName = instrument.getName() if instrument else None
userQuery = self._verifyInstrumentRestriction(instrumentName, userQuery)

with scaffolding.connectDataIds(self.registry, collections, userQuery) as commonDataIds:
if instrument is not None:
dataId = DataCoordinate.standardize(instrument=instrument.getName(),
universe=self.registry.dimensions)
else:
dataId = DataCoordinate.makeEmpty(self.registry.dimensions)
with scaffolding.connectDataIds(self.registry, collections, userQuery, dataId) as commonDataIds:
scaffolding.resolveDatasetRefs(self.registry, collections, run, commonDataIds,
skipExisting=self.skipExisting)
return scaffolding.makeQuantumGraph()

@staticmethod
def _verifyInstrumentRestriction(instrumentName, query):
"""Add an instrument restriction to the query if it does not have one,
and verify that if given an instrument name that there are no other
instrument restrictions in the query.
Parameters
----------
instrumentName : `str`
The name of the instrument that should appear in the query.
query : `str`
The query string.
Returns
-------
query : `str`
The query string with the instrument added to it if needed.
Raises
------
RuntimeError
If the pipeline names an instrument and the query contains more
than one instrument or the name of the instrument in the query does
not match the instrument named by the pipeline.
"""
if not instrumentName:
return query
queryInstruments = _findInstruments(query)
if len(queryInstruments) > 1:
raise RuntimeError(f"When the pipeline has an instrument (\"{instrumentName}\") the query must "
"have zero instruments or one instrument that matches the pipeline. "
f"Found these instruments in the query: {queryInstruments}.")
if not queryInstruments:
# There is not an instrument in the query, add it:
restriction = f"instrument = '{instrumentName}'"
_LOG.debug(f"Adding restriction \"{restriction}\" to query.")
query = f"{restriction} AND ({query})" if query else restriction # (there may not be a query)
elif queryInstruments[0] != instrumentName:
# Since there is an instrument in the query, it should match
# the instrument in the pipeline.
raise RuntimeError(f"The instrument named in the query (\"{queryInstruments[0]}\") does not "
f"match the instrument named by the pipeline (\"{instrumentName}\")")
return query
18 changes: 14 additions & 4 deletions python/lsst/pipe/base/tests/simpleQGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from lsst.daf.butler import Butler, Config, DatasetType
import lsst.daf.butler.tests as butlerTests
import lsst.pex.config as pexConfig
from lsst.utils import doImport
from ... import base as pipeBase
from .. import connectionTypes as cT

Expand All @@ -44,7 +45,7 @@ class SimpleInstrument:

@staticmethod
def getName():
return "SimpleInstrument"
return "INSTRU"

def applyConfigOverrides(self, name, config):
pass
Expand Down Expand Up @@ -256,13 +257,22 @@ def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExist
# Add dataset types to registry
registerDatasetTypes(butler.registry, pipeline.toExpandedPipeline())

instrument = pipeline.getInstrument()
if instrument is not None:
if isinstance(instrument, str):
instrument = doImport(instrument)
instrumentName = instrument.getName()
else:
instrumentName = "INSTR"

# Add all needed dimensions to registry
butler.registry.insertDimensionData("instrument", dict(name="INSTR"))
butler.registry.insertDimensionData("detector", dict(instrument="INSTR", id=0, full_name="det0"))
butler.registry.insertDimensionData("instrument", dict(name=instrumentName))
butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0,
full_name="det0"))

# Add inputs to butler
data = numpy.array([0., 1., 2., 5.])
butler.put(data, "add_dataset0", instrument="INSTR", detector=0)
butler.put(data, "add_dataset0", instrument=instrumentName, detector=0)

# Make the graph
builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting)
Expand Down
59 changes: 0 additions & 59 deletions tests/test_graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,72 +24,13 @@
import logging
import unittest

from lsst.pipe.base import GraphBuilder
from lsst.pipe.base.tests import simpleQGraph
import lsst.utils.tests
from lsst.utils.tests import temporaryDirectory

_LOG = logging.getLogger(__name__)


class VerifyInstrumentRestrictionTestCase(unittest.TestCase):

def testAddInstrument(self):
"""Verify the pipeline instrument is added to the query."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction("HSC", "tract = 42"),
"instrument = 'HSC' AND (tract = 42)")

def testQueryContainsInstrument(self):
"""Verify the instrument is found and no further action is taken."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction("HSC", "'HSC' = instrument AND tract = 42"),
"'HSC' = instrument AND tract = 42")

def testQueryContainsInstrumentAltOrder(self):
"""Verify instrument is found in a different order, with no further
action."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction("HSC", "tract=42 AND instrument='HSC'"),
"tract=42 AND instrument='HSC'")

def testQueryContainsSimilarKey(self):
"""Verify a key that contains "instrument" is not confused for the
actual "instrument" key."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction("HSC", "notinstrument=42 AND instrument='HSC'"),
"notinstrument=42 AND instrument='HSC'")

def testNoPipelineInstrument(self):
"""Verify that when no pipeline instrument is passed that the query is
returned unchanged."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction(None, "foo=bar"),
"foo=bar")

def testNoPipelineInstrumentTwoQueryInstruments(self):
"""Verify that when no pipeline instrument is passed that the query can
contain two instruments."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction(None, "instrument = 'HSC' OR instrument = 'LSSTCam'"),
"instrument = 'HSC' OR instrument = 'LSSTCam'")

def testTwoQueryInstruments(self):
"""Verify that when a pipeline instrument is passed and the query
contains two instruments that a RuntimeError is raised."""
with self.assertRaises(RuntimeError):
GraphBuilder._verifyInstrumentRestriction("HSC", "instrument = 'HSC' OR instrument = 'LSSTCam'")

def testNoQuery(self):
"""Test adding the instrument query to an empty query."""
self.assertEqual(GraphBuilder._verifyInstrumentRestriction("HSC", ""), "instrument = 'HSC'")

def testNoQueryNoInstruments(self):
"""Test the verify function when there is no instrument and no
query."""
self.assertEqual(GraphBuilder._verifyInstrumentRestriction("", ""), "")


class GraphBuilderTestCase(unittest.TestCase):

def testDefault(self):
Expand Down

0 comments on commit 059c297

Please sign in to comment.