# Tests for WebSocketChannel (RPC Remote Module)

In [None]:
#|default_exp rpc.test_remote

In [None]:
#|export
import pytest
import asyncio

# Check if websockets is available
try:
    import websockets
    HAS_WEBSOCKETS = True
except ImportError:
    HAS_WEBSOCKETS = False

pytestmark = pytest.mark.skipif(not HAS_WEBSOCKETS, reason="websockets not installed")

In [None]:
#|export
from netrun.rpc.base import ChannelClosed, RecvTimeout
from netrun.rpc.remote import (
    WebSocketChannel,
    connect_channel,
    serve_background,
)

## Test Server and Client Connection

In [None]:
#|export
@pytest.mark.asyncio
async def test_connect_and_send_recv():
    """Test basic client-server communication."""
    received_on_server = []

    async def handler(channel: WebSocketChannel):
        key, data = await channel.recv()
        received_on_server.append((key, data))
        await channel.send("echo", data)

    async with serve_background(handler, "127.0.0.1", 18881):
        client = await connect_channel("ws://127.0.0.1:18881")
        try:
            await client.send("hello", "world")
            key, data = await client.recv(timeout=5.0)

            assert key == "echo"
            assert data == "world"
            assert received_on_server == [("hello", "world")]
        finally:
            await client.close()

In [None]:
await test_connect_and_send_recv();

In [None]:
#|export
@pytest.mark.asyncio
async def test_multiple_messages():
    """Test multiple messages between client and server."""
    async def handler(channel: WebSocketChannel):
        try:
            while True:
                key, data = await channel.recv()
                await channel.send(f"ack:{key}", data * 2)
        except ChannelClosed:
            pass

    async with serve_background(handler, "127.0.0.1", 18882):
        client = await connect_channel("ws://127.0.0.1:18882")
        try:
            for i in range(5):
                await client.send(f"msg{i}", i)

            responses = []
            for _ in range(5):
                key, data = await client.recv(timeout=5.0)
                responses.append((key, data))

            assert len(responses) == 5
            for key, data in responses:
                assert key.startswith("ack:msg")
        finally:
            await client.close()

In [None]:
await test_multiple_messages();

In [None]:
#|export
@pytest.mark.asyncio
async def test_bidirectional():
    """Test bidirectional communication."""
    async def handler(channel: WebSocketChannel):
        # Server initiates
        await channel.send("server_hello", "from server")
        # Wait for response
        key, data = await channel.recv(timeout=5.0)
        assert key == "client_hello"
        assert data == "from client"

    async with serve_background(handler, "127.0.0.1", 18883):
        client = await connect_channel("ws://127.0.0.1:18883")
        try:
            # Receive server's hello
            key, data = await client.recv(timeout=5.0)
            assert key == "server_hello"
            assert data == "from server"

            # Send client's hello
            await client.send("client_hello", "from client")
        finally:
            await client.close()

In [None]:
await test_bidirectional();

## Test try_recv

In [None]:
#|export
@pytest.mark.asyncio
async def test_try_recv_empty():
    """Test try_recv when no message is pending."""
    async def handler(channel: WebSocketChannel):
        await asyncio.sleep(10)  # Keep connection open

    async with serve_background(handler, "127.0.0.1", 18884):
        client = await connect_channel("ws://127.0.0.1:18884")
        try:
            result = await client.try_recv()
            assert result is None
        finally:
            await client.close()

In [None]:
await test_try_recv_empty();

In [None]:
#|export
@pytest.mark.asyncio
async def test_try_recv_with_message():
    """Test try_recv when message is pending."""
    async def handler(channel: WebSocketChannel):
        await channel.send("test", "data")
        await asyncio.sleep(10)  # Keep connection open

    async with serve_background(handler, "127.0.0.1", 18885):
        client = await connect_channel("ws://127.0.0.1:18885")
        try:
            # Wait for message to arrive (may take a bit over network)
            result = None
            for _ in range(20):  # Try up to 2 seconds
                await asyncio.sleep(0.1)
                result = await client.try_recv()
                if result is not None:
                    break
            assert result is not None
            key, data = result
            assert key == "test"
            assert data == "data"
        finally:
            await client.close()

In [None]:
await test_try_recv_with_message();

## Test Timeout

In [None]:
#|export
@pytest.mark.asyncio
async def test_recv_timeout():
    """Test recv timeout."""
    async def handler(channel: WebSocketChannel):
        await asyncio.sleep(10)  # Don't send anything

    async with serve_background(handler, "127.0.0.1", 18886):
        client = await connect_channel("ws://127.0.0.1:18886")
        try:
            with pytest.raises(RecvTimeout):
                await client.recv(timeout=0.1)
        finally:
            await client.close()

In [None]:
await test_recv_timeout();

## Test Channel Close

In [None]:
#|export
@pytest.mark.asyncio
async def test_client_close():
    """Test closing client channel."""
    async def handler(channel: WebSocketChannel):
        await asyncio.sleep(10)

    async with serve_background(handler, "127.0.0.1", 18887):
        client = await connect_channel("ws://127.0.0.1:18887")
        assert not client.is_closed
        await client.close()
        assert client.is_closed

In [None]:
await test_client_close();

In [None]:
#|export
@pytest.mark.asyncio
async def test_send_on_closed_raises():
    """Test sending on closed channel raises ChannelClosed."""
    async def handler(channel: WebSocketChannel):
        await asyncio.sleep(10)

    async with serve_background(handler, "127.0.0.1", 18888):
        client = await connect_channel("ws://127.0.0.1:18888")
        await client.close()

        with pytest.raises(ChannelClosed):
            await client.send("test", "data")

In [None]:
await test_send_on_closed_raises();

## Test Multiple Clients

In [None]:
#|export
@pytest.mark.asyncio
async def test_multiple_clients():
    """Test multiple clients connecting to same server."""
    client_ids = []

    async def handler(channel: WebSocketChannel):
        key, data = await channel.recv(timeout=5.0)
        client_ids.append(data)
        await channel.send("ack", data)

    async with serve_background(handler, "127.0.0.1", 18889):
        clients = []
        try:
            # Connect multiple clients
            for i in range(3):
                client = await connect_channel("ws://127.0.0.1:18889")
                clients.append(client)
                await client.send("id", i)

            # Receive acks
            for i, client in enumerate(clients):
                key, data = await client.recv(timeout=5.0)
                assert key == "ack"
        finally:
            for client in clients:
                await client.close()

    assert len(client_ids) == 3

In [None]:
await test_multiple_clients();

## Test Data Types

In [None]:
#|export
@pytest.mark.asyncio
async def test_send_various_types():
    """Test sending various data types over WebSocket."""
    received = []

    async def handler(channel: WebSocketChannel):
        try:
            while True:
                key, data = await channel.recv()
                received.append((key, data))
                await channel.send("ack", key)
        except ChannelClosed:
            pass

    async with serve_background(handler, "127.0.0.1", 18890):
        client = await connect_channel("ws://127.0.0.1:18890")
        try:
            test_data = [
                ("str", "hello"),
                ("int", 42),
                ("float", 3.14),
                ("list", [1, 2, 3]),
                ("dict", {"a": 1, "b": 2}),
                ("none", None),
                ("bool", True),
            ]

            for key, data in test_data:
                await client.send(key, data)
                await client.recv(timeout=5.0)  # Wait for ack

            await asyncio.sleep(0.1)  # Let server process
        finally:
            await client.close()

    assert len(received) == len(test_data)
    for (expected_key, expected_data), (actual_key, actual_data) in zip(test_data, received):
        assert actual_key == expected_key
        assert actual_data == expected_data

In [None]:
await test_send_various_types();