Skip to content

Commit

Permalink
Merge branch 'tickets/DM-34811'
Browse files Browse the repository at this point in the history
  • Loading branch information
natelust committed May 25, 2022
2 parents 963e35b + 475179c commit c3cc6a3
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 13 deletions.
1 change: 1 addition & 0 deletions doc/changes/DM-34811.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed a bug where dot graphs of pipelines did not correctly render edges between composite and component dataset types.
52 changes: 43 additions & 9 deletions python/lsst/ctrl/mpexec/dotTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
# -------------------------------
# Imports of standard modules --
# -------------------------------
import re

# -----------------------------
# Imports for other modules --
# -----------------------------
from lsst.daf.butler import DimensionUniverse
from lsst.daf.butler import DatasetType, DimensionUniverse
from lsst.pipe.base import Pipeline, iterConnections

# ----------------------------------
Expand Down Expand Up @@ -63,7 +64,7 @@ def _renderTaskNode(nodeName, taskDef, file, idx=None):
labels.append(f"index: {idx}")
if taskDef.connections:
# don't print collection of str directly to avoid visually noisy quotes
dimensions_str = ", ".join(taskDef.connections.dimensions)
dimensions_str = ", ".join(sorted(taskDef.connections.dimensions))
labels.append(f"dimensions: {dimensions_str}")
_renderNode(file, nodeName, "task", labels)

Expand All @@ -80,7 +81,7 @@ def _renderDSTypeNode(name, dimensions, file):
"""Render GV node for a dataset type"""
labels = [name]
if dimensions:
labels.append("Dimensions: " + ", ".join(dimensions))
labels.append("Dimensions: " + ", ".join(sorted(dimensions)))
_renderNode(file, name, "dsType", labels)


Expand Down Expand Up @@ -233,34 +234,67 @@ def expand_dimensions(dimensions):
allDatasets = set()
if isinstance(pipeline, Pipeline):
pipeline = pipeline.toExpandedPipeline()
for idx, taskDef in enumerate(pipeline):

# The next two lines are a workaround until DM-29658 at which time metadata
# connections should start working with the above code
labelToTaskName = {}
metadataNodesToLink = set()

for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)):

# node for a task
taskNodeName = "task{}".format(idx)
_renderTaskNode(taskNodeName, taskDef, file, idx)

for attr in iterConnections(taskDef.connections, "inputs"):
# next line is workaround until DM-29658
labelToTaskName[taskDef.label] = taskNodeName

_renderTaskNode(taskNodeName, taskDef, file, None)

metadataRePattern = re.compile("^(.*)_metadata$")
for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name):
if attr.name not in allDatasets:
dimensions = expand_dimensions(attr.dimensions)
_renderDSTypeNode(attr.name, dimensions, file)
allDatasets.add(attr.name)
nodeName, component = DatasetType.splitDatasetTypeName(attr.name)
_renderEdge(attr.name, taskNodeName, file)

for attr in iterConnections(taskDef.connections, "prerequisiteInputs"):
# connect component dataset types to the composite type that
# produced it
if component is not None and (nodeName, attr.name) not in allDatasets:
_renderEdge(nodeName, attr.name, file)
allDatasets.add((nodeName, attr.name))
if nodeName not in allDatasets:
dimensions = expand_dimensions(attr.dimensions)
_renderDSTypeNode(nodeName, dimensions, file)
# The next if block is a workaround until DM-29658 at which time
# metadata connections should start working with the above code
if (match := metadataRePattern.match(attr.name)) is not None:
matchTaskLabel = match.group(1)
metadataNodesToLink.add((matchTaskLabel, attr.name))

for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name):
if attr.name not in allDatasets:
dimensions = expand_dimensions(attr.dimensions)
_renderDSTypeNode(attr.name, dimensions, file)
allDatasets.add(attr.name)
# use dashed line for prerequisite edges to distinguish them
_renderEdge(attr.name, taskNodeName, file, style="dashed")

for attr in iterConnections(taskDef.connections, "outputs"):
for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name):
if attr.name not in allDatasets:
dimensions = expand_dimensions(attr.dimensions)
_renderDSTypeNode(attr.name, dimensions, file)
allDatasets.add(attr.name)
_renderEdge(taskNodeName, attr.name, file)

# This for loop is a workaround until DM-29658 at which time metadata
# connections should start working with the above code
for matchLabel, dsTypeName in metadataNodesToLink:
# only render an edge to metadata if the label is part of the current
# graph
if (result := labelToTaskName.get(matchLabel)) is not None:
_renderEdge(result, dsTypeName, file)

print("}", file=file)
if close:
file.close()
22 changes: 18 additions & 4 deletions tests/test_dotTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,24 @@ class DotToolsTestCase(unittest.TestCase):
def testPipeline2dot(self):
"""Tests for dotTools.pipeline2dot method"""
pipeline = _makePipeline(
[("A", ("B", "C"), "task1"), ("C", "E", "task2"), ("B", "D", "task3"), (("D", "E"), "F", "task4")]
[
("A", ("B", "C"), "task0"),
("C", "E", "task1"),
("B", "D", "task2"),
(("D", "E"), "F", "task3"),
("D.C", "G", "task4"),
("task3_metadata", "H", "task5"),
]
)
file = io.StringIO()
pipeline2dot(pipeline, file)

# It's hard to validate complete output, just checking few basic
# things, even that is not terribly stable.
lines = file.getvalue().strip().split("\n")
ndatasets = 6
ntasks = 4
nedges = 10
ndatasets = 10
ntasks = 6
nedges = 16
nextra = 2 # graph header and closing
self.assertEqual(len(lines), ndatasets + ntasks + nedges + nextra)

Expand All @@ -144,6 +151,13 @@ def testPipeline2dot(self):
self.assertEqual(node[0] + node[-1], '""')
continue

# make sure components are connected appropriately
self.assertIn('"D" -> "D.C";', file.getvalue())

# make sure there is a connection created for metadata if someone
# tries to read it in
self.assertIn('"task3" -> "task3_metadata"', file.getvalue())


class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
pass
Expand Down

0 comments on commit c3cc6a3

Please sign in to comment.