Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-26774: Instrument-finding code incorrectly requires a data query #145

Merged
merged 8 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 21 additions & 1 deletion python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,26 @@ def visitParens(self, expression, node):


def _findInstruments(queryStr):
"""Get the names of any instrument named in the query string by searching
for "instrument = <value>" and similar patterns.

Parameters
----------
queryStr : `str` or None
The query string to search, or None if there is no query.

Returns
-------
instruments : `list` [`str`]
The list of instrument names found in the query.

Raises
------
ValueError
If the query expression can not be parsed.
"""
if not queryStr:
return []
parser = ParserYacc()
finder = _InstrumentFinder()
try:
Expand Down Expand Up @@ -935,7 +955,7 @@ def _verifyInstrumentRestriction(instrumentName, query):
# There is not an instrument in the query, add it:
restriction = f"instrument = '{instrumentName}'"
_LOG.debug(f"Adding restriction \"{restriction}\" to query.")
query = f"{restriction} AND ({query})"
query = f"{restriction} AND ({query})" if query else restriction # (there may not be a query)
elif queryInstruments[0] != instrumentName:
# Since there is an instrument in the query, it should match
# the instrument in the pipeline.
Expand Down
Empty file.
276 changes: 276 additions & 0 deletions python/lsst/pipe/base/tests/simpleQGraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Bunch of common classes and methods for use in unit tests.
"""

__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]

import itertools
import logging
import numpy

from lsst.daf.butler import (Butler, Config, DatasetType, CollectionSearch)
import lsst.daf.butler.tests as butlerTests
import lsst.pex.config as pexConfig
from ... import base as pipeBase
from .. import connectionTypes as cT

_LOG = logging.getLogger(__name__)


# SimpleInstrument has an instrument-like API as needed for unit testing, but
# can not explicitly depend on Instrument because pipe_base does not explicitly
# depend on obs_base.
class SimpleInstrument:

@staticmethod
def getName():
return "SimpleInstrument"

def applyConfigOverrides(self, name, config):
pass


class AddTaskConnections(pipeBase.PipelineTaskConnections,
dimensions=("instrument", "detector"),
defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}):
"""Connections for AddTask, has one input and two outputs,
plus one init output.
"""
input = cT.Input(name="add_dataset{in_tmpl}",
dimensions=["instrument", "detector"],
storageClass="NumpyArray",
doc="Input dataset type for this task")
output = cT.Output(name="add_dataset{out_tmpl}",
dimensions=["instrument", "detector"],
storageClass="NumpyArray",
doc="Output dataset type for this task")
output2 = cT.Output(name="add2_dataset{out_tmpl}",
dimensions=["instrument", "detector"],
storageClass="NumpyArray",
doc="Output dataset type for this task")
initout = cT.InitOutput(name="add_init_output{out_tmpl}",
storageClass="NumpyArray",
doc="Init Output dataset type for this task")


class AddTaskConfig(pipeBase.PipelineTaskConfig,
pipelineConnections=AddTaskConnections):
"""Config for AddTask.
"""
addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)


class AddTask(pipeBase.PipelineTask):
"""Trivial PipelineTask for testing, has some extras useful for specific
unit tests.
"""

ConfigClass = AddTaskConfig
_DefaultName = "add_task"

initout = numpy.array([999])
"""InitOutputs for this task"""

taskFactory = None
"""Factory that makes instances"""

def run(self, input):

if self.taskFactory:
# do some bookkeeping
if self.taskFactory.stopAt == self.taskFactory.countExec:
raise RuntimeError("pretend something bad happened")
self.taskFactory.countExec += 1

self.metadata.add("add", self.config.addend)
output = input + self.config.addend
output2 = output + self.config.addend
_LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
return pipeBase.Struct(output=output, output2=output2)


class AddTaskFactoryMock(pipeBase.TaskFactory):
"""Special task factory that instantiates AddTask.

It also defines some bookkeeping variables used by AddTask to report
progress to unit tests.
"""
def __init__(self, stopAt=-1):
self.countExec = 0 # incremented by AddTask
self.stopAt = stopAt # AddTask raises exception at this call to run()

def loadTaskClass(self, taskName):
if taskName == "AddTask":
return AddTask, "AddTask"

def makeTask(self, taskClass, config, overrides, butler):
if config is None:
config = taskClass.ConfigClass()
if overrides:
overrides.applyTo(config)
task = taskClass(config=config, initInputs=None)
task.taskFactory = self
return task


def registerDatasetTypes(registry, pipeline):
"""Register all dataset types used by tasks in a registry.

Copied and modified from `PreExecInit.initializeDatasetTypes`.

Parameters
----------
registry : `~lsst.daf.butler.Registry`
Registry instance.
pipeline : `typing.Iterable` of `TaskDef`
Iterable of TaskDef instances, likely the output of the method
toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object
"""
for taskDef in pipeline:
configDatasetType = DatasetType(taskDef.configDatasetName, {},
storageClass="Config",
universe=registry.dimensions)
packagesDatasetType = DatasetType("packages", {},
storageClass="Packages",
universe=registry.dimensions)
datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs,
datasetTypes.inputs, datasetTypes.outputs,
datasetTypes.prerequisites,
[configDatasetType, packagesDatasetType]):
_LOG.info("Registering %s with registry", datasetType)
# this is a no-op if it already exists and is consistent,
# and it raises if it is inconsistent. But components must be
# skipped
if not datasetType.isComponent():
registry.registerDatasetType(datasetType)


def makeSimplePipeline(nQuanta, instrument=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring would be nice.

"""Make a simple Pipeline for tests.

This is called by ``makeSimpleQGraph`` if no pipeline is passed to that
function. It can also be used to customize the pipeline used by
``makeSimpleQGraph`` function by calling this first and passing the result
to it.

Parameters
----------
nQuanta : `int`
The number of quanta to add to the pipeline.
instrument : `str` or `None`, optional
The importable name of an instrument to be added to the pipeline or
if no instrument should be added then an empty string or `None`, by
default None

Returns
-------
pipeline : `~lsst.pipe.base.Pipeline`
The created pipeline object.
"""
pipeline = pipeBase.Pipeline("test pipeline")
# make a bunch of tasks that execute in well defined order (via data
# dependencies)
for lvl in range(nQuanta):
pipeline.addTask(AddTask, f"task{lvl}")
pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", f"{lvl}")
pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", f"{lvl+1}")
if instrument:
pipeline.addInstrument(instrument)
return pipeline


def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExisting=False, inMemory=True,
userQuery=""):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add docstring for userQuery?

"""Make simple QuantumGraph for tests.

Makes simple one-task pipeline with AddTask, sets up in-memory
registry and butler, fills them with minimal data, and generates
QuantumGraph with all of that.

Parameters
----------
nQuanta : `int`
Number of quanta in a graph.
pipeline : `~lsst.pipe.base.Pipeline`
If `None` then one-task pipeline is made with `AddTask` and
default `AddTaskConfig`.
butler : `~lsst.daf.butler.Butler`, optional
Data butler instance, this should be an instance returned from a
previous call to this method.
root : `str`
Path or URI to the root location of the new repository. Only used if
``butler`` is None.
skipExisting : `bool`, optional
If `True` (default), a Quantum is not created if all its outputs
already exist.
inMemory : `bool`, optional
If true make in-memory repository.
userQuery : `str`, optional
The user query to pass to ``makeGraph``, by default an empty string.

Returns
-------
butler : `~lsst.daf.butler.Butler`
Butler instance
qgraph : `~lsst.pipe.base.QuantumGraph`
Quantum graph instance
"""

if pipeline is None:
pipeline = makeSimplePipeline(nQuanta=nQuanta)

if butler is None:

if root is None:
raise ValueError("Must provide `root` when `butler` is None")

config = Config()
if not inMemory:
config["registry", "db"] = f"sqlite:///{root}/gen3.sqlite"
config["datastore", "cls"] = "lsst.daf.butler.datastores.posixDatastore.PosixDatastore"
repo = butlerTests.makeTestRepo(root, {}, config=config)
collection = "test"
butler = Butler(butler=repo, run=collection)

# Add dataset types to registry
registerDatasetTypes(butler.registry, pipeline.toExpandedPipeline())

# Add all needed dimensions to registry
butler.registry.insertDimensionData("instrument", dict(name="INSTR"))
butler.registry.insertDimensionData("detector", dict(instrument="INSTR", id=0, full_name="det0"))

# Add inputs to butler
data = numpy.array([0., 1., 2., 5.])
butler.put(data, "add_dataset0", instrument="INSTR", detector=0)

# Make the graph
builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting)
qgraph = builder.makeGraph(
pipeline,
collections=CollectionSearch.fromExpression(butler.run),
run=butler.run,
userQuery=userQuery
)

return butler, qgraph
41 changes: 39 additions & 2 deletions tests/test_graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@

"""Tests of things related to the GraphBuilder class."""

import logging
import unittest

from lsst.pipe.base import GraphBuilder

from lsst.pipe.base.tests import simpleQGraph
import lsst.utils.tests
from lsst.utils.tests import temporaryDirectory

_LOG = logging.getLogger(__name__)


class HelperTestCase(unittest.TestCase):
class VerifyInstrumentRestrictionTestCase(unittest.TestCase):

def testAddInstrument(self):
"""Verify the pipeline instrument is added to the query."""
Expand Down Expand Up @@ -76,6 +80,39 @@ def testTwoQueryInstruments(self):
with self.assertRaises(RuntimeError):
GraphBuilder._verifyInstrumentRestriction("HSC", "instrument = 'HSC' OR instrument = 'LSSTCam'")

def testNoQuery(self):
"""Test adding the instrument query to an empty query."""
self.assertEqual(GraphBuilder._verifyInstrumentRestriction("HSC", ""), "instrument = 'HSC'")

def testNoQueryNoInstruments(self):
"""Test the verify function when there is no instrument and no
query."""
self.assertEqual(GraphBuilder._verifyInstrumentRestriction("", ""), "")


class GraphBuilderTestCase(unittest.TestCase):

def testDefault(self):
"""Simple test to verify makeSimpleQGraph can be used to make a Quantum
Graph."""
with temporaryDirectory() as root:
# makeSimpleQGraph calls GraphBuilder.
butler, qgraph = simpleQGraph.makeSimpleQGraph(root=root)
# by default makeSimpleQGraph makes a graph with 5 nodes
self.assertEqual(len(qgraph), 5)

def testAddInstrumentMismatch(self):
"""Verify that a RuntimeError is raised if the instrument in the user
query does not match the instrument in the pipeline."""
with temporaryDirectory() as root:
pipeline = simpleQGraph.makeSimplePipeline(
nQuanta=5,
instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument")
with self.assertRaises(RuntimeError):
simpleQGraph.makeSimpleQGraph(root=root,
pipeline=pipeline,
userQuery="instrument = 'foo'")


if __name__ == "__main__":
lsst.utils.tests.init()
Expand Down