Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-26591: Include instrument data ID value when provided on pipetask command-line or Pipeline yaml file #144

Merged
merged 2 commits into from
Sep 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
125 changes: 124 additions & 1 deletion python/lsst/pipe/base/graphBuilder.py
Expand Up @@ -35,6 +35,7 @@
from typing import Dict, Iterable, Iterator, List
import logging


# -----------------------------
# Imports for other modules --
# -----------------------------
Expand All @@ -50,6 +51,8 @@
NamedKeyDict,
Quantum,
)
from lsst.daf.butler.registry.queries.exprParser import ParseError, ParserYacc, TreeVisitor
from lsst.utils import doImport

# ----------------------------------
# Local non-exported definitions --
Expand Down Expand Up @@ -745,6 +748,75 @@ def makeQuantumGraph(self):
return graph


class _InstrumentFinder(TreeVisitor):
"""Implementation of TreeVisitor which looks for instrument name

Instrument should be specified as a boolean expression

instrument = 'string'
'string' = instrument

so we only need to find a binary operator where operator is "=",
one side is a string literal and other side is an identifier.
All visit methods return tuple of (type, value), non-useful nodes
return None for both type and value.
"""
def __init__(self):
self.instruments = []

def visitNumericLiteral(self, value, node):
# do not care about numbers
return (None, None)

def visitStringLiteral(self, value, node):
# return type and value
return ("str", value)

def visitTimeLiteral(self, value, node):
# do not care about these
return (None, None)

def visitRangeLiteral(self, start, stop, stride, node):
# do not care about these
return (None, None)

def visitIdentifier(self, name, node):
if name.lower() == "instrument":
return ("id", "instrument")
return (None, None)

def visitUnaryOp(self, operator, operand, node):
# do not care about these
return (None, None)

def visitBinaryOp(self, operator, lhs, rhs, node):
if operator == "=":
if lhs == ("id", "instrument") and rhs[0] == "str":
self.instruments.append(rhs[1])
elif rhs == ("id", "instrument") and lhs[0] == "str":
self.instruments.append(lhs[1])
return (None, None)

def visitIsIn(self, lhs, values, not_in, node):
# do not care about these
return (None, None)

def visitParens(self, expression, node):
# do not care about these
return (None, None)


def _findInstruments(queryStr):
parser = ParserYacc()
finder = _InstrumentFinder()
try:
tree = parser.parse(queryStr)
except ParseError as exc:
raise ValueError(f"failed to parse query expression: {queryStr}") from exc
tree.visit(finder)
return finder.instruments


# ------------------------
# Exported definitions --
# ------------------------
Expand Down Expand Up @@ -799,7 +871,7 @@ def makeGraph(self, pipeline, collections, run, userQuery):
Name of the `~lsst.daf.butler.CollectionType.RUN` collection for
output datasets, if it already exists.
userQuery : `str`
String which defunes user-defined selection for registry, should be
String which defines user-defined selection for registry, should be
empty or `None` if there is no restrictions on data selection.

Returns
Expand All @@ -817,7 +889,58 @@ def makeGraph(self, pipeline, collections, run, userQuery):
classes.
"""
scaffolding = _PipelineScaffolding(pipeline, registry=self.registry)

instrument = pipeline.getInstrument()
if isinstance(instrument, str):
instrument = doImport(instrument)
instrumentName = instrument.getName() if instrument else None
userQuery = self._verifyInstrumentRestriction(instrumentName, userQuery)

with scaffolding.connectDataIds(self.registry, collections, userQuery) as commonDataIds:
scaffolding.resolveDatasetRefs(self.registry, collections, run, commonDataIds,
skipExisting=self.skipExisting)
return scaffolding.makeQuantumGraph()

@staticmethod
def _verifyInstrumentRestriction(instrumentName, query):
"""Add an instrument restriction to the query if it does not have one,
and verify that if given an instrument name that there are no other
instrument restrictions in the query.

Parameters
----------
instrumentName : `str`
The name of the instrument that should appear in the query.
query : `str`
The query string.

Returns
-------
query : `str`
The query string with the instrument added to it if needed.

Raises
------
RuntimeError
If the pipeline names an instrument and the query contains more
than one instrument or the name of the instrument in the query does
not match the instrument named by the pipeline.
"""
if not instrumentName:
return query
queryInstruments = _findInstruments(query)
if len(queryInstruments) > 1:
raise RuntimeError(f"When the pipeline has an instrument (\"{instrumentName}\") the query must "
"have zero instruments or one instrument that matches the pipeline. "
f"Found these instruments in the query: {queryInstruments}.")
if not queryInstruments:
# There is not an instrument in the query, add it:
restriction = f"instrument = '{instrumentName}'"
_LOG.debug(f"Adding restriction \"{restriction}\" to query.")
query = f"{restriction} AND ({query})"
elif queryInstruments[0] != instrumentName:
# Since there is an instrument in the query, it should match
# the instrument in the pipeline.
raise RuntimeError(f"The instrument named in the query (\"{queryInstruments[0]}\") does not "
f"match the instrument named by the pipeline (\"{instrumentName}\")")
return query
12 changes: 12 additions & 0 deletions python/lsst/pipe/base/pipeline.py
Expand Up @@ -210,6 +210,18 @@ def addInstrument(self, instrument: Union[Instrument, str]):
instrument = f"{instrument.__module__}.{instrument.__qualname__}"
self._pipelineIR.instrument = instrument

def getInstrument(self):
"""Get the instrument from the pipeline.

Returns
-------
instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None
A derived class object of a `lsst.daf.butler.instrument`, a string
corresponding to a fully qualified `lsst.daf.butler.instrument`
name, or None if the pipeline does not have an instrument.
"""
return self._pipelineIR.instrument

def addTask(self, task: Union[PipelineTask, str], label: str):
"""Add a new task to the pipeline, or replace a task that is already
associated with the supplied label.
Expand Down
82 changes: 82 additions & 0 deletions tests/test_graphBuilder.py
@@ -0,0 +1,82 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

"""Tests of things related to the GraphBuilder class."""

import unittest

from lsst.pipe.base import GraphBuilder

import lsst.utils.tests


class HelperTestCase(unittest.TestCase):

def testAddInstrument(self):
"""Verify the pipeline instrument is added to the query."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction("HSC", "tract = 42"),
"instrument = 'HSC' AND (tract = 42)")

def testQueryContainsInstrument(self):
"""Verify the instrument is found and no further action is taken."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction("HSC", "'HSC' = instrument AND tract = 42"),
"'HSC' = instrument AND tract = 42")

def testQueryContainsInstrumentAltOrder(self):
"""Verify instrument is found in a different order, with no further
action."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction("HSC", "tract=42 AND instrument='HSC'"),
"tract=42 AND instrument='HSC'")

def testQueryContainsSimilarKey(self):
"""Verify a key that contains "instrument" is not confused for the
actual "instrument" key."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction("HSC", "notinstrument=42 AND instrument='HSC'"),
"notinstrument=42 AND instrument='HSC'")

def testNoPipelineInstrument(self):
"""Verify that when no pipeline instrument is passed that the query is
returned unchanged."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction(None, "foo=bar"),
"foo=bar")

def testNoPipelineInstrumentTwoQueryInstruments(self):
"""Verify that when no pipeline instrument is passed that the query can
contain two instruments."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction(None, "instrument = 'HSC' OR instrument = 'LSSTCam'"),
"instrument = 'HSC' OR instrument = 'LSSTCam'")

def testTwoQueryInstruments(self):
"""Verify that when a pipeline instrument is passed and the query
contains two instruments that a RuntimeError is raised."""
with self.assertRaises(RuntimeError):
GraphBuilder._verifyInstrumentRestriction("HSC", "instrument = 'HSC' OR instrument = 'LSSTCam'")


if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()