diff --git a/.gitignore b/.gitignore index 5bf2443ccb..0c3c190c57 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ .scannerwork/ .unasyncd_cache/ .venv/ +.venv* .vscode/ __pycache__/ build/ diff --git a/docs/reference/channels/backends/asyncpg.rst b/docs/reference/channels/backends/asyncpg.rst new file mode 100644 index 0000000000..91d44ecdf1 --- /dev/null +++ b/docs/reference/channels/backends/asyncpg.rst @@ -0,0 +1,5 @@ +asyncpg +======= + +.. automodule:: litestar.channels.backends.asyncpg + :members: diff --git a/docs/reference/channels/backends/index.rst b/docs/reference/channels/backends/index.rst index 010ae7e509..02deff518a 100644 --- a/docs/reference/channels/backends/index.rst +++ b/docs/reference/channels/backends/index.rst @@ -6,3 +6,5 @@ backends base memory redis + psycopg + asyncpg diff --git a/docs/reference/channels/backends/psycopg.rst b/docs/reference/channels/backends/psycopg.rst new file mode 100644 index 0000000000..4a8163db60 --- /dev/null +++ b/docs/reference/channels/backends/psycopg.rst @@ -0,0 +1,5 @@ +psycopg +======= + +.. automodule:: litestar.channels.backends.psycopg + :members: diff --git a/docs/usage/channels.rst b/docs/usage/channels.rst index cbf0ef2721..6e2cd13c02 100644 --- a/docs/usage/channels.rst +++ b/docs/usage/channels.rst @@ -399,7 +399,7 @@ implemented are: A basic in-memory backend, mostly useful for testing and local development, but still fully capable. Since it stores all data in-process, it can achieve the highest performance of all the backends, but at the same time is not suitable for - applications running on multiple processes. + applications running on multiple processes :class:`RedisChannelsPubSubBackend <.redis.RedisChannelsPubSubBackend>` A Redis based backend, using `Pub/Sub `_ to @@ -413,6 +413,17 @@ implemented are: when history is needed +:class:`AsyncPgChannelsBackend <.asyncpg.AsyncPgChannelsBackend>` + A postgres backend using the + `asyncpg `_ driver + + +:class:`PsycoPgChannelsBackend <.psycopg.AsyncPgChannelsBackend>` + A postgres backend using the `psycopg3 `_ + async driver + + + Integrating with websocket handlers ----------------------------------- diff --git a/litestar/channels/backends/asyncpg.py b/litestar/channels/backends/asyncpg.py new file mode 100644 index 0000000000..4b3948d179 --- /dev/null +++ b/litestar/channels/backends/asyncpg.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import asyncio +from contextlib import AsyncExitStack +from functools import partial +from typing import AsyncGenerator, Awaitable, Callable, Iterable, overload + +import asyncpg + +from litestar.channels import ChannelsBackend +from litestar.exceptions import ImproperlyConfiguredException + + +class AsyncPgChannelsBackend(ChannelsBackend): + _listener_conn: asyncpg.Connection + _queue: asyncio.Queue[tuple[str, bytes]] + + @overload + def __init__(self, dsn: str) -> None: + ... + + @overload + def __init__( + self, + *, + make_connection: Callable[[], Awaitable[asyncpg.Connection]], + ) -> None: + ... + + def __init__( + self, + dsn: str | None = None, + *, + make_connection: Callable[[], Awaitable[asyncpg.Connection]] | None = None, + ) -> None: + if not (dsn or make_connection): + raise ImproperlyConfiguredException("Need to specify dsn or make_connection") + + self._subscribed_channels: set[str] = set() + self._exit_stack = AsyncExitStack() + self._connect = make_connection or partial(asyncpg.connect, dsn=dsn) + + async def on_startup(self) -> None: + self._queue = asyncio.Queue() + self._listener_conn = await self._connect() + + async def on_shutdown(self) -> None: + await self._listener_conn.close() + del self._queue + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + dec_data = data.decode("utf-8") + + conn = await self._connect() + try: + for channel in channels: + await conn.execute("SELECT pg_notify($1, $2);", channel, dec_data) + finally: + await conn.close() + + async def subscribe(self, channels: Iterable[str]) -> None: + for channel in set(channels) - self._subscribed_channels: + await self._listener_conn.add_listener(channel, self._listener) # type: ignore[arg-type] + self._subscribed_channels.add(channel) + + async def unsubscribe(self, channels: Iterable[str]) -> None: + for channel in channels: + await self._listener_conn.remove_listener(channel, self._listener) # type: ignore[arg-type] + self._subscribed_channels = self._subscribed_channels - set(channels) + + async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]: + while self._queue: + yield await self._queue.get() + self._queue.task_done() + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + raise NotImplementedError() + + def _listener(self, /, connection: asyncpg.Connection, pid: int, channel: str, payload: object) -> None: + if not isinstance(payload, str): + raise RuntimeError("Invalid data received") + self._queue.put_nowait((channel, payload.encode("utf-8"))) diff --git a/litestar/channels/backends/psycopg.py b/litestar/channels/backends/psycopg.py new file mode 100644 index 0000000000..8cd5fff543 --- /dev/null +++ b/litestar/channels/backends/psycopg.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from typing import AsyncGenerator, Iterable + +import psycopg + +from .base import ChannelsBackend + + +def _safe_quote(ident: str) -> str: + return '"{}"'.format(ident.replace('"', '""')) + + +class PsycoPgChannelsBackend(ChannelsBackend): + _listener_conn: psycopg.AsyncConnection + + def __init__(self, pg_dsn: str) -> None: + self._pg_dsn = pg_dsn + self._subscribed_channels: set[str] = set() + self._exit_stack = AsyncExitStack() + + async def on_startup(self) -> None: + self._listener_conn = await psycopg.AsyncConnection.connect(self._pg_dsn, autocommit=True) + await self._exit_stack.enter_async_context(self._listener_conn) + + async def on_shutdown(self) -> None: + await self._exit_stack.aclose() + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + dec_data = data.decode("utf-8") + async with await psycopg.AsyncConnection.connect(self._pg_dsn) as conn: + for channel in channels: + await conn.execute("SELECT pg_notify(%s, %s);", (channel, dec_data)) + + async def subscribe(self, channels: Iterable[str]) -> None: + for channel in set(channels) - self._subscribed_channels: + # can't use placeholders in LISTEN + await self._listener_conn.execute(f"LISTEN {_safe_quote(channel)};") # pyright: ignore + + self._subscribed_channels.add(channel) + + async def unsubscribe(self, channels: Iterable[str]) -> None: + for channel in channels: + # can't use placeholders in UNLISTEN + await self._listener_conn.execute(f"UNLISTEN {_safe_quote(channel)};") # pyright: ignore + self._subscribed_channels = self._subscribed_channels - set(channels) + + async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]: + async for notify in self._listener_conn.notifies(): + yield notify.channel, notify.payload.encode("utf-8") + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + raise NotImplementedError() diff --git a/litestar/channels/plugin.py b/litestar/channels/plugin.py index ae7dcc78b3..985c337847 100644 --- a/litestar/channels/plugin.py +++ b/litestar/channels/plugin.py @@ -311,10 +311,10 @@ async def _sub_worker(self) -> None: subscriber.put_nowait(payload) async def _on_startup(self) -> None: + await self._backend.on_startup() self._pub_queue = Queue() self._pub_task = create_task(self._pub_worker()) self._sub_task = create_task(self._sub_worker()) - await self._backend.on_startup() if self._channels: await self._backend.subscribe(list(self._channels)) diff --git a/pdm.lock b/pdm.lock index 036a51bae2..eb8e513db3 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "standard", "jwt", "pydantic", "cli", "picologging", "dev-contrib", "piccolo", "prometheus", "dev", "mako", "test", "brotli", "cryptography", "linting", "attrs", "opentelemetry", "docs", "redis", "sqlalchemy", "full", "annotated-types", "jinja", "structlog", "minijinja"] strategy = ["cross_platform"] lock_version = "4.4" -content_hash = "sha256:621796e9ccd87bf1be1aa40fd5e3ec170a89be761aae62fed47879256848d88f" +content_hash = "sha256:d9d9469bfaa3932d8be1987dfba4650c9e88b5509e9fcce3ee7f039833980121" [[package]] name = "accessible-pygments" @@ -237,6 +237,10 @@ summary = "Automatically generate code examples for different Python versions in dependencies = [ "ruff>=0.0.260", ] +files = [ + {file = "auto_pytabs-0.4.0-py3-none-any.whl", hash = "sha256:941ca4f21b218249ee4d026ebaf4a8a7788a066fdb223571f1f7b93d44ac6a74"}, + {file = "auto_pytabs-0.4.0.tar.gz", hash = "sha256:4c596aa02ea20c6c85809e5f60a22aa60499dcaa637e52d6313d07c58c5bb61e"}, +] [[package]] name = "auto-pytabs" diff --git a/pyproject.toml b/pyproject.toml index 02d9cd1aa6..78c06456af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,9 @@ dev = [ "trio", "aiosqlite", "exceptiongroup; python_version < \"3.11\"", + "asyncpg>=0.29.0", + "psycopg[pool,binary]>=3.1.10", + "psycopg2-binary", ] dev-contrib = ["opentelemetry-sdk", "httpx-sse"] docs = [ diff --git a/tests/docker_service_fixtures.py b/tests/docker_service_fixtures.py index de71aace45..6efca36979 100644 --- a/tests/docker_service_fixtures.py +++ b/tests/docker_service_fixtures.py @@ -139,3 +139,8 @@ async def postgres_responsive(host: str) -> bool: return (await conn.fetchrow("SELECT 1"))[0] == 1 # type: ignore finally: await conn.close() + + +@pytest.fixture() +async def postgres_service(docker_services: DockerServiceRegistry) -> None: + await docker_services.start("postgres", check=postgres_responsive) diff --git a/tests/unit/test_channels/conftest.py b/tests/unit/test_channels/conftest.py index c95799143d..1f041ae901 100644 --- a/tests/unit/test_channels/conftest.py +++ b/tests/unit/test_channels/conftest.py @@ -3,7 +3,9 @@ import pytest from redis.asyncio import Redis as AsyncRedis +from litestar.channels.backends.asyncpg import AsyncPgChannelsBackend from litestar.channels.backends.memory import MemoryChannelsBackend +from litestar.channels.backends.psycopg import PsycoPgChannelsBackend from litestar.channels.backends.redis import RedisChannelsPubSubBackend, RedisChannelsStreamBackend @@ -37,3 +39,13 @@ def redis_pub_sub_backend(redis_client: AsyncRedis) -> RedisChannelsPubSubBacken @pytest.fixture() def memory_backend() -> MemoryChannelsBackend: return MemoryChannelsBackend(history=10) + + +@pytest.fixture() +def postgres_asyncpg_backend(postgres_service: None, docker_ip: str) -> AsyncPgChannelsBackend: + return AsyncPgChannelsBackend(f"postgres://postgres:super-secret@{docker_ip}:5423") + + +@pytest.fixture() +def postgres_psycopg_backend(postgres_service: None, docker_ip: str) -> PsycoPgChannelsBackend: + return PsycoPgChannelsBackend(f"postgres://postgres:super-secret@{docker_ip}:5423") diff --git a/tests/unit/test_channels/test_backends.py b/tests/unit/test_channels/test_backends.py index 17eb0baf8f..871d355131 100644 --- a/tests/unit/test_channels/test_backends.py +++ b/tests/unit/test_channels/test_backends.py @@ -3,14 +3,18 @@ import asyncio from datetime import timedelta from typing import AsyncGenerator, cast +from unittest.mock import AsyncMock, MagicMock import pytest from _pytest.fixtures import FixtureRequest from redis.asyncio.client import Redis from litestar.channels import ChannelsBackend +from litestar.channels.backends.asyncpg import AsyncPgChannelsBackend from litestar.channels.backends.memory import MemoryChannelsBackend +from litestar.channels.backends.psycopg import PsycoPgChannelsBackend from litestar.channels.backends.redis import RedisChannelsPubSubBackend, RedisChannelsStreamBackend +from litestar.exceptions import ImproperlyConfiguredException from litestar.utils.compat import async_next @@ -18,6 +22,8 @@ params=[ pytest.param("redis_pub_sub_backend", id="redis:pubsub", marks=pytest.mark.xdist_group("redis")), pytest.param("redis_stream_backend", id="redis:stream", marks=pytest.mark.xdist_group("redis")), + pytest.param("postgres_asyncpg_backend", id="postgres:asyncpg", marks=pytest.mark.xdist_group("postgres")), + pytest.param("postgres_psycopg_backend", id="postgres:psycopg", marks=pytest.mark.xdist_group("postgres")), pytest.param("memory_backend", id="memory"), ] ) @@ -82,7 +88,7 @@ async def test_unsubscribe_without_subscription(channels_backend: ChannelsBacken async def test_get_history( channels_backend: ChannelsBackend, history_limit: int | None, expected_history_length: int ) -> None: - if isinstance(channels_backend, RedisChannelsPubSubBackend): + if isinstance(channels_backend, (RedisChannelsPubSubBackend, AsyncPgChannelsBackend, PsycoPgChannelsBackend)): pytest.skip("Redis pub/sub backend does not support history") messages = [str(i).encode() for i in range(100)] @@ -97,7 +103,7 @@ async def test_get_history( async def test_discards_history_entries(channels_backend: ChannelsBackend) -> None: - if isinstance(channels_backend, RedisChannelsPubSubBackend): + if isinstance(channels_backend, (RedisChannelsPubSubBackend, AsyncPgChannelsBackend, PsycoPgChannelsBackend)): pytest.skip("Redis pub/sub backend does not support history") for _ in range(20): @@ -133,3 +139,35 @@ async def test_memory_publish_not_initialized_raises() -> None: with pytest.raises(RuntimeError): await backend.publish(b"foo", ["something"]) + + +@pytest.mark.xdist_group("postgres") +async def test_asyncpg_get_history(postgres_asyncpg_backend: AsyncPgChannelsBackend) -> None: + with pytest.raises(NotImplementedError): + await postgres_asyncpg_backend.get_history("something") + + +@pytest.mark.xdist_group("postgres") +async def test_psycopg_get_history(postgres_psycopg_backend: PsycoPgChannelsBackend) -> None: + with pytest.raises(NotImplementedError): + await postgres_psycopg_backend.get_history("something") + + +async def test_asyncpg_make_connection() -> None: + make_connection = AsyncMock() + + backend = AsyncPgChannelsBackend(make_connection=make_connection) + await backend.on_startup() + + make_connection.assert_awaited_once() + + +async def test_asyncpg_no_make_conn_or_dsn_passed_raises() -> None: + with pytest.raises(ImproperlyConfiguredException): + AsyncPgChannelsBackend() # type: ignore[call-overload] + + +def test_asyncpg_listener_raises_on_non_string_payload() -> None: + backend = AsyncPgChannelsBackend(make_connection=AsyncMock()) + with pytest.raises(RuntimeError): + backend._listener(connection=MagicMock(), pid=1, payload=b"abc", channel="foo") diff --git a/tests/unit/test_channels/test_plugin.py b/tests/unit/test_channels/test_plugin.py index c6743e6102..856c8bc9a8 100644 --- a/tests/unit/test_channels/test_plugin.py +++ b/tests/unit/test_channels/test_plugin.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import time from secrets import token_hex from typing import cast from unittest.mock import AsyncMock, MagicMock @@ -24,6 +25,8 @@ params=[ pytest.param("redis_pub_sub_backend", id="redis:pubsub", marks=pytest.mark.xdist_group("redis")), pytest.param("redis_stream_backend", id="redis:stream", marks=pytest.mark.xdist_group("redis")), + pytest.param("postgres_asyncpg_backend", id="postgres:asyncpg", marks=pytest.mark.xdist_group("postgres")), + pytest.param("postgres_psycopg_backend", id="postgres:psycopg", marks=pytest.mark.xdist_group("postgres")), pytest.param("memory_backend", id="memory"), ] ) @@ -119,7 +122,7 @@ def test_create_ws_route_handlers( @pytest.mark.flaky(reruns=5) -def test_ws_route_handlers_receive_arbitrary_message(channels_backend: ChannelsBackend) -> None: +async def test_ws_route_handlers_receive_arbitrary_message(channels_backend: ChannelsBackend) -> None: """The websocket handlers await `WebSocket.receive()` to detect disconnection and stop the subscription. This test ensures that the subscription is only stopped in the case of receiving a `websocket.disconnect` message. @@ -140,7 +143,7 @@ def test_ws_route_handlers_receive_arbitrary_message(channels_backend: ChannelsB @pytest.mark.flaky(reruns=5) -async def test_create_ws_route_handlers_arbitrary_channels_allowed(channels_backend: ChannelsBackend) -> None: +def test_create_ws_route_handlers_arbitrary_channels_allowed(channels_backend: ChannelsBackend) -> None: channels_plugin = ChannelsPlugin( backend=channels_backend, arbitrary_channels_allowed=True, @@ -155,6 +158,8 @@ async def test_create_ws_route_handlers_arbitrary_channels_allowed(channels_back channels_plugin.publish("something", "foo") assert ws.receive_text(timeout=2) == "something" + time.sleep(0.1) + with client.websocket_connect("/ws/bar") as ws: channels_plugin.publish("something else", "bar") assert ws.receive_text(timeout=2) == "something else"