# Tests for MultiprocessPool

In [None]:
#|default_exp pool.test_multiprocess

In [None]:
#|export
import pytest
import asyncio
from netrun.rpc.base import RecvTimeout
from netrun.pool.base import (
    PoolNotStarted,
    PoolAlreadyStarted,
)
from netrun.pool.multiprocess import MultiprocessPool

## Import Worker Functions

Worker functions are in an importable module for multiprocessing.

In [None]:
#|export
from .workers import echo_worker, compute_worker, pid_worker

## Test Pool Creation

In [None]:
#|export
def test_pool_creation():
    """Test creating a MultiprocessPool."""
    pool = MultiprocessPool(echo_worker, num_processes=2, threads_per_process=2)
    assert pool.num_workers == 4
    assert pool.num_processes == 2
    assert pool.threads_per_process == 2
    assert not pool.is_running

In [None]:
test_pool_creation();

In [None]:
#|export
def test_pool_default_threads():
    """Test that threads_per_process defaults to 1."""
    pool = MultiprocessPool(echo_worker, num_processes=3)
    assert pool.num_workers == 3
    assert pool.threads_per_process == 1

In [None]:
test_pool_default_threads();

In [None]:
#|export
def test_pool_invalid_num_processes():
    """Test that invalid num_processes raises ValueError."""
    with pytest.raises(ValueError):
        MultiprocessPool(echo_worker, num_processes=0)

    with pytest.raises(ValueError):
        MultiprocessPool(echo_worker, num_processes=-1)

In [None]:
test_pool_invalid_num_processes();

In [None]:
#|export
def test_pool_invalid_threads_per_process():
    """Test that invalid threads_per_process raises ValueError."""
    with pytest.raises(ValueError):
        MultiprocessPool(echo_worker, num_processes=2, threads_per_process=0)

In [None]:
test_pool_invalid_threads_per_process();

## Test Pool Lifecycle

In [None]:
#|export
@pytest.mark.asyncio
async def test_start_and_close():
    """Test starting and closing a pool."""
    pool = MultiprocessPool(echo_worker, num_processes=1, threads_per_process=1)

    assert not pool.is_running
    await pool.start()
    assert pool.is_running

    await pool.close()
    assert not pool.is_running

In [None]:
await test_start_and_close();

In [None]:
#|export
@pytest.mark.asyncio
async def test_double_start_raises():
    """Test that starting twice raises PoolAlreadyStarted."""
    pool = MultiprocessPool(echo_worker, num_processes=1)
    await pool.start()

    try:
        with pytest.raises(PoolAlreadyStarted):
            await pool.start()
    finally:
        await pool.close()

In [None]:
await test_double_start_raises();

In [None]:
#|export
@pytest.mark.asyncio
async def test_context_manager():
    """Test using pool as context manager."""
    async with MultiprocessPool(echo_worker, num_processes=1) as pool:
        assert pool.is_running

    assert not pool.is_running

In [None]:
await test_context_manager();

## Test Send/Recv

In [None]:
#|export
@pytest.mark.asyncio
async def test_send_recv_single():
    """Test sending and receiving a single message."""
    async with MultiprocessPool(echo_worker, num_processes=1, threads_per_process=1) as pool:
        await pool.send(worker_id=0, key="test", data="hello")
        msg = await pool.recv(timeout=10.0)

        assert msg.worker_id == 0
        assert msg.key == "echo:test"
        assert msg.data == {"worker_id": 0, "data": "hello"}

In [None]:
await test_send_recv_single();

In [None]:
#|export
@pytest.mark.asyncio
async def test_send_recv_multiple_workers():
    """Test sending to multiple workers across processes."""
    async with MultiprocessPool(echo_worker, num_processes=2, threads_per_process=2) as pool:
        # Send to each worker
        for i in range(pool.num_workers):
            await pool.send(worker_id=i, key="ping", data=i)

        # Receive all responses
        responses = []
        for _ in range(pool.num_workers):
            msg = await pool.recv(timeout=10.0)
            responses.append(msg)

        assert len(responses) == 4
        worker_ids = {msg.worker_id for msg in responses}
        assert worker_ids == {0, 1, 2, 3}

In [None]:
await test_send_recv_multiple_workers();

## Test Worker ID Mapping

In [None]:
#|export
@pytest.mark.asyncio
async def test_worker_id_mapping():
    """Test that worker IDs map correctly to processes and threads."""
    # 2 processes x 2 threads = 4 workers
    # Worker 0, 1 in process 0
    # Worker 2, 3 in process 1
    async with MultiprocessPool(pid_worker, num_processes=2, threads_per_process=2) as pool:
        for i in range(4):
            await pool.send(worker_id=i, key="get_pid", data=None)

        pids = {}
        for _ in range(4):
            msg = await pool.recv(timeout=10.0)
            pids[msg.data["worker_id"]] = msg.data["pid"]

        # Workers 0 and 1 should have same PID (same process)
        assert pids[0] == pids[1]
        # Workers 2 and 3 should have same PID (same process)
        assert pids[2] == pids[3]
        # Different processes should have different PIDs
        assert pids[0] != pids[2]

In [None]:
await test_worker_id_mapping();

## Test try_recv

In [None]:
#|export
@pytest.mark.asyncio
async def test_try_recv_empty():
    """Test try_recv when no messages pending."""
    async with MultiprocessPool(echo_worker, num_processes=1) as pool:
        result = await pool.try_recv()
        assert result is None

In [None]:
await test_try_recv_empty();

In [None]:
#|export
@pytest.mark.asyncio
async def test_try_recv_with_message():
    """Test try_recv with pending message."""
    async with MultiprocessPool(echo_worker, num_processes=1) as pool:
        await pool.send(worker_id=0, key="test", data="data")
        await asyncio.sleep(0.5)  # Let worker process

        result = await pool.try_recv()
        assert result is not None
        assert result.key == "echo:test"

In [None]:
await test_try_recv_with_message();

## Test Broadcast

In [None]:
#|export
@pytest.mark.asyncio
async def test_broadcast():
    """Test broadcasting to all workers."""
    async with MultiprocessPool(echo_worker, num_processes=2, threads_per_process=2) as pool:
        await pool.broadcast("config", {"setting": "value"})

        responses = []
        for _ in range(pool.num_workers):
            msg = await pool.recv(timeout=10.0)
            responses.append(msg)

        assert len(responses) == 4
        worker_ids = {msg.worker_id for msg in responses}
        assert worker_ids == {0, 1, 2, 3}

In [None]:
await test_broadcast();

## Test Timeout

In [None]:
#|export
@pytest.mark.asyncio
async def test_recv_timeout():
    """Test recv timeout."""
    async with MultiprocessPool(echo_worker, num_processes=1) as pool:
        with pytest.raises(RecvTimeout):
            await pool.recv(timeout=0.5)

In [None]:
await test_recv_timeout();

## Test Error Handling

In [None]:
#|export
@pytest.mark.asyncio
async def test_send_before_start_raises():
    """Test that sending before start raises PoolNotStarted."""
    pool = MultiprocessPool(echo_worker, num_processes=1)

    with pytest.raises(PoolNotStarted):
        await pool.send(worker_id=0, key="test", data="data")

In [None]:
await test_send_before_start_raises();

In [None]:
#|export
@pytest.mark.asyncio
async def test_recv_before_start_raises():
    """Test that receiving before start raises PoolNotStarted."""
    pool = MultiprocessPool(echo_worker, num_processes=1)

    with pytest.raises(PoolNotStarted):
        await pool.recv()

In [None]:
await test_recv_before_start_raises();

In [None]:
#|export
@pytest.mark.asyncio
async def test_invalid_worker_id():
    """Test that invalid worker_id raises ValueError."""
    async with MultiprocessPool(echo_worker, num_processes=2, threads_per_process=2) as pool:
        with pytest.raises(ValueError):
            await pool.send(worker_id=-1, key="test", data="data")

        with pytest.raises(ValueError):
            await pool.send(worker_id=4, key="test", data="data")

In [None]:
await test_invalid_worker_id();

## Test Computation

In [None]:
#|export
@pytest.mark.asyncio
async def test_compute_workers():
    """Test compute workers with actual computation."""
    async with MultiprocessPool(compute_worker, num_processes=2, threads_per_process=1) as pool:
        await pool.send(worker_id=0, key="square", data=7)
        await pool.send(worker_id=1, key="double", data=21)

        results = []
        for _ in range(2):
            msg = await pool.recv(timeout=10.0)
            results.append((msg.worker_id, msg.data))

        results.sort()  # Sort by worker_id
        assert results == [(0, 49), (1, 42)]

In [None]:
await test_compute_workers();

## Test Consistency

These tests verify that all messages are reliably delivered across multiple runs.

In [None]:
#|export
@pytest.mark.asyncio
async def test_consistency_single_run():
    """Test that all messages are received in a single run."""
    async with MultiprocessPool(echo_worker, num_processes=2, threads_per_process=2) as pool:
        num_messages = pool.num_workers

        # Send to each worker
        for i in range(num_messages):
            await pool.send(worker_id=i, key="ping", data=i)

        # Receive all responses
        received = []
        for _ in range(num_messages):
            msg = await pool.recv(timeout=10.0)
            received.append(msg.worker_id)

        # All workers should have responded
        assert sorted(received) == list(range(num_messages)), f"Missing responses: expected {list(range(num_messages))}, got {sorted(received)}"

In [None]:
await test_consistency_single_run();

In [None]:
#|export
@pytest.mark.asyncio
async def test_consistency_multiple_runs():
    """Test consistency across multiple pool create/destroy cycles."""
    for run in range(5):
        async with MultiprocessPool(echo_worker, num_processes=2, threads_per_process=2) as pool:
            num_messages = pool.num_workers

            # Send to each worker
            for i in range(num_messages):
                await pool.send(worker_id=i, key="ping", data=i)

            # Receive all responses
            received = []
            for _ in range(num_messages):
                msg = await pool.recv(timeout=10.0)
                received.append(msg.worker_id)

            assert sorted(received) == list(range(num_messages)), f"Run {run}: Missing responses"

In [None]:
await test_consistency_multiple_runs();

In [None]:
#|export
@pytest.mark.asyncio
async def test_consistency_many_messages():
    """Test consistency with many messages per worker."""
    async with MultiprocessPool(echo_worker, num_processes=2, threads_per_process=2) as pool:
        messages_per_worker = 10
        total_messages = pool.num_workers * messages_per_worker

        # Send multiple messages to each worker
        for round_num in range(messages_per_worker):
            for worker_id in range(pool.num_workers):
                await pool.send(worker_id, key=f"msg{round_num}", data=round_num)

        # Receive all responses
        received_count = 0
        for _ in range(total_messages):
            msg = await pool.recv(timeout=10.0)
            received_count += 1

        assert received_count == total_messages, f"Expected {total_messages} messages, got {received_count}"

In [None]:
await test_consistency_many_messages();

In [None]:
#|export
@pytest.mark.asyncio
async def test_consistency_rapid_cycles():
    """Test consistency with rapid pool creation and destruction."""
    for run in range(10):
        async with MultiprocessPool(echo_worker, num_processes=1, threads_per_process=2) as pool:
            # Quick send/recv cycle
            await pool.send(worker_id=0, key="quick", data=run)
            await pool.send(worker_id=1, key="quick", data=run)

            msg1 = await pool.recv(timeout=10.0)
            msg2 = await pool.recv(timeout=10.0)

            worker_ids = sorted([msg1.worker_id, msg2.worker_id])
            assert worker_ids == [0, 1], f"Run {run}: Expected workers [0, 1], got {worker_ids}"

In [None]:
await test_consistency_rapid_cycles();