Skip to content

Commit

Permalink
Improve scheduler performance for large graphs (DM-28418)
Browse files Browse the repository at this point in the history
Multiprocessing scheduler did not scale very well with the number of
nodes in a quantum graph. Re-designed data structures to avoid looping
through the list of nodes on every iteration.
  • Loading branch information
andy-slac committed Jan 20, 2021
1 parent 7f4067b commit 69df2a8
Showing 1 changed file with 104 additions and 76 deletions.
180 changes: 104 additions & 76 deletions python/lsst/ctrl/mpexec/mpGraphExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,14 @@ class _Job:
def __init__(self, qnode):
self.qnode = qnode
self.process = None
self.state = JobState.PENDING
self._state = JobState.PENDING
self.started = None

@property
def state(self):
"""Job processing state (JobState)"""
return self._state

def start(self, butler, quantumExecutor, startMethod=None):
"""Start process which runs the task.
Expand Down Expand Up @@ -95,7 +100,7 @@ def start(self, butler, quantumExecutor, startMethod=None):
)
self.process.start()
self.started = time.time()
self.state = JobState.RUNNING
self._state = JobState.RUNNING

@staticmethod
def _executeJob(quantumExecutor, taskDef, quantum_pickle, butler_pickle, logConfigState):
Expand Down Expand Up @@ -157,57 +162,72 @@ class _JobList:
"""
def __init__(self, iterable):
self.jobs = [_Job(qnode) for qnode in iterable]
self.pending = self.jobs[:]
self.running = []
self.finishedNodes = set()
self.failedNodes = set()
self.timedOutNodes = set()

def pending(self):
"""Return list of jobs that wait for execution.
Returns
-------
jobs : `list` [`_Job`]
List of jobs.
"""
return [job for job in self.jobs if job.state == JobState.PENDING]

def running(self):
"""Return list of jobs that are executing.
Returns
-------
jobs : `list` [`_Job`]
List of jobs.
"""
return [job for job in self.jobs if job.state == JobState.RUNNING]

def finishedNodes(self):
"""Return set of QuantumNodes that finished successfully (not failed).
Returns
-------
QuantumNodes : `set` [`~lsst.pipe.base.QuantumNode`]
Set of QuantumNodes that have successfully finished
"""
return set(job.qnode for job in self.jobs if job.state == JobState.FINISHED)

def failedNodes(self):
"""Return set of jobs IDs that failed for any reason.
def submit(self, job, butler, quantumExecutor, startMethod=None):
"""Submit one more job for execution
Returns
-------
QuantumNodes : `set` [`~lsst.pipe.base.QuantumNode`]
Set of QUantumNodes that failed during processing
Parameters
----------
job : `_Job`
Job to submit.
butler : `lsst.daf.butler.Butler`
Data butler instance.
quantumExecutor : `QuantumExecutor`
Executor for single quantum.
startMethod : `str`, optional
Start method from `multiprocessing` module.
"""
return set(job.qnode for job in self.jobs
if job.state in (JobState.FAILED, JobState.FAILED_DEP, JobState.TIMED_OUT))
# this will raise if job is not in pending list
self.pending.remove(job)
job.start(butler, quantumExecutor, startMethod)
self.running.append(job)

def timedOutIds(self):
"""Return set of jobs IDs that timed out.
def setJobState(self, job, state):
"""Update job state.
Returns
-------
jobsIds : `set` [`int`]
Set of integer job IDs.
Parameters
----------
job : `_Job`
Job to submit.
state : `JobState`
New job state, note that only FINISHED, FAILED, TIMED_OUT, or
FAILED_DEP state is acceptable.
"""
return set(job.qnode for job in self.jobs if job.state == JobState.TIMED_OUT)
allowedStates = (
JobState.FINISHED,
JobState.FAILED,
JobState.TIMED_OUT,
JobState.FAILED_DEP
)
assert state in allowedStates, f"State {state} not allowed here"

# remove job from pending/running lists
if job.state == JobState.PENDING:
self.pending.remove(job)
elif job.state == JobState.RUNNING:
self.running.remove(job)

qnode = job.qnode
# it should not be in any of these, but just in case
self.finishedNodes.discard(qnode)
self.failedNodes.discard(qnode)
self.timedOutNodes.discard(qnode)

job._state = state
if state == JobState.FINISHED:
self.finishedNodes.add(qnode)
elif state == JobState.FAILED:
self.failedNodes.add(qnode)
elif state == JobState.FAILED_DEP:
self.failedNodes.add(qnode)
elif state == JobState.TIMED_OUT:
self.failedNodes.add(qnode)
self.timedOutNodes.add(qnode)

def cleanup(self):
"""Do periodic cleanup for jobs that did not finish correctly.
Expand Down Expand Up @@ -355,28 +375,28 @@ def _executeQuantaMP(self, graph, butler):
raise MPGraphExecutorError(f"Task {taskDef.taskName} does not support multiprocessing;"
" use single process")

finished, failed = 0, 0
while jobs.pending() or jobs.running():
finishedCount, failedCount = 0, 0
while jobs.pending or jobs.running:

_LOG.debug("#pendingJobs: %s", len(jobs.pending()))
_LOG.debug("#runningJobs: %s", len(jobs.running()))
_LOG.debug("#pendingJobs: %s", len(jobs.pending))
_LOG.debug("#runningJobs: %s", len(jobs.running))

# See if any jobs have finished
for job in jobs.running():
for job in jobs.running:
if not job.process.is_alive():
_LOG.debug("finished: %s", job)
# finished
exitcode = job.process.exitcode
if exitcode == 0:
job.state = JobState.FINISHED
jobs.setJobState(job, JobState.FINISHED)
job.cleanup()
_LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
else:
job.state = JobState.FAILED
jobs.setJobState(job, JobState.FAILED)
job.cleanup()
_LOG.debug("failed: %s", job)
if self.failFast:
for stopJob in jobs.running():
for stopJob in jobs.running:
if stopJob is not job:
stopJob.stop()
raise MPGraphExecutorError(
Expand All @@ -390,7 +410,7 @@ def _executeQuantaMP(self, graph, butler):
# check for timeout
now = time.time()
if now - job.started > self.timeout:
job.state = JobState.TIMED_OUT
jobs.setJobState(job, JobState.TIMED_OUT)
_LOG.debug("Terminating job %s due to timeout", job)
job.stop()
job.cleanup()
Expand All @@ -402,46 +422,54 @@ def _executeQuantaMP(self, graph, butler):
"for remaining tasks.", self.timeout, job
)

# Fail jobs whose inputs failed, this may need several iterations
# if the order is not right, will be done in the next loop.
if jobs.failedNodes:
for job in jobs.pending:
jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
if jobInputNodes & jobs.failedNodes:
jobs.setJobState(job, JobState.FAILED_DEP)
_LOG.error("Upstream job failed for task %s, skipping this task.", job)

# see if we can start more jobs
for job in jobs.pending():

# check all dependencies
if graph.determineInputsToQuantumNode(job.qnode) & jobs.failedNodes():
# upstream job has failed, skipping this
job.state = JobState.FAILED_DEP
_LOG.error("Upstream job failed for task %s, skipping this task.", job)
elif graph.determineInputsToQuantumNode(job.qnode) <= jobs.finishedNodes():
# all dependencies have completed, can start new job
if len(jobs.running()) < self.numProc:
_LOG.debug("Sumbitting %s", job)
job.start(butler, self.quantumExecutor, self.startMethod)
if len(jobs.running) < self.numProc:
for job in jobs.pending:
jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
if jobInputNodes <= jobs.finishedNodes:
# all dependencies have completed, can start new job
if len(jobs.running) < self.numProc:
_LOG.debug("Sumbitting %s", job)
jobs.submit(job, butler, self.quantumExecutor, self.startMethod)
if len(jobs.running) >= self.numProc:
# cannot start any more jobs, wait until something finishes
break

# Do cleanup for timed out jobs if necessary.
jobs.cleanup()

# Print progress message if something changed.
newFinished, newFailed = len(jobs.finishedNodes()), len(jobs.failedNodes())
if (finished, failed) != (newFinished, newFailed):
finished, failed = newFinished, newFailed
newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
if (finishedCount, failedCount) != (newFinished, newFailed):
finishedCount, failedCount = newFinished, newFailed
totalCount = len(jobs.jobs)
_LOG.info("Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
finished, failed, totalCount - finished - failed, totalCount)
finishedCount, failedCount, totalCount - finishedCount - failedCount, totalCount)

# Here we want to wait until one of the running jobs completes
# but multiprocessing does not provide an API for that, for now
# just sleep a little bit and go back to the loop.
if jobs.running():
if jobs.running:
time.sleep(0.1)

if jobs.failedNodes():
if jobs.failedNodes:
# print list of failed jobs
_LOG.error("Failed jobs:")
for job in jobs.jobs:
if job.state != JobState.FINISHED:
_LOG.error(" - %s: %s", job.state, job)
_LOG.error(" - %s: %s", job.state.name, job)

# if any job failed raise an exception
if jobs.failedNodes() == jobs.timedOutIds():
if jobs.failedNodes == jobs.timedOutNodes:
raise MPTimeoutError("One or more tasks timed out during execution.")
else:
raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")

0 comments on commit 69df2a8

Please sign in to comment.