Skip to content

Commit

Permalink
add instrument to the user query
Browse files Browse the repository at this point in the history
  • Loading branch information
n8pease committed Sep 16, 2020
1 parent 3e73913 commit acf6da9
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 0 deletions.
123 changes: 123 additions & 0 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from typing import Dict, Iterable, Iterator, List
import logging


# -----------------------------
# Imports for other modules --
# -----------------------------
Expand All @@ -50,6 +51,8 @@
NamedKeyDict,
Quantum,
)
from lsst.daf.butler.registry.queries.exprParser import ParseError, ParserYacc, TreeVisitor
from lsst.utils import doImport

# ----------------------------------
# Local non-exported definitions --
Expand Down Expand Up @@ -745,6 +748,75 @@ def makeQuantumGraph(self):
return graph


class _InstrumentFinder(TreeVisitor):
"""Implementation of TreeVisitor which looks for instrument name
Instrument should be specified as a boolean expression
instrument = 'string'
'string' = instrument
so we only need to find a binary operator where operator is "=",
one side is a string literal and other side is an identifier.
All visit methods return tuple of (type, value), non-useful nodes
return None for both type and value.
"""
def __init__(self):
self.instruments = []

def visitNumericLiteral(self, value, node):
# do not care about numbers
return (None, None)

def visitStringLiteral(self, value, node):
# return type and value
return ("str", value)

def visitTimeLiteral(self, value, node):
# do not care about these
return (None, None)

def visitRangeLiteral(self, start, stop, stride, node):
# do not care about these
return (None, None)

def visitIdentifier(self, name, node):
if name.lower() == "instrument":
return ("id", "instrument")
return (None, None)

def visitUnaryOp(self, operator, operand, node):
# do not care about these
return (None, None)

def visitBinaryOp(self, operator, lhs, rhs, node):
if operator == "=":
if lhs == ("id", "instrument") and rhs[0] == "str":
self.instruments.append(rhs[1])
elif rhs == ("id", "instrument") and lhs[0] == "str":
self.instruments.append(lhs[1])
return (None, None)

def visitIsIn(self, lhs, values, not_in, node):
# do not care about these
return (None, None)

def visitParens(self, expression, node):
# do not care about these
return (None, None)


def _findInstruments(queryStr):
parser = ParserYacc()
finder = _InstrumentFinder()
try:
tree = parser.parse(queryStr)
except ParseError as exc:
raise ValueError(f"failed to parse query expression: {queryStr}") from exc
tree.visit(finder)
return finder.instruments


# ------------------------
# Exported definitions --
# ------------------------
Expand Down Expand Up @@ -817,7 +889,58 @@ def makeGraph(self, pipeline, collections, run, userQuery):
classes.
"""
scaffolding = _PipelineScaffolding(pipeline, registry=self.registry)

instrument = pipeline.getInstrument()
if isinstance(instrument, str):
instrument = doImport(instrument)
instrumentName = instrument.getName() if instrument else None
userQuery = self._verifyInstrumentRestriction(instrumentName, userQuery)

with scaffolding.connectDataIds(self.registry, collections, userQuery) as commonDataIds:
scaffolding.resolveDatasetRefs(self.registry, collections, run, commonDataIds,
skipExisting=self.skipExisting)
return scaffolding.makeQuantumGraph()

@staticmethod
def _verifyInstrumentRestriction(instrumentName, query):
"""Add an instrument restriction to the query if it does not have one,
and verify that if given an instrument name that there are no other
instrument restrictions in the query.
Parameters
----------
instrumentName : `str`
The name of the instrument that should appear in the query.
query : `str`
The query string.
Returns
-------
query : `str`
The query string with the instrument added to it if needed.
Raises
------
RuntimeError
If the pipeline names an instrument and the query contains more
than one instrument or the name of the instrument in the query does
not match the instrument named by the pipeline.
"""
if not instrumentName:
return query
queryInstruments = _findInstruments(query)
if len(queryInstruments) > 1:
raise RuntimeError(f"When the pipeline has an instrument (\"{instrumentName}\") the query must "
"have zero instruments or one instrument that matches the pipeline. "
f"Found these instruments in the query: {queryInstruments}.")
if not queryInstruments:
# There is not an instrument in the query, add it:
restriction = f"instrument = '{instrumentName}'"
_LOG.debug(f"Adding restriction \"{restriction}\" to query.")
query = f"{restriction} AND ({query})"
elif queryInstruments[0] != instrumentName:
# Since there is an instrument in the query, it should match
# the instrument in the pipeline.
raise RuntimeError(f"The instrument named in the query (\"{queryInstruments[0]}\") does not "
f"match the instrument named by the pipeline (\"{instrumentName}\")")
return query
12 changes: 12 additions & 0 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,18 @@ def addInstrument(self, instrument: Union[Instrument, str]):
instrument = f"{instrument.__module__}.{instrument.__qualname__}"
self._pipelineIR.instrument = instrument

def getInstrument(self):
"""Get the instrument from the pipeline.
Returns
-------
instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None
A derived class object of a `lsst.daf.butler.instrument`, a string
corresponding to a fully qualified `lsst.daf.butler.instrument`
name, or None if the pipeline does not have an instrument.
"""
return self._pipelineIR.instrument

def addTask(self, task: Union[PipelineTask, str], label: str):
"""Add a new task to the pipeline, or replace a task that is already
associated with the supplied label.
Expand Down
82 changes: 82 additions & 0 deletions tests/test_graphBuilder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

"""Tests of things related to the GraphBuilder class."""

import unittest

from lsst.pipe.base import GraphBuilder

import lsst.utils.tests


class HelperTestCase(unittest.TestCase):

def testAddInstrument(self):
"""Verify the pipeline instrument is added to the query."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction("HSC", "tract = 42"),
"instrument = 'HSC' AND (tract = 42)")

def testQueryContainsInstrument(self):
"""Verify the instrument is found and no further action is taken."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction("HSC", "'HSC' = instrument AND tract = 42"),
"'HSC' = instrument AND tract = 42")

def testQueryContainsInstrumentAltOrder(self):
"""Verify instrument is found in a different order, with no further
action."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction("HSC", "tract=42 AND instrument='HSC'"),
"tract=42 AND instrument='HSC'")

def testQueryContainsSimilarKey(self):
"""Verify a key that contains "instrument" is not confused for the
actual "instrument" key."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction("HSC", "notinstrument=42 AND instrument='HSC'"),
"notinstrument=42 AND instrument='HSC'")

def testNoPipelineInstrument(self):
"""Verify that when no pipeline instrument is passed that the query is
returned unchanged."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction(None, "foo=bar"),
"foo=bar")

def testNoPipelineInstrumentTwoQueryInstruments(self):
"""Verify that when no pipeline instrument is passed that the query can
contain two instruments."""
self.assertEqual(
GraphBuilder._verifyInstrumentRestriction(None, "instrument = 'HSC' OR instrument = 'LSSTCam'"),
"instrument = 'HSC' OR instrument = 'LSSTCam'")

def testTwoQueryInstruments(self):
"""Verify that when a pipeline instrument is passed and the query
contains two instruments that a RuntimeError is raised."""
with self.assertRaises(RuntimeError):
GraphBuilder._verifyInstrumentRestriction("HSC", "instrument = 'HSC' OR instrument = 'LSSTCam'")


if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()

0 comments on commit acf6da9

Please sign in to comment.