Skip to content

Commit

Permalink
add instrument to the user query
Browse files Browse the repository at this point in the history
  • Loading branch information
n8pease committed Sep 11, 2020
1 parent 75c765a commit aefa31f
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 0 deletions.
74 changes: 74 additions & 0 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from dataclasses import dataclass
from typing import Dict, Iterable, Iterator, List
import logging
import re

# -----------------------------
# Imports for other modules --
Expand All @@ -50,6 +51,7 @@
NamedKeyDict,
Quantum,
)
from lsst.obs.base.utils import getInstrument

# ----------------------------------
# Local non-exported definitions --
Expand Down Expand Up @@ -817,7 +819,79 @@ def makeGraph(self, pipeline, collections, run, userQuery):
classes.
"""
scaffolding = _PipelineScaffolding(pipeline, registry=self.registry)

userQuery = self._addInstrumentToQuery(pipeline, userQuery)

with scaffolding.connectDataIds(self.registry, collections, userQuery) as commonDataIds:
scaffolding.resolveDatasetRefs(self.registry, collections, run, commonDataIds,
skipExisting=self.skipExisting)
return scaffolding.makeQuantumGraph()

def _addInstrumentToQuery(self, pipeline, query):
"""Add an instrument restriction to the query if the pipeline has an
instrument and the query does not already have an instrument. If the
query does have an instrument verify that it matches the instrument in
the pipeline.
Parameters
----------
pipeline : `Pipeline`
Same as the pipeline in `makeGraph`
query : `str`
The initial query string.
Returns
-------
query : `str`
The query string with an instrument restriction checked or added
if needed.
Raises
------
RuntimeError
If the query contains an instrument restriction and the pipeline
has an instrument and the two do not match.
"""
instrument = pipeline.getInstrument()
if instrument is not None:
if isinstance(instrument, str):
instrument = getInstrument(pipeline._pipelineIR.instrument, self.registry)
queryInstrument = self._getInstrumentFromQuery(query)
if queryInstrument is None:
# There is not an instrument in the query, add it:
restriction = f"instrument = '{instrument.getName()}'"
_LOG.info(f"Adding restriction \"{restriction}\" to query.")
query = f"{restriction} AND ({query})"
elif queryInstrument != instrument.getName():
# Since there is an instrument in the query, it should match
# the instrument in the pipeline.
raise RuntimeError(f"The instrument named in the query (\"{queryInstrument}\") does not "
f"match the instrument named by the pipeline (\"{instrument.getName()}\")")
return query

@staticmethod
def _getInstrumentFromQuery(query):
"""Search the provided query string for "instrument=" and if found
return the value that follows.
Search is case insensitive, and allows whitespace after "instrument"
before the equals sign.
Strips any quotes from the name of the instrument.
Parameters
----------
query : `str`
The query string.
Returns
-------
instrumentName : `str` or `None`
The name of the instrument in the query string, or `None` if an
instrument is not named.
"""
instrumentRegex = r"instrument *="
match = re.search(instrumentRegex, query, re.IGNORECASE)
if match is None:
return None
return query[match.end():].split()[0].strip("\"'")
12 changes: 12 additions & 0 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,18 @@ def addInstrument(self, instrument: Union[Instrument, str]):
instrument = f"{instrument.__module__}.{instrument.__qualname__}"
self._pipelineIR.instrument = instrument

def getInstrument(self):
"""Get the instrument from the pipeline.
Returns
-------
instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None
A derived class object of a `lsst.daf.butler.instrument`, a string
corresponding to a fully qualified `lsst.daf.butler.instrument`
name, or None if the pipeline does not have an instrument.
"""
return self._pipelineIR.instrument

def addTask(self, task: Union[PipelineTask, str], label: str):
"""Add a new task to the pipeline, or replace a task that is already
associated with the supplied label.
Expand Down
44 changes: 44 additions & 0 deletions tests/test_graphBuilder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

"""Tests of things related to the GraphBuilder class."""

import unittest

from lsst.pipe.base import GraphBuilder
import lsst.utils.tests


class HelperTestCase(unittest.TestCase):

def test_getInstrument(self):
"""Test getting the instrument from the query."""
queries = (("tract = 42 and instrument = 'HSC'", "HSC"),
("tract=42 and INSTRUMENT='HSC'", "HSC"),
('tract=42 and Instrument = "HSC"', "HSC"),
("tract=42", None))
for query, expected in queries:
self.assertEqual(GraphBuilder._getInstrumentFromQuery(query), expected)


if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()

0 comments on commit aefa31f

Please sign in to comment.