Skip to content

Commit

Permalink
Add a test for mismatch connection and metadata
Browse files Browse the repository at this point in the history
This test demonstrates that we can have task metadata
output from one task as TaskMetadata and be read into
the next task as a dict.
  • Loading branch information
timj committed Feb 9, 2022
1 parent ea1fd46 commit 56862c2
Showing 1 changed file with 94 additions and 1 deletion.
95 changes: 94 additions & 1 deletion tests/test_simple_pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,23 @@
import shutil
import tempfile
import unittest
from typing import Any, Dict

import lsst.daf.butler
import lsst.utils.tests
from lsst.ctrl.mpexec import SimplePipelineExecutor
from lsst.pex.config import Field
from lsst.pipe.base import PipelineTaskConfig, PipelineTaskConnections, TaskDef, connectionTypes
from lsst.pipe.base import (
PipelineTask,
PipelineTaskConfig,
PipelineTaskConnections,
Struct,
TaskDef,
TaskMetadata,
connectionTypes,
)
from lsst.pipe.base.tests.no_dimensions import NoDimensionsTestTask
from lsst.utils.introspection import get_full_type_name

TESTDIR = os.path.abspath(os.path.dirname(__file__))

Expand All @@ -49,6 +59,64 @@ class NoDimensionsTestConfig2(PipelineTaskConfig, pipelineConnections=NoDimensio
outputSC = Field(dtype=str, doc="Output storage class requested", default="dict")


class NoDimensionsMetadataTestConnections(PipelineTaskConnections, dimensions=set()):
input = connectionTypes.Input(
name="input", doc="some dict-y input data for testing", storageClass="StructuredDataDict"
)
# Deliberately choose a storage class that does not match the metadata
# default TaskMetadata storage class.
meta = connectionTypes.Input(
name="a_metadata", doc="Metadata from previous task", storageClass="StructuredDataDict"
)
output = connectionTypes.Output(
name="output", doc="some dict-y output data for testing", storageClass="StructuredDataDict"
)


class NoDimensionsMetadataTestConfig(
PipelineTaskConfig, pipelineConnections=NoDimensionsMetadataTestConnections
):
key = Field(dtype=str, doc="String key for the dict entry the task sets.", default="one")
value = Field(dtype=int, doc="Integer value for the dict entry the task sets.", default=1)
outputSC = Field(dtype=str, doc="Output storage class requested", default="dict")


class NoDimensionsMetadataTestTask(PipelineTask):
"""A simple pipeline task that can take a metadata as input."""

ConfigClass = NoDimensionsMetadataTestConfig
_DefaultName = "noDimensionsMetadataTest"

def run(self, input: Dict[str, int], meta: Dict[str, Any]) -> Struct:
"""Run the task, adding the configured key-value pair to the input
argument and returning it as the output.
Parameters
----------
input : `dict`
Dictionary to update and return.
Returns
-------
result : `lsst.pipe.base.Struct`
Struct with a single ``output`` attribute.
"""
self.log.info("Run metadata method given data of type: %s", get_full_type_name(input))
output = input.copy()
output[self.config.key] = self.config.value

self.log.info("Received task metadata (%s): %s", get_full_type_name(meta), meta)

# Can change the return type via configuration.
if "TaskMetadata" in self.config.outputSC:
output = TaskMetadata.from_dict(output)
elif type(output) == TaskMetadata:
# Want the output to be a dict
output = output.to_dict()
self.log.info("Run method returns data of type: %s", get_full_type_name(output))
return Struct(output=output)


class SimplePipelineExecutorTests(lsst.utils.tests.TestCase):
"""Test the SimplePipelineExecutor API with a trivial task."""

Expand Down Expand Up @@ -226,6 +294,31 @@ def test_from_pipeline_incompatible(self):
):
executor.run(register_dataset_types=True)

def test_from_pipeline_metadata(self):
"""Test two tasks where the output uses metadata from input."""
# Must configure a special pipeline for this test.
config_a = NoDimensionsTestTask.ConfigClass()
config_a.connections.output = "intermediate"
config_b = NoDimensionsMetadataTestTask.ConfigClass()
config_b.connections.input = "intermediate"
config_b.key = "two"
config_b.value = 2
task_defs = [
TaskDef(label="a", taskClass=NoDimensionsTestTask, config=config_a),
TaskDef(label="b", taskClass=NoDimensionsMetadataTestTask, config=config_b),
]
executor = SimplePipelineExecutor.from_pipeline(task_defs, butler=self.butler)

with self.assertLogs("test_simple_pipeline_executor", level="INFO") as cm:
quanta = executor.run(register_dataset_types=True)
for o in cm.output:
print(o)
self.assertIn(f"Received task metadata ({get_full_type_name(dict)})", "".join(cm.output))

self.assertEqual(len(quanta), 2)
self.assertEqual(self.butler.get("intermediate"), {"zero": 0, "one": 1})
self.assertEqual(self.butler.get("output"), {"zero": 0, "one": 1, "two": 2})

def test_from_pipeline_file(self):
"""Test executing a two quanta from different configurations of the
same task, with an executor created by the `from_pipeline_filename`
Expand Down

0 comments on commit 56862c2

Please sign in to comment.