Skip to content

Commit

Permalink
[local-worker-mgr] Use cloudpickle (#306)
Browse files Browse the repository at this point in the history
Makes it easier to work with local classes, like the ones produced
by functions.
  • Loading branch information
mtrofin committed Oct 17, 2023
1 parent 1677344 commit 6960b47
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions compiler_opt/distributed/local/local_worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
local thread pool, or, if the task is 'urgent', it executes it promptly.
"""
import concurrent.futures
import cloudpickle
import dataclasses
import functools
import multiprocessing
Expand Down Expand Up @@ -59,8 +60,11 @@ class TaskResult:
value: Any


def _run_impl(pipe: connection.Connection, worker_class: 'type[worker.Worker]',
*args, **kwargs):
SerializedClass = bytes


def _run_impl(pipe: connection.Connection, worker_class: SerializedClass, *args,
**kwargs):
"""Worker process entrypoint."""

# A setting of 1 does not inhibit the while loop below from running since
Expand All @@ -69,7 +73,7 @@ def _run_impl(pipe: connection.Connection, worker_class: 'type[worker.Worker]',
# spawned at a time which execute given tasks. In the typical clang-spawning
# jobs, this effectively limits the number of clang instances spawned.
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
obj = worker_class(*args, **kwargs)
obj = cloudpickle.loads(worker_class)(*args, **kwargs)

# Pipes are not thread safe
pipe_lock = threading.Lock()
Expand Down Expand Up @@ -103,8 +107,8 @@ def on_done(f: concurrent.futures.Future):
pool.submit(application).add_done_callback(make_ondone(task.msgid))


def _run(pipe: connection.Connection, worker_class: 'type[worker.Worker]',
*args, **kwargs):
def _run(pipe: connection.Connection, worker_class: SerializedClass, *args,
**kwargs):
try:
_run_impl(pipe, worker_class, *args, **kwargs)
except BaseException as e:
Expand All @@ -127,7 +131,8 @@ def __init__(self):
# to handle high priority requests. The expectation is that the user
# achieves concurrency through multiprocessing, not multithreading.
self._process = multiprocessing.get_context().Process(
target=functools.partial(_run, child_pipe, cls, *args, **kwargs))
target=functools.partial(_run, child_pipe, cloudpickle.dumps(cls), *
args, **kwargs))
# lock for the msgid -> reply future map. The map will be set to None
# when we stop.
self._lock = threading.Lock()
Expand Down

0 comments on commit 6960b47

Please sign in to comment.