Skip to content

Commit

Permalink
Add method to get QuantumNodes by TaskDef
Browse files Browse the repository at this point in the history
Add a method to the QuantumGraph object that returns all the
QuantumNodes associated with a TaskDef.
  • Loading branch information
natelust committed Mar 22, 2021
1 parent 560fb7b commit 9d7063c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
18 changes: 18 additions & 0 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def _buildGraphs(self,

nodeNumberGenerator = count()
self._nodeIdMap: Dict[NodeId, QuantumNode] = {}
self._taskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
self._count = 0
for taskDef, quantumSet in self._quanta.items():
connections = taskDef.connections
Expand Down Expand Up @@ -134,6 +135,7 @@ def _buildGraphs(self,
inits = quantum.initInputs.values()
inputs = quantum.inputs.values()
value = QuantumNode(quantum, taskDef, nodeId)
self._taskToQuantumNode[taskDef].add(value)
self._nodeIdMap[nodeId] = value

for dsRef in chain(inits, inputs):
Expand Down Expand Up @@ -268,6 +270,22 @@ def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
"""
return frozenset(self._quanta[taskDef])

def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]:
"""Return all the `QuantumNodes` associated with a `TaskDef`.
Parameters
----------
taskDef : `TaskDef`
The `TaskDef` for which `Quantum` are to be queried
Returns
-------
frozenset of `QuantumNodes`
The `frozenset` of `QuantumNodes` that is associated with the
specified `TaskDef`.
"""
return frozenset(self._taskToQuantumNode[taskDef])

def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
"""Find all tasks that have the specified dataset type name as an
input.
Expand Down
9 changes: 8 additions & 1 deletion tests/test_quantumGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import pickle
import tempfile
from typing import Iterable
import unittest
import random
from lsst.daf.butler import DimensionUniverse
Expand All @@ -32,7 +33,7 @@
import lsst.pipe.base.connectionTypes as cT
from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config
from lsst.pex.config import Field
from lsst.pipe.base.graph.quantumNode import NodeId, BuildId
from lsst.pipe.base.graph.quantumNode import NodeId, BuildId, QuantumNode
import lsst.utils.tests

try:
Expand Down Expand Up @@ -251,6 +252,12 @@ def testGetQuantaForTask(self):
for task in self.tasks:
self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])

def testGetNodesForTask(self):
for task in self.tasks:
nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
quanta_in_node = set(n.quantum for n in nodes)
self.assertEqual(quanta_in_node, self.quantumMap[task])

def testFindTasksWithInput(self):
self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0],
self.tasks[1])
Expand Down

0 comments on commit 9d7063c

Please sign in to comment.