In [None]:
#|default_exp execution_manager

In [None]:
#|hide
from nblite import nbl_export; nbl_export();

# ExecutionManager

The execution manager layer deals with the execution of functions inside the pools.

It uses a round

In [None]:
#|export
from contextlib import contextmanager
from typing import Any
from collections.abc import Callable, Awaitable
import datetime
import builtins
from netrun._iutils import get_timestamp_utc
from datetime import datetime
import asyncio
from enum import Enum
from dataclasses import dataclass
import importlib
import uuid
import pickle
import random
import functools
import threading
import io

from netrun.rpc.base import RPCChannel
from netrun.pool.thread import ThreadPool
from netrun.pool.multiprocess import MultiprocessPool
from netrun.pool.aio import SingleWorkerPool
from netrun.pool.remote import RemotePoolClient

## Thread-safe print capture

These helpers provide thread-safe print capture using thread-local storage.
The patch is applied to `builtins.print` and uses reference counting to support
multiple ExecutionManager instances.

In [None]:
#|exporti
_print_capture_lock = threading.Lock()
_print_capture_refcount = 0
_original_builtins_print = None
_print_capture_local = threading.local()

def _capturing_print(*args, **kwargs):
    """Replacement for builtins.print that captures output when callback is set."""
    callback = getattr(_print_capture_local, 'callback', None)
    if callback is not None:
        # Capture mode active in this thread - capture to string and call callback
        output = io.StringIO()
        _original_builtins_print(*args, **kwargs, file=output)
        callback(output.getvalue())
    else:
        # Normal print
        _original_builtins_print(*args, **kwargs)

def _enable_print_capture():
    """Enable print capture. Idempotent - uses reference counting."""
    global _print_capture_refcount, _original_builtins_print
    with _print_capture_lock:
        if _print_capture_refcount == 0:
            _original_builtins_print = builtins.print
            builtins.print = _capturing_print
        _print_capture_refcount += 1

def _disable_print_capture():
    """Disable print capture. Only removes patch when refcount reaches 0."""
    global _print_capture_refcount
    with _print_capture_lock:
        _print_capture_refcount -= 1
        if _print_capture_refcount == 0:
            builtins.print = _original_builtins_print

@contextmanager
def _capture_prints(callback):
    """Context manager to set the print callback for the current thread."""
    _print_capture_local.callback = callback
    try:
        yield
    finally:
        _print_capture_local.callback = None

## Function runners

Helpers for the execution manager to run functions.

In [None]:
#|exporti
async def _async_func_runner(
    channel: RPCChannel,
    func: Callable[..., Awaitable[Any]],
    send_channel: bool,
    print_callback: Callable[[str], None] | None,
    args: tuple,
    kwargs: dict,
) -> Any:
    if print_callback is not None:
        with _capture_prints(print_callback):
            if asyncio.iscoroutinefunction(func):
                if send_channel:
                    return await func(channel, *args, **kwargs)
                else:
                    return await func(*args, **kwargs)
            else:
                if send_channel:
                    return func(channel, *args, **kwargs)
                else:
                    return func(*args, **kwargs)
    else:
        if asyncio.iscoroutinefunction(func):
            if send_channel:
                return await func(channel, *args, **kwargs)
            else:
                return await func(*args, **kwargs)
        else:
            if send_channel:
                return func(channel, *args, **kwargs)
            else:
                return func(*args, **kwargs)

In [None]:
#|exporti
def _func_runner(
    channel: RPCChannel,
    func: Callable[..., Any],
    send_channel: bool,
    print_callback: Callable[[str], None] | None,
    args: tuple,
    kwargs: dict,
    event_loop: asyncio.AbstractEventLoop,
) -> Any:
    if print_callback is not None:
        with _capture_prints(print_callback):
            if asyncio.iscoroutinefunction(func):
                if send_channel:
                    return event_loop.run_until_complete(func(channel, *args, **kwargs))
                else:
                    return event_loop.run_until_complete(func(*args, **kwargs))
            else:
                if send_channel:
                    return func(channel, *args, **kwargs)
                else:
                    return func(*args, **kwargs)
    else:
        if asyncio.iscoroutinefunction(func):
            if send_channel:
                return event_loop.run_until_complete(func(channel, *args, **kwargs))
            else:
                return event_loop.run_until_complete(func(*args, **kwargs))
        else:
            if send_channel:
                return func(channel, *args, **kwargs)
            else:
                return func(*args, **kwargs)

## Helpers

In [None]:
#|exporti
def _convert_to_str_if_not_serializable(obj: Any) -> tuple[bool, Any]:
    """Convert an object to string if it's not pickle-serializable.

    Returns:
        Tuple of (was_converted, result)
    """
    try:
        pickle.dumps(obj)
        return (False, obj)
    except (pickle.PicklingError, TypeError, AttributeError):
        return (True, str(obj))

## Workers

Worker functions for the execution manager's pools

In [None]:
#|exporti
class ExecutionManagerProtocolKeys(Enum):
    RUN = "exec-manager:run"
    """
    Run a function.
    Args: msg_id, func_import_path_or_key, run_id, send_channel, args, kwargs
    """

    UP_RUN_STARTED = "exec-manager-up:run-started"
    """
    Notification from the worker that a run has been submitted and started. 
    Args: msg_id, timestamp_utc_started
    """

    UP_RUN_RESPONSE = "exec-manager-up:run-response"
    """
    Response to RUN from the worker.
    Args: msg_id, converted_to_str, _res
    """

    SEND_FUNCTION = "exec-manager:send-function"
    """
    Used to send a function object to the worker, which can then be run using RUN.
    Args: msg_id, func_key, func
    """

    UP_SEND_FUNCTION_RESPONSE = "exec-manager-up:send-function-response"
    """
    Response to SEND_FUNCTION from the worker, to confirm that the function was received.
    Args: msg_id
    """

    UP_PRINT_BUFFER = "exec-manager-up:print-buffer"
    """
    Auto-flushed print buffer from the worker during function execution.
    Sent at regular intervals (print_flush_interval) and before UP_RUN_RESPONSE.
    Args: msg_id, run_id, _buffer
    """

In [None]:
#|exporti
def _worker_func(is_in_main_process: bool, channel, worker_id, print_flush_interval: float = 0.1, capture_prints: bool = True):
    """Worker function that handles execution manager protocol messages.

    Args:
        is_in_main_process: If True, results don't need to be serializable.
        channel: RPC channel for communication.
        worker_id: ID of this worker.
        print_flush_interval: Interval in seconds between automatic print buffer flushes.
        capture_prints: If True, capture print statements and send them back. If False, prints go to stdout normally.
    """
    # Enable print capture if configured
    if capture_prints:
        _enable_print_capture()

    event_loop = asyncio.new_event_loop()
    registered_functions: dict[str, Callable[..., Awaitable] | Callable[..., None]] = {}

    while True:
        key, data = channel.recv()
        # RUN
        if key == ExecutionManagerProtocolKeys.RUN.value:
            msg_id, func_import_path_or_key, run_id, send_channel, args, kwargs = data
            if func_import_path_or_key in registered_functions:
                func = registered_functions[func_import_path_or_key]
            else:
                module_path, func_name = func_import_path_or_key.rsplit(".", 1)
                module = importlib.import_module(module_path)
                func = getattr(module, func_name)

            # Set up auto-flushing print buffer (only if capture_prints is enabled)
            print_buffer: list[tuple[datetime, str]] = []
            last_flush_time = get_timestamp_utc()

            if capture_prints:
                def print_callback(text: str):
                    nonlocal last_flush_time
                    print_buffer.append((get_timestamp_utc(), text))
                    # Check if we should auto-flush
                    now = get_timestamp_utc()
                    elapsed = (now - last_flush_time).total_seconds()
                    if elapsed >= print_flush_interval and print_buffer:
                        channel.send(ExecutionManagerProtocolKeys.UP_PRINT_BUFFER.value, (msg_id, run_id, list(print_buffer)))
                        print_buffer.clear()
                        last_flush_time = now
            else:
                print_callback = None

            timestamp_utc_started = get_timestamp_utc()
            channel.send(ExecutionManagerProtocolKeys.UP_RUN_STARTED.value, (msg_id, timestamp_utc_started))
            res = _func_runner(
                channel=channel,
                func=func,
                send_channel=send_channel,
                print_callback=print_callback,
                args=args,
                kwargs=kwargs,
                event_loop=event_loop,
            )
            timestamp_utc_completed = get_timestamp_utc()
            if is_in_main_process:
                converted_to_str, _res = False, res
            else:
                converted_to_str, _res = _convert_to_str_if_not_serializable(res)

            # Flush any remaining print buffer before sending response
            if capture_prints and print_buffer:
                channel.send(ExecutionManagerProtocolKeys.UP_PRINT_BUFFER.value, (msg_id, run_id, list(print_buffer)))

            channel.send(ExecutionManagerProtocolKeys.UP_RUN_RESPONSE.value, (msg_id, timestamp_utc_started, timestamp_utc_completed, converted_to_str, _res))
        # SEND_FUNCTION
        elif key == ExecutionManagerProtocolKeys.SEND_FUNCTION.value:
            msg_id, func_key, func = data
            registered_functions[func_key] = func
            channel.send(ExecutionManagerProtocolKeys.UP_SEND_FUNCTION_RESPONSE.value, (msg_id,))
        else:
            raise ValueError(f"Unknown execution manager protocol key: '{key}'.")

def _thread_worker_func(channel, worker_id, print_flush_interval: float = 0.1, capture_prints: bool = True):
    return _worker_func(is_in_main_process=True, channel=channel, worker_id=worker_id, print_flush_interval=print_flush_interval, capture_prints=capture_prints)

# If the worker is in a multiprocess pool, then the result needs to be pickleable for it to be sent back without being converted as `str(result)`.
def _multiprocess_worker_func(channel, worker_id, print_flush_interval: float = 0.1, capture_prints: bool = True):
    return _worker_func(is_in_main_process=False, channel=channel, worker_id=worker_id, print_flush_interval=print_flush_interval, capture_prints=capture_prints)

In [None]:
#|exporti
async def _async_worker_func(channel, worker_id, print_flush_interval: float = 0.1, capture_prints: bool = True):
    """Async worker function that handles execution manager protocol messages.

    Note: For async workers, print buffer is sent all at once at the end of execution
    since the print callback cannot be async. Interval-based flushing is only
    supported for sync workers (thread/multiprocess pools).

    Args:
        channel: Async RPC channel for communication.
        worker_id: ID of this worker.
        print_flush_interval: Not used for async workers (kept for API consistency).
        capture_prints: If True, capture print statements and send them back. If False, prints go to stdout normally.
    """
    # Enable print capture if configured
    if capture_prints:
        _enable_print_capture()

    registered_functions: dict[str, Callable[..., Awaitable] | Callable[..., None]] = {}

    while True:
        key, data = await channel.recv()
        # RUN
        if key == ExecutionManagerProtocolKeys.RUN.value:
            msg_id, func_import_path_or_key, run_id, send_channel, args, kwargs = data
            if func_import_path_or_key in registered_functions:
                func = registered_functions[func_import_path_or_key]
            else:
                module_path, func_name = func_import_path_or_key.rsplit(".", 1)
                module = importlib.import_module(module_path)
                func = getattr(module, func_name)

            # For async workers, we just collect prints and send at the end
            # (interval-based flushing requires sync channel.send which isn't available)
            print_buffer: list[tuple[datetime, str]] = []

            if capture_prints:
                print_callback = lambda s: print_buffer.append((get_timestamp_utc(), s))
            else:
                print_callback = None

            timestamp_utc_started = get_timestamp_utc()
            await channel.send(ExecutionManagerProtocolKeys.UP_RUN_STARTED.value, (msg_id, timestamp_utc_started))
            res = await _async_func_runner(
                channel=channel,
                func=func,
                send_channel=send_channel,
                print_callback=print_callback,
                args=args,
                kwargs=kwargs,
            )
            timestamp_utc_completed = get_timestamp_utc()
            converted_to_str, _res = False, res

            # Send print buffer before response
            if capture_prints and print_buffer:
                await channel.send(ExecutionManagerProtocolKeys.UP_PRINT_BUFFER.value, (msg_id, run_id, list(print_buffer)))

            await channel.send(ExecutionManagerProtocolKeys.UP_RUN_RESPONSE.value, (msg_id, timestamp_utc_started, timestamp_utc_completed, converted_to_str, _res))
        # SEND_FUNCTION
        elif key == ExecutionManagerProtocolKeys.SEND_FUNCTION.value:
            msg_id, func_key, func = data
            registered_functions[func_key] = func
            await channel.send(ExecutionManagerProtocolKeys.UP_SEND_FUNCTION_RESPONSE.value, (msg_id,))
        else:
            raise ValueError(f"Unknown execution manager protocol key: '{key}'.")

## ExecutionManager

In [None]:
#|export
@dataclass
class JobResult:
    """Result of a job execution."""
    timestamp_utc_submitted: datetime
    timestamp_utc_started: datetime
    timestamp_utc_completed: datetime
    func_import_path_or_key: str
    pool_id: str
    worker_id: int
    converted_to_str: bool
    result: Any
    print_buffer: list[tuple[datetime, str]]
    """List of (timestamp, text) tuples for each print statement captured during execution."""

@dataclass
class SubmittedJobInfo:
    """Information about a submitted job."""
    run_id: str
    timestamp_utc_submitted: datetime
    timestamp_utc_started: datetime | None
    func_import_path_or_key: str
    pool_id: str
    worker_id: int

class RunAllocationMethod(Enum):
    """Method for allocating a job to a worker."""
    ROUND_ROBIN = "round-robin"
    RANDOM = "random"
    LEAST_BUSY = "least-busy"

In [None]:
#|export
PoolType = ThreadPool | MultiprocessPool | SingleWorkerPool | RemotePoolClient

class ExecutionManager:
    def __init__(self, pool_configs: dict[type[PoolType], tuple[str, dict[str, Any]]]):
        """
        Create an ExecutionManager with the given pool configurations.

        Args:
            pool_configs: A dictionary mapping pool_id to (pool_type, pool_init_kwargs).
                pool_type can be "thread", "multiprocess", "remote", or "main".
                pool_init_kwargs are passed to the pool constructor (excluding worker_fn).
        """
        self._pool_configs = pool_configs
        self._pools: dict[str, PoolType] = {}
        self._msg_recv_tasks: dict[str, asyncio.Task] = {}
        self._msgs: dict[str, dict[str, asyncio.Queue]] = {}
        self._started = False

        self._worker_jobs: dict[tuple[str, str], list[SubmittedJobInfo]] = {}  # (pool_id, worker_id) -> list of SubmittedJobInfo
        self._worker_round_robin_lst: list[tuple[str, str]] = []

    async def start(self) -> None:
        """Start all pools and initialize the execution manager."""
        if self._started:
            raise RuntimeError("ExecutionManager is already started.")

        # Enable print capture in the main process
        _enable_print_capture()

        for pool_id, (pool_type, pool_init_kwargs) in self._pool_configs.items():
            if 'worker_fn' in pool_init_kwargs:
                raise ValueError("The 'worker_fn' argument should not be specified in the pool config.")

            # Extract print_flush_interval and capture_prints from kwargs
            pool_kwargs = dict(pool_init_kwargs)
            print_flush_interval = pool_kwargs.pop('print_flush_interval', 0.1)
            capture_prints = pool_kwargs.pop('capture_prints', True)

            if pool_type == ThreadPool:
                worker_fn = functools.partial(_thread_worker_func, print_flush_interval=print_flush_interval, capture_prints=capture_prints)
                self._pools[pool_id] = ThreadPool(**pool_kwargs, worker_fn=worker_fn)
            elif pool_type == MultiprocessPool:
                worker_fn = functools.partial(_multiprocess_worker_func, print_flush_interval=print_flush_interval, capture_prints=capture_prints)
                self._pools[pool_id] = MultiprocessPool(**pool_kwargs, worker_fn=worker_fn)
            elif pool_type == RemotePoolClient:
                self._pools[pool_id] = RemotePoolClient(**pool_kwargs)
            elif pool_type == SingleWorkerPool:
                worker_fn = functools.partial(_async_worker_func, print_flush_interval=print_flush_interval, capture_prints=capture_prints)
                self._pools[pool_id] = SingleWorkerPool(**pool_kwargs, worker_fn=worker_fn)
            else:
                raise ValueError(f"Unknown pool type: '{pool_type}'.")

            self._msgs[pool_id] = {}

        # Start all pools
        for pool_id, pool in self._pools.items():
            await pool.start()

        # Initialize worker jobs tracking for each worker
        for pool_id, pool in self._pools.items():
            for worker_id in range(pool.num_workers):
                self._worker_jobs[(pool_id, worker_id)] = []

        # Start message receiver tasks after pools are started
        for pool_id in self._pools:
            self._msg_recv_tasks[pool_id] = asyncio.create_task(self._msg_recv_task_func(pool_id))

        self._started = True

    async def _msg_recv_task_func(self, pool_id: str):
        pool = self._pools[pool_id]
        while True:
            msg = await pool.recv()
            msg_id = msg.data[0]
            msg.data = msg.data[1:]
            await self._msgs[pool_id][msg_id].put(msg)

    async def _send_msg(self, pool_id: str, worker_id: str, key: str, data: Any) -> str:
        pool = self._pools[pool_id]
        msg_id = str(uuid.uuid4())

        # Check if the message receiver task for the pool is running, else propagate its exception
        msg_recv_task = self._msg_recv_tasks.get(pool_id)
        if msg_recv_task is not None and msg_recv_task.done():
            exc = msg_recv_task.exception()
            if exc is not None:
                raise exc

        # Create the queue BEFORE sending to avoid race condition where response
        # arrives before the queue is created
        self._msgs[pool_id][msg_id] = asyncio.Queue()

        await pool.send(
            worker_id=worker_id,
            key=key,
            data=(msg_id, *data),
        )

        return msg_id

    async def _recv_msg(self, pool_id: str, msg_id: str, expect: ExecutionManagerProtocolKeys, close_msg_queue: bool) -> tuple[str, Any]:
        # Get the queue for this message and the recv task
        msg_queue = self._msgs[pool_id][msg_id]
        msg_recv_task = self._msg_recv_tasks.get(pool_id)

        # Create a task to wait for the message
        get_task = asyncio.create_task(msg_queue.get())

        # Wait for either the message or the recv task to complete (crash)
        if msg_recv_task is not None:
            done, pending = await asyncio.wait(
                [get_task, msg_recv_task],
                return_when=asyncio.FIRST_COMPLETED
            )

            # If the recv task completed (crashed), cancel the get and propagate exception
            if msg_recv_task in done:
                get_task.cancel()
                try:
                    await get_task
                except asyncio.CancelledError:
                    pass
                if close_msg_queue:
                    del self._msgs[pool_id][msg_id]
                exc = msg_recv_task.exception()
                if exc is not None:
                    raise exc
                # Recv task ended without exception - shouldn't happen in normal operation
                raise RuntimeError("Message receiver task ended unexpectedly")

            # Otherwise, get the message result
            msg = get_task.result()
        else:
            msg = await get_task

        if close_msg_queue:
            del self._msgs[pool_id][msg_id]

        if msg.key != expect.value:
            raise ValueError(f"Expected message key '{expect.value}', got '{msg.key}'.")

        return msg

    async def run(
        self,
        pool_id: str,
        worker_id: str,
        func_import_path_or_key: str,
        send_channel: bool,
        func_args,
        func_kwargs,
        on_print: Callable[[list[tuple[datetime, str]]], None] | None = None,
    ) -> JobResult:
        """
        Run a function in a pool.

        Args:
            pool_id: The ID of the pool to run the function in.
            worker_id: The ID of the worker to run the function on.
            func_import_path_or_key: The import path or key of the function to run (for the latter, use send_function to register the function first)
            send_channel: Whether to send the worker RPC channel to the function.
            func_args: The arguments to pass to the function.
            func_kwargs: The keyword arguments to pass to the function.
            on_print: Optional callback called when print output is received from the worker.
                Called with a list of (timestamp, text) tuples. Called multiple times during
                execution as print buffers are flushed (at print_flush_interval).

        Returns:
            The result of the function.
        """
        pool = self._pools[pool_id]

        run_id = str(uuid.uuid4())
        timestamp_utc_submitted = get_timestamp_utc()
        msg_id = await self._send_msg(
            pool_id=pool_id,
            worker_id=worker_id,
            key=ExecutionManagerProtocolKeys.RUN.value,
            data=(func_import_path_or_key, run_id, send_channel, func_args, func_kwargs),
        )
        job_info = SubmittedJobInfo(
            run_id=run_id,
            timestamp_utc_submitted=timestamp_utc_submitted,
            timestamp_utc_started=None,
            func_import_path_or_key=func_import_path_or_key,
            pool_id=pool_id,
            worker_id=worker_id,
        )
        self._worker_jobs[(pool_id, worker_id)].append(job_info)
        if (pool_id, worker_id) in self._worker_round_robin_lst:
            self._worker_round_robin_lst.remove((pool_id, worker_id))
        self._worker_round_robin_lst.append((pool_id, worker_id))

        # Wait for UP_RUN_STARTED
        started_msg = await self._recv_msg(pool_id, msg_id, expect=ExecutionManagerProtocolKeys.UP_RUN_STARTED, close_msg_queue=False)
        job_info.timestamp_utc_started = started_msg.data[0]  # timestamp_utc_started

        # Accumulate print buffers and wait for UP_RUN_RESPONSE
        accumulated_print_buffer: list[tuple[datetime, str]] = []
        msg_queue = self._msgs[pool_id][msg_id]

        while True:
            msg_recv_task = self._msg_recv_tasks.get(pool_id)

            # Create a task to wait for the next message
            get_task = asyncio.create_task(msg_queue.get())

            # Wait for either the message or the recv task to complete (crash)
            if msg_recv_task is not None:
                done, pending = await asyncio.wait(
                    [get_task, msg_recv_task],
                    return_when=asyncio.FIRST_COMPLETED
                )

                # If the recv task completed (crashed), cancel the get and propagate exception
                if msg_recv_task in done:
                    get_task.cancel()
                    try:
                        await get_task
                    except asyncio.CancelledError:
                        pass
                    del self._msgs[pool_id][msg_id]
                    exc = msg_recv_task.exception()
                    if exc is not None:
                        raise exc
                    raise RuntimeError("Message receiver task ended unexpectedly")

                msg = get_task.result()
            else:
                msg = await get_task

            if msg.key == ExecutionManagerProtocolKeys.UP_PRINT_BUFFER.value:
                # Intermediate print buffer
                _run_id, _buffer = msg.data
                accumulated_print_buffer.extend(_buffer)
                if on_print is not None:
                    on_print(_buffer)
            elif msg.key == ExecutionManagerProtocolKeys.UP_RUN_RESPONSE.value:
                # Job completed - clean up and break
                del self._msgs[pool_id][msg_id]
                break
            else:
                raise ValueError(f"Unexpected message key: '{msg.key}'")

        self._worker_jobs[(pool_id, worker_id)].remove(job_info)

        timestamp_utc_started, timestamp_utc_completed, converted_to_str, _res = msg.data
        return JobResult(
            timestamp_utc_submitted=job_info.timestamp_utc_submitted,
            timestamp_utc_started=job_info.timestamp_utc_started,
            timestamp_utc_completed=timestamp_utc_completed,
            func_import_path_or_key=job_info.func_import_path_or_key,
            pool_id=job_info.pool_id,
            worker_id=job_info.worker_id,
            converted_to_str=converted_to_str,
            result=_res,
            print_buffer=accumulated_print_buffer,
        )

    async def run_allocate(
        self,
        pool_worker_ids: list[str | tuple[str, str]],
        allocation_method: RunAllocationMethod,
        func_import_path_or_key: str,
        send_channel: bool,
        func_args,
        func_kwargs,
        on_print: Callable[[list[tuple[datetime, str]]], None] | None = None,
    ) -> JobResult:
        worker_ids: list[tuple[str, int]] = []
        # Convert pool_worker_ids to a list of (pool_id, worker_id) tuples
        for _id in pool_worker_ids:
            if isinstance(_id, str):
                pool = self._pools[_id]
                for worker_id in range(pool.num_workers):
                    worker_ids.append((_id, worker_id))
            else:
                pool_id, worker_id = _id
                worker_ids.append((pool_id, worker_id))

        if not worker_ids:
            raise ValueError("No workers available for allocation.")

        # Select worker based on allocation method
        if allocation_method == RunAllocationMethod.ROUND_ROBIN:
            round_robin_lst = [p for p in self._worker_round_robin_lst if p in worker_ids]
            not_in_round_robin_lst = [p for p in worker_ids if p not in round_robin_lst]
            if not_in_round_robin_lst:
                pool_id, worker_id = not_in_round_robin_lst[0]
            else:
                pool_id, worker_id = round_robin_lst[0]

        elif allocation_method == RunAllocationMethod.RANDOM:
            pool_id, worker_id = random.choice(worker_ids)

        elif allocation_method == RunAllocationMethod.LEAST_BUSY:
            # Choose the worker with the fewest active jobs (jobs that have been submitted
            # but not yet completed). A worker with no recorded jobs is preferred.
            def active_job_count(pool_worker: tuple[str, str]) -> int:
                key = pool_worker
                jobs = self._worker_jobs.get(key, [])
                # Active = submitted but not yet completed; here we treat presence in the
                # list as active, and completion will remove the job from this structure.
                return len(jobs)

            pool_id, worker_id = min(worker_ids, key=active_job_count)

        else:
            raise ValueError(f"Unknown allocation method: '{allocation_method}'.")

        return await self.run(
            pool_id=pool_id,
            worker_id=worker_id,
            func_import_path_or_key=func_import_path_or_key,
            send_channel=send_channel,
            func_args=func_args,
            func_kwargs=func_kwargs,
            on_print=on_print,
        )

    async def send_function(self, pool_id: str, worker_id: str, func_key: str, func: Callable[..., Any]|Callable[..., Awaitable[Any]]) -> None:
        """
        Send a function to a worker in a pool, such that it can be run using the given key using `ExecutionManager.run`.

        Args:
            pool_id: The ID of the pool to send the function to.
            worker_id: The ID of the worker to send the function to.
            func_key: The key of the function to send. 
            func: The function to send.
        """
        # If the pool is a multiprocess pool or remote pool, the function needs to be pickleable.
        pool = self._pools[pool_id]
        if isinstance(pool, (MultiprocessPool, RemotePoolClient)):
            try:
                pickle.dumps(func)
            except pickle.PicklingError:
                raise ValueError(f"Function {func} (key = '{func_key}') is not pickleable. Cannot send to worker in pool '{pool_id}'.")

        msg_id = await self._send_msg(
            pool_id=pool_id,
            worker_id=worker_id,
            key=ExecutionManagerProtocolKeys.SEND_FUNCTION.value,
            data=(func_key, func),
        )
        await self._recv_msg(pool_id, msg_id, expect=ExecutionManagerProtocolKeys.UP_SEND_FUNCTION_RESPONSE, close_msg_queue=True)

    async def send_function_to_pool(self, pool_id: str, func_key: str, func: Callable[..., Any]|Callable[..., Awaitable[Any]]) -> None:
        """
        Send a function to all workers in a pool.

        Args:
            pool_id: The ID of the pool to send the function to.
            func_key: The key of the function to send.
            func: The function to send.
        """
        tasks = [asyncio.create_task(self.send_function(pool_id, worker_id, func_key, func)) for worker_id in range(self._pools[pool_id].num_workers)]
        await asyncio.gather(*tasks)

    async def close(self):
        """Close the execution manager and all its pools."""
        # Cancel message receiver tasks FIRST to prevent them from trying to
        # recv from closed pools (which would raise PoolNotStarted)
        errors = []
        for task in self._msg_recv_tasks.values():
            task.cancel()
        for task in self._msg_recv_tasks.values():
            try:
                await task
            except asyncio.CancelledError:
                pass
            except Exception as e:
                errors.append(e)

        # Now close the pools
        for pool in self._pools.values():
            await pool.close()

        # Disable print capture in the main process
        _disable_print_capture()

        self._started = False

        # Propagate any errors from the recv tasks
        if errors:
            raise errors[0]

    @property
    def pools(self) -> list[tuple[str, type[PoolType]]]:
        """Get list of pool IDs."""
        return [(k, type(v)) for k, v in self._pools.items()]

    def get_num_workers(self, pool_id: str) -> int:
        """Get the number of workers in a pool."""
        return self._pools[pool_id].num_workers

    def get_worker_ids(self, pool_id: str) -> list[str]:
        """Get the list of worker IDs for a pool."""
        return [f"{pool_id}_{i}" for i in range(self._pools[pool_id].num_workers)]

    def get_worker_jobs(self, pool_id: str, worker_id: int) -> list[SubmittedJobInfo]:
        """Get the list of currently submitted jobs for a worker."""
        return list(self._worker_jobs.get((pool_id, worker_id), []))

    def get_process_ids(self, pool_id: str) -> list[int]:
        """Get all process IDs for a MultiprocessPool.

        Args:
            pool_id: The ID of the pool.

        Returns:
            List of process indices (0 to num_processes-1).

        Raises:
            ValueError: If the pool is not a MultiprocessPool.
        """
        pool = self._pools[pool_id]
        if not isinstance(pool, MultiprocessPool):
            raise ValueError(f"Pool '{pool_id}' is not a MultiprocessPool (got {type(pool).__name__})")
        return list(range(pool.num_processes))

    async def flush_pool_stdout(
        self, pool_id: str, process_idx: int, timeout: float | None = None
    ) -> list[tuple[datetime, bool, str]]:
        """Flush and retrieve stdout/stderr buffer from a specific process in a MultiprocessPool.

        Args:
            pool_id: The ID of the pool.
            process_idx: Index of the process (0 to num_processes-1).
            timeout: Maximum time to wait for response in seconds.

        Returns:
            List of (timestamp, is_stdout, text) tuples.
            is_stdout is True for stdout, False for stderr.

        Raises:
            ValueError: If the pool is not a MultiprocessPool.
            PoolNotStarted: If the pool is not running.
        """
        pool = self._pools[pool_id]
        if not isinstance(pool, MultiprocessPool):
            raise ValueError(f"Pool '{pool_id}' is not a MultiprocessPool (got {type(pool).__name__})")
        return await pool.flush_stdout(process_idx, timeout=timeout)

    async def flush_all_pool_stdout(
        self, pool_id: str, timeout: float | None = None
    ) -> dict[int, list[tuple[datetime, bool, str]]]:
        """Flush and retrieve stdout/stderr buffers from all processes in a MultiprocessPool.

        Args:
            pool_id: The ID of the pool.
            timeout: Maximum time to wait for each process response.

        Returns:
            Dict mapping process_idx to list of (timestamp, is_stdout, text) tuples.

        Raises:
            ValueError: If the pool is not a MultiprocessPool.
            PoolNotStarted: If the pool is not running.
        """
        pool = self._pools[pool_id]
        if not isinstance(pool, MultiprocessPool):
            raise ValueError(f"Pool '{pool_id}' is not a MultiprocessPool (got {type(pool).__name__})")
        return await pool.flush_all_stdout(timeout=timeout)

    async def __aenter__(self) -> "ExecutionManager":
        """Context manager entry - starts the manager."""
        await self.start()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
        """Context manager exit - closes the manager."""
        await self.close()

## Examples

The following examples demonstrate how to use the ExecutionManager.

### Example 1: Basic usage with ThreadPool

This example shows the basic workflow of running a function on a worker.

In [None]:
print("=" * 50)
print("Example: Basic ExecutionManager Usage")
print("=" * 50)

def example_add(a: int, b: int) -> int:
    """A simple function that adds two numbers."""
    print(f"Adding {a} + {b}")
    return a + b

# Create an ExecutionManager with a thread pool
manager = ExecutionManager({
    "workers": (ThreadPool, {"num_workers": 2}),
})

async with manager:
    print(f"Pool IDs: {[pool_id for pool_id, _ in manager.pools]}")
    print(f"Workers in 'workers' pool: {manager.get_num_workers('workers')}")

    # Send a function to all workers in the pool
    await manager.send_function_to_pool("workers", "add", example_add)

    # Run the function on worker 0
    result = await manager.run(
        pool_id="workers",
        worker_id=0,
        func_import_path_or_key="add",
        send_channel=False,
        func_args=(3, 4),
        func_kwargs={},
    )

    print(f"\nResult: {result.result}")
    print(f"Submitted at: {result.timestamp_utc_submitted}")
    print(f"Started at: {result.timestamp_utc_started}")
    print(f"Completed at: {result.timestamp_utc_completed}")
    print(f"Was converted to str: {result.converted_to_str}")
    print(f"Print buffer: {result.print_buffer}")

Example: Basic ExecutionManager Usage
Pool IDs: ['workers']
Workers in 'workers' pool: 2

Result: 7
Submitted at: 2026-01-19 11:21:43.219941+00:00
Started at: 2026-01-19 11:21:43.219997+00:00
Completed at: 2026-01-19 11:21:43.220022+00:00
Was converted to str: False
Print buffer: [(datetime.datetime(2026, 1, 19, 11, 21, 43, 220011, tzinfo=datetime.timezone.utc), 'Adding 3 + 4\n')]


### Example 2: Running multiple jobs with allocation

This example shows how to use the allocation methods to distribute work.

In [None]:
def example_multiply(x: int, y: int) -> int:
    """Multiply two numbers."""
    import time
    time.sleep(0.1)  # Simulate some work
    print(f"Multiplying {x} * {y}")
    return x * y

print("=" * 50)
print("Example: Job Allocation")
print("=" * 50)

manager = ExecutionManager({
    "compute": (ThreadPool, {"num_workers": 3}),
})

async with manager:
    # Send the multiply function to all workers
    await manager.send_function_to_pool("compute", "multiply", example_multiply)

    # Run multiple jobs using round-robin allocation
    print("\nRunning 6 jobs with ROUND_ROBIN allocation:")
    tasks = []
    for i in range(6):
        task = asyncio.create_task(
            manager.run_allocate(
                pool_worker_ids=["compute"],  # Use all workers in "compute" pool
                allocation_method=RunAllocationMethod.ROUND_ROBIN,
                func_import_path_or_key="multiply",
                send_channel=False,
                func_args=(i, i + 1),
                func_kwargs={},
            )
        )
        tasks.append(task)

    results = await asyncio.gather(*tasks)
    for i, result in enumerate(results):
        print(f"  Job {i}: {result.result} (worker {result.worker_id})")

Example: Job Allocation

Running 6 jobs with ROUND_ROBIN allocation:


  Job 0: 0 (worker 0)
  Job 1: 2 (worker 0)
  Job 2: 6 (worker 0)
  Job 3: 12 (worker 0)
  Job 4: 20 (worker 0)
  Job 5: 30 (worker 0)


### Example 3: Streaming print output with on_print callback

This example shows how to use the `on_print` callback to receive print statements
in real-time as they are flushed from the worker. The `print_flush_interval` controls
how often the worker sends accumulated print statements.

In [None]:
import time as _time

def example_slow_print(iterations: int, delay: float) -> str:
    """A function that prints multiple times with delays."""
    for i in range(iterations):
        print(f"Processing step {i + 1}/{iterations}...")
        _time.sleep(delay)
    return f"Completed {iterations} steps"

print("=" * 50)
print("Example: Streaming Print Output")
print("=" * 50)

# Use a short flush interval (50ms) to see prints arrive in chunks
manager = ExecutionManager({
    "workers": (ThreadPool, {"num_workers": 1, "print_flush_interval": 0.05}),
})

received_count = 0

def print_callback(buffer):
    global received_count
    received_count += 1
    print(f"  [Callback {received_count}] Received {len(buffer)} print(s)")
    for timestamp, text in buffer:
        print(f"    {text.strip()}")

async with manager:
    await manager.send_function_to_pool("workers", "slow_print", example_slow_print)

    print("\nRunning function with on_print callback:")
    result = await manager.run(
        pool_id="workers",
        worker_id=0,
        func_import_path_or_key="slow_print",
        send_channel=False,
        func_args=(5, 0.1),  # 5 iterations, 100ms delay each
        func_kwargs={},
        on_print=print_callback,
    )

    print(f"\nResult: {result.result}")
    print(f"Total prints captured: {len(result.print_buffer)}")

Example: Streaming Print Output

Running function with on_print callback:
  [Callback 1] Received 2 print(s)
    Processing step 1/5...
    Processing step 2/5...


  [Callback 2] Received 1 print(s)
    Processing step 3/5...
  [Callback 3] Received 1 print(s)
    Processing step 4/5...


  [Callback 4] Received 1 print(s)
    Processing step 5/5...

Result: Completed 5 steps
Total prints captured: 5


In [None]:
[r[1] for r in result.print_buffer]

['Processing step 1/5...\n',
 'Processing step 2/5...\n',
 'Processing step 3/5...\n',
 'Processing step 4/5...\n',
 'Processing step 5/5...\n']