Skip to content

Commit

Permalink
add tests for GraphBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
n8pease committed Sep 23, 2020
1 parent a9010f5 commit b1b024e
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 10 deletions.
62 changes: 53 additions & 9 deletions python/lsst/pipe/base/tests/simpleQGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@
_LOG = logging.getLogger(__name__)


# SimpleInstrument has an instrument-like API as needed for unit testing, but
# can not explicitly depend on Instrument because pipe_base does not explicitly
# depend on obs_base.
class SimpleInstrument:

@staticmethod
def getName():
return "SimpleInstrument"

def applyConfigOverrides(self, name, config):
pass


class AddTaskConnections(pipeBase.PipelineTaskConnections,
dimensions=("instrument", "detector"),
defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}):
Expand Down Expand Up @@ -153,7 +166,42 @@ def registerDatasetTypes(registry, pipeline):
registry.registerDatasetType(datasetType)


def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExisting=False, inMemory=True):
def makeSimplePipeline(nQuanta, instrument=None):
"""Make a simple Pipeline for tests.
This is called by ``makeSimpleQGraph`` if no pipeline is passed to that
function. It can also be used to customize the pipeline used by
``makeSimpleQGraph`` function by calling this first and passing the result
to it.
Parameters
----------
nQuanta : `int`
The number of quanta to add to the pipeline.
instrument : `str` or `None`, optional
The importable name of an instrument to be added to the pipeline or
if no instrument should be added then an empty string or `None`, by
default None
Returns
-------
pipeline : `~lsst.pipe.base.Pipeline`
The created pipeline object.
"""
pipeline = pipeBase.Pipeline("test pipeline")
# make a bunch of tasks that execute in well defined order (via data
# dependencies)
for lvl in range(nQuanta):
pipeline.addTask(AddTask, f"task{lvl}")
pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", f"{lvl}")
pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", f"{lvl+1}")
if instrument:
pipeline.addInstrument(instrument)
return pipeline


def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExisting=False, inMemory=True,
userQuery=""):
"""Make simple QuantumGraph for tests.
Makes simple one-task pipeline with AddTask, sets up in-memory
Expand All @@ -178,6 +226,8 @@ def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExist
already exist.
inMemory : `bool`, optional
If true make in-memory repository.
userQuery : `str`, optional
The user query to pass to ``makeGraph``, by default an empty string.
Returns
-------
Expand All @@ -188,13 +238,7 @@ def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExist
"""

if pipeline is None:
pipeline = pipeBase.Pipeline("test pipeline")
# make a bunch of tasks that execute in well defined order (via data
# dependencies)
for lvl in range(nQuanta):
pipeline.addTask(AddTask, f"task{lvl}")
pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", f"{lvl}")
pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", f"{lvl+1}")
pipeline = makeSimplePipeline(nQuanta=nQuanta)

if butler is None:

Expand Down Expand Up @@ -226,7 +270,7 @@ def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExist
pipeline,
collections=CollectionSearch.fromExpression(butler.run),
run=butler.run,
userQuery=""
userQuery=userQuery
)

return butler, qgraph
30 changes: 29 additions & 1 deletion tests/test_graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@

"""Tests of things related to the GraphBuilder class."""

import logging
import unittest

from lsst.pipe.base import GraphBuilder

from lsst.pipe.base.tests import simpleQGraph
import lsst.utils.tests
from lsst.utils.tests import temporaryDirectory

_LOG = logging.getLogger(__name__)


class VerifyInstrumentRestrictionTestCase(unittest.TestCase):
Expand Down Expand Up @@ -86,6 +90,30 @@ def testNoQueryNoInstruments(self):
self.assertEqual(GraphBuilder._verifyInstrumentRestriction("", ""), "")


class GraphBuilderTestCase(unittest.TestCase):

def testDefault(self):
"""Simple test to verify makeSimpleQGraph can be used to make a Quantum
Graph."""
with temporaryDirectory() as root:
# makeSimpleQGraph calls GraphBuilder.
butler, qgraph = simpleQGraph.makeSimpleQGraph(root=root)
# by default makeSimpleQGraph makes a graph with 5 nodes
self.assertEqual(len(qgraph), 5)

def testAddInstrumentMismatch(self):
"""Verify that a RuntimeError is raised if the instrument in the user
query does not match the instrument in the pipeline."""
with temporaryDirectory() as root:
pipeline = simpleQGraph.makeSimplePipeline(
nQuanta=5,
instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument")
with self.assertRaises(RuntimeError):
simpleQGraph.makeSimpleQGraph(root=root,
pipeline=pipeline,
userQuery="instrument = 'foo'")


if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()

0 comments on commit b1b024e

Please sign in to comment.