Skip to content

Commit

Permalink
Merge branch 'tickets/DM-12549'
Browse files Browse the repository at this point in the history
  • Loading branch information
kfindeisen committed Mar 24, 2021
2 parents 0f3ea05 + db8fb0d commit 082ad50
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/lsst/ctrl/mpexec/execFixupDataId.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
taskDef = graph.findTaskDefByLabel(self.taskLabel)
if taskDef is None:
raise ValueError(f"Cannot find task with label {self.taskLabel}")
quanta = list(graph.quantaForTask(taskDef))
quanta = list(graph.getNodesForTask(taskDef))
keyQuanta = defaultdict(list)
for q in quanta:
key = self._key(q)
Expand Down
21 changes: 18 additions & 3 deletions tests/test_executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import psutil
import sys
import time
from types import SimpleNamespace
import unittest
import warnings

Expand Down Expand Up @@ -64,13 +63,25 @@ def getDataIds(self, field):
return [quantum.dataId[field] for quantum in self.quanta]


class QuantumMock:
def __init__(self, dataId):
self.dataId = dataId

def __eq__(self, other):
return self.dataId == other.dataId

def __hash__(self):
# dict.__eq__ is order-insensitive
return hash(sorted(kv for kv in self.dataId.items()))


class QuantumIterDataMock:
"""Simple class to mock QuantumIterData.
"""
def __init__(self, index, taskDef, **dataId):
self.index = index
self.taskDef = taskDef
self.quantum = SimpleNamespace(dataId=dataId)
self.quantum = QuantumMock(dataId)
self.dependencies = set()
self.nodeId = NodeId(index, "DummyBuildString")

Expand All @@ -96,7 +107,11 @@ def findTaskDefByLabel(self, label):
if q.taskDef.label == label:
return q.taskDef

def quantaForTask(self, taskDef):
def getQuantaForTask(self, taskDef):
nodes = self.getNodesForTask(taskDef)
return {q.quantum for q in nodes}

def getNodesForTask(self, taskDef):
quanta = set()
for q in self:
if q.taskDef == taskDef:
Expand Down

0 comments on commit 082ad50

Please sign in to comment.