# Thread and Process Pools

This module provides pool management for parallel execution of nodes.
Supports both `ThreadPoolExecutor` and `ProcessPoolExecutor`.

In [None]:
#|default_exp pools

In [None]:
#|hide
from nblite import nbl_export, show_doc; nbl_export();

In [None]:
#|export
import asyncio
import threading
import pickle
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, Future
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
from queue import Queue, Empty

if TYPE_CHECKING:
    from netrun.net import Net


class PoolType(Enum):
    """Type of execution pool."""
    THREAD = auto()
    PROCESS = auto()


@dataclass
class PoolConfig:
    """Configuration for a pool."""
    name: str
    pool_type: PoolType
    size: int

    def __post_init__(self):
        if self.size < 1:
            raise ValueError(f"Pool size must be at least 1, got {self.size}")


class PoolInitMode(Enum):
    """How start/stop functions are called for pooled nodes."""
    PER_WORKER = "per_worker"  # Once per worker thread/process
    GLOBAL = "global"  # Once globally

## Pool Manager

The `PoolManager` class manages all thread and process pools for a Net.

In [None]:
#|export
class WorkerState:
    """Tracks state for a single worker in a pool."""

    def __init__(self, worker_id: int, pool_name: str):
        self.worker_id = worker_id
        self.pool_name = pool_name
        self.active_tasks: int = 0
        self.initialized_nodes: Set[str] = set()
        self._lock = threading.Lock()

    def acquire_task(self) -> None:
        """Mark a task as starting on this worker."""
        with self._lock:
            self.active_tasks += 1

    def release_task(self) -> None:
        """Mark a task as completed on this worker."""
        with self._lock:
            self.active_tasks -= 1

    def mark_node_initialized(self, node_name: str) -> None:
        """Mark a node as initialized on this worker."""
        with self._lock:
            self.initialized_nodes.add(node_name)

    def is_node_initialized(self, node_name: str) -> bool:
        """Check if a node is initialized on this worker."""
        with self._lock:
            return node_name in self.initialized_nodes


class ManagedPool:
    """Wrapper around ThreadPoolExecutor or ProcessPoolExecutor."""

    def __init__(self, config: PoolConfig):
        self.config = config
        self.name = config.name
        self.pool_type = config.pool_type
        self.size = config.size

        self._executor: Optional[ThreadPoolExecutor | ProcessPoolExecutor] = None
        self._workers: List[WorkerState] = []
        self._started = False
        self._lock = threading.Lock()

        # For tracking which nodes have been globally initialized
        self._globally_initialized_nodes: Set[str] = set()

    def start(self) -> None:
        """Start the pool."""
        with self._lock:
            if self._started:
                return

            if self.pool_type == PoolType.THREAD:
                self._executor = ThreadPoolExecutor(
                    max_workers=self.size,
                    thread_name_prefix=f"netrun-{self.name}-"
                )
            else:
                self._executor = ProcessPoolExecutor(max_workers=self.size)

            # Create worker state trackers
            self._workers = [WorkerState(i, self.name) for i in range(self.size)]
            self._started = True

    def stop(self) -> None:
        """Stop the pool and wait for pending tasks."""
        with self._lock:
            if not self._started or self._executor is None:
                return

            self._executor.shutdown(wait=True)
            self._executor = None
            self._workers = []
            self._started = False
            self._globally_initialized_nodes.clear()

    def get_available_workers(self) -> int:
        """Get the number of workers with no active tasks."""
        with self._lock:
            return sum(1 for w in self._workers if w.active_tasks == 0)

    def get_total_active_tasks(self) -> int:
        """Get total number of active tasks across all workers."""
        with self._lock:
            return sum(w.active_tasks for w in self._workers)

    def get_least_busy_worker(self) -> Optional[WorkerState]:
        """Get the worker with the fewest active tasks."""
        with self._lock:
            if not self._workers:
                return None
            return min(self._workers, key=lambda w: w.active_tasks)

    def submit(
        self,
        fn: Callable,
        *args,
        worker_state: Optional[WorkerState] = None,
        **kwargs
    ) -> Future:
        """Submit a task to the pool."""
        if not self._started or self._executor is None:
            raise RuntimeError(f"Pool '{self.name}' is not started")

        # Track task on the assigned worker
        if worker_state is not None:
            worker_state.acquire_task()

        def wrapped_fn(*args, **kwargs):
            try:
                return fn(*args, **kwargs)
            finally:
                if worker_state is not None:
                    worker_state.release_task()

        return self._executor.submit(wrapped_fn, *args, **kwargs)

    def mark_node_globally_initialized(self, node_name: str) -> None:
        """Mark a node as globally initialized."""
        with self._lock:
            self._globally_initialized_nodes.add(node_name)

    def is_node_globally_initialized(self, node_name: str) -> bool:
        """Check if a node has been globally initialized."""
        with self._lock:
            return node_name in self._globally_initialized_nodes

    @property
    def is_thread_pool(self) -> bool:
        """Check if this is a thread pool."""
        return self.pool_type == PoolType.THREAD

    @property
    def is_process_pool(self) -> bool:
        """Check if this is a process pool."""
        return self.pool_type == PoolType.PROCESS


class PoolManager:
    """Manages all pools for a Net instance."""

    def __init__(
        self,
        thread_pools: Optional[Dict[str, dict]] = None,
        process_pools: Optional[Dict[str, dict]] = None,
    ):
        """
        Initialize the pool manager.

        Args:
            thread_pools: Dict of pool name -> config ({"size": N})
            process_pools: Dict of pool name -> config ({"size": N})
        """
        self._pools: Dict[str, ManagedPool] = {}
        self._started = False
        self._lock = threading.Lock()

        # Create thread pool configs
        if thread_pools:
            for name, config in thread_pools.items():
                pool_config = PoolConfig(
                    name=name,
                    pool_type=PoolType.THREAD,
                    size=config.get("size", 4),
                )
                self._pools[name] = ManagedPool(pool_config)

        # Create process pool configs
        if process_pools:
            for name, config in process_pools.items():
                pool_config = PoolConfig(
                    name=name,
                    pool_type=PoolType.PROCESS,
                    size=config.get("size", 4),
                )
                self._pools[name] = ManagedPool(pool_config)

    def start(self) -> None:
        """Start all pools."""
        with self._lock:
            if self._started:
                return
            for pool in self._pools.values():
                pool.start()
            self._started = True

    def stop(self) -> None:
        """Stop all pools."""
        with self._lock:
            if not self._started:
                return
            for pool in self._pools.values():
                pool.stop()
            self._started = False

    def get_pool(self, name: str) -> Optional[ManagedPool]:
        """Get a pool by name."""
        return self._pools.get(name)

    def get_pools(self, names: List[str]) -> List[ManagedPool]:
        """Get multiple pools by name."""
        return [self._pools[n] for n in names if n in self._pools]

    def select_pool(
        self,
        pool_names: List[str],
        algorithm: str = "least_busy"
    ) -> Optional[ManagedPool]:
        """
        Select a pool from multiple options.

        Args:
            pool_names: List of pool names to choose from
            algorithm: Selection algorithm ("least_busy")

        Returns:
            The selected pool, or None if no valid pools
        """
        pools = self.get_pools(pool_names)
        if not pools:
            return None

        if algorithm == "least_busy":
            # Select pool with most available workers (least busy)
            return max(pools, key=lambda p: p.get_available_workers())
        else:
            raise ValueError(f"Unknown pool selection algorithm: {algorithm}")

    def has_pools(self) -> bool:
        """Check if any pools are configured."""
        return len(self._pools) > 0

    @property
    def pool_names(self) -> List[str]:
        """Get all pool names."""
        return list(self._pools.keys())

## Process Pool Utilities

Utilities for running tasks in process pools, including pickling validation.

In [None]:
#|export
def validate_picklable(value: Any, context: str = "") -> None:
    """
    Validate that a value can be pickled for process pool execution.

    Args:
        value: The value to check
        context: Description for error messages

    Raises:
        ValueError: If the value cannot be pickled
    """
    try:
        pickle.dumps(value)
    except (pickle.PicklingError, TypeError, AttributeError) as e:
        ctx = f" ({context})" if context else ""
        raise ValueError(
            f"Value{ctx} cannot be pickled for process pool execution: {e}"
        ) from e


def run_in_process_with_event_loop(func: Callable, *args, **kwargs) -> Any:
    """
    Run a function in a process with its own event loop.

    For async functions, creates a new event loop and runs until complete.
    """
    import asyncio
    import inspect

    if asyncio.iscoroutinefunction(func) or inspect.iscoroutinefunction(func):
        # Create new event loop for this process
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        try:
            return loop.run_until_complete(func(*args, **kwargs))
        finally:
            loop.close()
    else:
        return func(*args, **kwargs)

## Background Net Runner

For running a Net in a background thread.

In [None]:
#|export
class BackgroundNetRunner:
    """
    Runs a Net in a background thread with control methods.

    Provides methods to control the net execution:
    - wait_until_blocked(): Wait for net to block
    - poll(): Check if net is blocked
    - pause(): Pause net execution
    - stop(): Stop net execution
    """

    def __init__(self, net: "Net"):
        self._net = net
        self._thread: Optional[threading.Thread] = None
        self._stop_event = threading.Event()
        self._blocked_event = threading.Event()
        self._exception: Optional[Exception] = None
        self._lock = threading.Lock()

    def start(self) -> None:
        """Start running the net in a background thread."""
        with self._lock:
            if self._thread is not None and self._thread.is_alive():
                raise RuntimeError("Net is already running in background")

            self._stop_event.clear()
            self._blocked_event.clear()
            self._exception = None

            self._thread = threading.Thread(
                target=self._run_loop,
                name="netrun-background",
                daemon=True,
            )
            self._thread.start()

    def _run_loop(self) -> None:
        """Main loop for background execution."""
        try:
            from netrun.net import NetState

            # Run until blocked, stopped, or error
            while not self._stop_event.is_set():
                # Check if net is paused or stopped
                if self._net.state in (NetState.PAUSED, NetState.STOPPED):
                    self._blocked_event.set()
                    break

                # Run one step
                had_work = self._net.run_step(start_epochs=True)

                if not had_work:
                    # Check if fully blocked
                    if not self._net.get_startable_epochs():
                        self._blocked_event.set()
                        # Wait a bit before checking again
                        if self._stop_event.wait(timeout=0.01):
                            break
        except Exception as e:
            self._exception = e
            self._blocked_event.set()

    def wait_until_blocked(self, timeout: Optional[float] = None) -> bool:
        """
        Wait until the net is blocked (no progress can be made).

        Returns True if blocked, False if timeout expired.
        """
        return self._blocked_event.wait(timeout=timeout)

    def poll(self) -> bool:
        """Check if the net is currently blocked."""
        return self._blocked_event.is_set()

    def pause(self) -> None:
        """Request the net to pause."""
        self._net.pause()

    def stop(self) -> None:
        """Stop the background execution."""
        self._stop_event.set()
        if self._thread is not None:
            self._thread.join(timeout=5.0)

    def get_exception(self) -> Optional[Exception]:
        """Get any exception that occurred during execution."""
        return self._exception

    @property
    def is_running(self) -> bool:
        """Check if the background thread is running."""
        return self._thread is not None and self._thread.is_alive()