Skip to content

Commit

Permalink
Update code to support new QuantumGraph type
Browse files Browse the repository at this point in the history
  • Loading branch information
natelust committed Sep 11, 2020
1 parent 18fd799 commit 608d5ef
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 134 deletions.
9 changes: 5 additions & 4 deletions python/lsst/ctrl/mpexec/cmdLineFwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,8 @@ def makeGraph(self, pipeline, args):
qgraph.save(pickleFile)

if args.save_single_quanta:
for iq, sqgraph in enumerate(qgraph.quantaAsQgraph()):
for iq, quantumNode in enumerate(qgraph):
sqgraph = qgraph.subset(quantumNode)
filename = args.save_single_quanta.format(iq)
with open(filename, "wb") as pickleFile:
sqgraph.save(pickleFile)
Expand Down Expand Up @@ -849,10 +850,10 @@ def _showGraph(self, graph):
graph : `QuantumGraph`
Execution graph.
"""
for taskNodes in graph:
print(taskNodes.taskDef)
for taskNode in graph.taskGraph:
print(taskNode)

for iq, quantum in enumerate(taskNodes.quanta):
for iq, quantum in enumerate(graph.quantaForTask(taskNode)):
print(" Quantum {}:".format(iq))
print(" inputs:")
for key, refs in quantum.predictedInputs.items():
Expand Down
52 changes: 12 additions & 40 deletions python/lsst/ctrl/mpexec/execFixupDataId.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

__all__ = ['ExecutionGraphFixup']

from typing import Any, Iterable, Sequence, Tuple, Union
from typing import Sequence, Union

from lsst.pipe.base import QuantumIterData
from lsst.pipe.base import QuantumGraph
from .executionGraphFixup import ExecutionGraphFixup


Expand Down Expand Up @@ -60,52 +60,24 @@ def assoc_fixup():
dimensions : `str` or sequence [`str`]
One or more dimension names, quanta execution will be ordered
according to values of these dimensions.
reverse : `bool`, optional
If `False` (default) then quanta with higher values of dimensions
will be executed after quanta with lower values, otherwise the order
is reversed.
"""

def __init__(self, taskLabel: str, dimensions: Union[str, Sequence[str]], reverse: bool = False):
self.taskLabel = taskLabel
self.dimensions = dimensions
self.reverse = reverse
if isinstance(self.dimensions, str):
self.dimensions = (self.dimensions, )
else:
self.dimensions = tuple(self.dimensions)

def _key(self, qdata: QuantumIterData) -> Tuple[Any, ...]:
"""Produce comparison key for quantum data.
Parameters
----------
qdata : `QuantumIterData`
Returns
-------
key : `tuple`
"""
dataId = qdata.quantum.dataId
key = tuple(dataId[dim] for dim in self.dimensions)
return key

def fixupQuanta(self, quanta: Iterable[QuantumIterData]) -> Iterable[QuantumIterData]:
# Docstring inherited from ExecutionGraphFixup.fixupQuanta
quanta = list(quanta)
# Index task quanta by the key
keyQuanta = {}
for qdata in quanta:
if qdata.taskDef.label == self.taskLabel:
key = self._key(qdata)
keyQuanta.setdefault(key, []).append(qdata)
if not keyQuanta:
def fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
taskDef = graph.findTaskDefByLabel(self.taskLabel)
if taskDef is None:
raise ValueError(f"Cannot find task with label {self.taskLabel}")
# order keys
keys = sorted(keyQuanta.keys(), reverse=self.reverse)
# for each quanta in a key add dependency to all quanta in a preceding key
for prev_key, key in zip(keys, keys[1:]):
prev_indices = frozenset(qdata.index for qdata in keyQuanta[prev_key])
for qdata in keyQuanta[key]:
qdata.dependencies |= prev_indices
return quanta
quanta = list(graph.quantaForTask(taskDef))
previous = quanta[0]
networkGraph = graph.graph
for node in quanta[1:]:
networkGraph.add_edge(previous, node)
previous = node
return graph
19 changes: 8 additions & 11 deletions python/lsst/ctrl/mpexec/executionGraphFixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
__all__ = ['ExecutionGraphFixup']

from abc import ABC, abstractmethod
from typing import Iterable

from lsst.pipe.base import QuantumIterData
from lsst.pipe.base import QuantumGraph


class ExecutionGraphFixup(ABC):
Expand All @@ -44,22 +43,20 @@ class ExecutionGraphFixup(ABC):
"""

@abstractmethod
def fixupQuanta(self, quanta: Iterable[QuantumIterData]) -> Iterable[QuantumIterData]:
def fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
"""Update quanta in a graph.
Potentially anything in the graph could be changed if it does not
break executor assumptions. Returned quanta will be re-ordered by
executor, if modifications result in a dependency cycle the executor
will raise an exception.
break executor assumptions. If modifications result in a dependency
cycle the executor will raise an exception.
Parameters
----------
quanta : iterable [`~lsst.pipe.base.QuantumIterData`]
Iterable of topologically ordered quanta as returned from
`lsst.pipe.base.QuantumGraph.traverse` method.
graph : QuantumGraph
Quantum Graph that will be executed by the executor
Yields
Returns
------
quantum : `~lsst.pipe.base.QuantumIterData`
graph : QuantumGraph
"""
raise NotImplementedError
79 changes: 32 additions & 47 deletions python/lsst/ctrl/mpexec/mpGraphExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import pickle
import time

from lsst.pipe.base.graph.graph import QuantumGraph

# -----------------------------
# Imports for other modules --
# -----------------------------
Expand Down Expand Up @@ -57,11 +59,12 @@ class _Job:
qdata : `~lsst.pipe.base.QuantumIterData`
Quantum and some associated information.
"""
def __init__(self, qdata):
def __init__(self, qdata, index):
self.qdata = qdata
self.process = None
self.state = JobState.PENDING
self.started = None
self.index = index

def start(self, butler, quantumExecutor):
"""Start process which runs the task.
Expand All @@ -82,7 +85,7 @@ def start(self, butler, quantumExecutor):
self.process = multiprocessing.Process(
target=self._executeJob,
args=(quantumExecutor, taskDef, quantum, butler_pickle),
name=f"task-{self.qdata.index}"
name=f"task-{self.index}"
)
self.process.start()
self.started = time.time()
Expand Down Expand Up @@ -132,7 +135,7 @@ class _JobList:
task dependencies.
"""
def __init__(self, iterable):
self.jobs = [_Job(qdata) for qdata in iterable]
self.jobs = [_Job(qdata, i) for i, qdata in enumerate(iterable)]

def pending(self):
"""Return list of jobs that wait for execution.
Expand Down Expand Up @@ -227,24 +230,24 @@ def __init__(self, numProc, timeout, quantumExecutor, *, failFast=False, executi

def execute(self, graph, butler):
# Docstring inherited from QuantumGraphExecutor.execute
quantaIter = self._fixupQuanta(graph.traverse())
graph = self._fixupQuanta(graph)
if self.numProc > 1:
self._executeQuantaMP(quantaIter, butler)
self._executeQuantaMP(graph, butler)
else:
self._executeQuantaInProcess(quantaIter, butler)
self._executeQuantaInProcess(graph, butler)

def _fixupQuanta(self, quantaIter):
def _fixupQuanta(self, graph: QuantumGraph):
"""Call fixup code to modify execution graph.
Parameters
----------
quantaIter : iterable of `~lsst.pipe.base.QuantumIterData`
Quanta as originated from a quantum graph.
graph : `QuantumGraph`
`QuantumGraph` to modify
Returns
-------
quantaIter : iterable of `~lsst.pipe.base.QuantumIterData`
Possibly updated set of quanta, properly ordered for execution.
graph : `QuantumGraph`
Modified `QuantumGraph`, properly ordered for execution.
Raises
------
Expand All @@ -253,66 +256,48 @@ def _fixupQuanta(self, quantaIter):
i.e. it has dependency cycles.
"""
if not self.executionGraphFixup:
return quantaIter
return graph

_LOG.debug("Call execution graph fixup method")
quantaIter = self.executionGraphFixup.fixupQuanta(quantaIter)

# need it correctly ordered as dependencies may have changed
# after modification, so do topo-sort
updatedQuanta = list(quantaIter)
quanta = []
ids = set()
_LOG.debug("Re-ordering execution graph")
while updatedQuanta:
# find quantum that has all dependencies resolved already
for i, qdata in enumerate(updatedQuanta):
if ids.issuperset(qdata.dependencies):
_LOG.debug("Found next quanta to execute: %s", qdata)
del updatedQuanta[i]
ids.add(qdata.index)
# we could yield here but I want to detect cycles before
# returning anything from this method
quanta.append(qdata)
break
else:
# means remaining quanta have dependency cycle
raise MPGraphExecutorError(
"Updated execution graph has dependency clycle.")
graph = self.executionGraphFixup.fixupQuanta(graph)

# Detect if there is now a cycle created within the graph
if graph.find_cycle():
raise MPGraphExecutorError(
"Updated execution graph has dependency cycle.")

return quanta
return graph

def _executeQuantaInProcess(self, iterable, butler):
def _executeQuantaInProcess(self, graph, butler):
"""Execute all Quanta in current process.
Parameters
----------
iterable : iterable of `~lsst.pipe.base.QuantumIterData`
Sequence if Quanta to execute. It is guaranteed that re-requisites
for a given Quantum will always appear before that Quantum.
graph : `QuantumGraph`
`QuantumGraph` that is to be executed
butler : `lsst.daf.butler.Butler`
Data butler instance
"""
for qdata in iterable:
for qdata in graph:
_LOG.debug("Executing %s", qdata)
self.quantumExecutor.execute(qdata.taskDef, qdata.quantum, butler)
taskDef = graph.taskDefForNode(qdata)
self.quantumExecutor.execute(taskDef, qdata.quantum, butler)

def _executeQuantaMP(self, iterable, butler):
def _executeQuantaMP(self, graph, butler):
"""Execute all Quanta in separate processes.
Parameters
----------
iterable : iterable of `~lsst.pipe.base.QuantumIterData`
Sequence if Quanta to execute. It is guaranteed that re-requisites
for a given Quantum will always appear before that Quantum.
graph : `QuantumGraph`
`QuantumGraph` that is to be executed.
butler : `lsst.daf.butler.Butler`
Data butler instance
"""

disableImplicitThreading() # To prevent thread contention

# re-pack input quantum data into jobs list
jobs = _JobList(iterable)
jobs = _JobList(graph)

# check that all tasks can run in sub-process
for job in jobs.jobs:
Expand Down
8 changes: 3 additions & 5 deletions python/lsst/ctrl/mpexec/preExecInit.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def initializeDatasetTypes(self, graph, registerDatasetTypes=False):
Raised if ``registerDatasetTypes`` is ``False`` and DatasetType
does not exist in registry.
"""
pipeline = list(nodes.taskDef for nodes in graph)
pipeline = graph.taskGraph

# Make dataset types for configurations
configDatasetTypes = [DatasetType(taskDef.configDatasetName, {},
Expand Down Expand Up @@ -178,8 +178,7 @@ def saveInitOutputs(self, graph):
potentially introduce some extensible mechanism for that.
"""
_LOG.debug("Will save InitOutputs for all tasks")
for taskNodes in graph:
taskDef = taskNodes.taskDef
for taskDef in graph.taskGraph:
task = self.taskFactory.makeTask(taskDef.taskClass, taskDef.config, None, self.butler)
for name in taskDef.connections.initOutputs:
attribute = getattr(taskDef.connections, name)
Expand Down Expand Up @@ -228,8 +227,7 @@ def logConfigMismatch(msg):
_LOG.debug("Will save Configs for all tasks")
# start transaction to rollback any changes on exceptions
with self.butler.transaction():
for taskNodes in graph:
taskDef = taskNodes.taskDef
for taskDef in graph.taskGraph:
configName = taskDef.configDatasetName

oldConfig = None
Expand Down

0 comments on commit 608d5ef

Please sign in to comment.