diff --git a/.gitignore b/.gitignore
index 5bf2443ccb..0c3c190c57 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,6 +8,7 @@
.scannerwork/
.unasyncd_cache/
.venv/
+.venv*
.vscode/
__pycache__/
build/
diff --git a/docs/reference/channels/backends/asyncpg.rst b/docs/reference/channels/backends/asyncpg.rst
new file mode 100644
index 0000000000..91d44ecdf1
--- /dev/null
+++ b/docs/reference/channels/backends/asyncpg.rst
@@ -0,0 +1,5 @@
+asyncpg
+=======
+
+.. automodule:: litestar.channels.backends.asyncpg
+ :members:
diff --git a/docs/reference/channels/backends/index.rst b/docs/reference/channels/backends/index.rst
index 010ae7e509..02deff518a 100644
--- a/docs/reference/channels/backends/index.rst
+++ b/docs/reference/channels/backends/index.rst
@@ -6,3 +6,5 @@ backends
base
memory
redis
+ psycopg
+ asyncpg
diff --git a/docs/reference/channels/backends/psycopg.rst b/docs/reference/channels/backends/psycopg.rst
new file mode 100644
index 0000000000..4a8163db60
--- /dev/null
+++ b/docs/reference/channels/backends/psycopg.rst
@@ -0,0 +1,5 @@
+psycopg
+=======
+
+.. automodule:: litestar.channels.backends.psycopg
+ :members:
diff --git a/docs/usage/channels.rst b/docs/usage/channels.rst
index cbf0ef2721..6e2cd13c02 100644
--- a/docs/usage/channels.rst
+++ b/docs/usage/channels.rst
@@ -399,7 +399,7 @@ implemented are:
A basic in-memory backend, mostly useful for testing and local development, but
still fully capable. Since it stores all data in-process, it can achieve the highest
performance of all the backends, but at the same time is not suitable for
- applications running on multiple processes.
+ applications running on multiple processes
:class:`RedisChannelsPubSubBackend <.redis.RedisChannelsPubSubBackend>`
A Redis based backend, using `Pub/Sub `_ to
@@ -413,6 +413,17 @@ implemented are:
when history is needed
+:class:`AsyncPgChannelsBackend <.asyncpg.AsyncPgChannelsBackend>`
+ A postgres backend using the
+ `asyncpg `_ driver
+
+
+:class:`PsycoPgChannelsBackend <.psycopg.AsyncPgChannelsBackend>`
+ A postgres backend using the `psycopg3 `_
+ async driver
+
+
+
Integrating with websocket handlers
-----------------------------------
diff --git a/litestar/channels/backends/asyncpg.py b/litestar/channels/backends/asyncpg.py
new file mode 100644
index 0000000000..4b3948d179
--- /dev/null
+++ b/litestar/channels/backends/asyncpg.py
@@ -0,0 +1,82 @@
+from __future__ import annotations
+
+import asyncio
+from contextlib import AsyncExitStack
+from functools import partial
+from typing import AsyncGenerator, Awaitable, Callable, Iterable, overload
+
+import asyncpg
+
+from litestar.channels import ChannelsBackend
+from litestar.exceptions import ImproperlyConfiguredException
+
+
+class AsyncPgChannelsBackend(ChannelsBackend):
+ _listener_conn: asyncpg.Connection
+ _queue: asyncio.Queue[tuple[str, bytes]]
+
+ @overload
+ def __init__(self, dsn: str) -> None:
+ ...
+
+ @overload
+ def __init__(
+ self,
+ *,
+ make_connection: Callable[[], Awaitable[asyncpg.Connection]],
+ ) -> None:
+ ...
+
+ def __init__(
+ self,
+ dsn: str | None = None,
+ *,
+ make_connection: Callable[[], Awaitable[asyncpg.Connection]] | None = None,
+ ) -> None:
+ if not (dsn or make_connection):
+ raise ImproperlyConfiguredException("Need to specify dsn or make_connection")
+
+ self._subscribed_channels: set[str] = set()
+ self._exit_stack = AsyncExitStack()
+ self._connect = make_connection or partial(asyncpg.connect, dsn=dsn)
+
+ async def on_startup(self) -> None:
+ self._queue = asyncio.Queue()
+ self._listener_conn = await self._connect()
+
+ async def on_shutdown(self) -> None:
+ await self._listener_conn.close()
+ del self._queue
+
+ async def publish(self, data: bytes, channels: Iterable[str]) -> None:
+ dec_data = data.decode("utf-8")
+
+ conn = await self._connect()
+ try:
+ for channel in channels:
+ await conn.execute("SELECT pg_notify($1, $2);", channel, dec_data)
+ finally:
+ await conn.close()
+
+ async def subscribe(self, channels: Iterable[str]) -> None:
+ for channel in set(channels) - self._subscribed_channels:
+ await self._listener_conn.add_listener(channel, self._listener) # type: ignore[arg-type]
+ self._subscribed_channels.add(channel)
+
+ async def unsubscribe(self, channels: Iterable[str]) -> None:
+ for channel in channels:
+ await self._listener_conn.remove_listener(channel, self._listener) # type: ignore[arg-type]
+ self._subscribed_channels = self._subscribed_channels - set(channels)
+
+ async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
+ while self._queue:
+ yield await self._queue.get()
+ self._queue.task_done()
+
+ async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
+ raise NotImplementedError()
+
+ def _listener(self, /, connection: asyncpg.Connection, pid: int, channel: str, payload: object) -> None:
+ if not isinstance(payload, str):
+ raise RuntimeError("Invalid data received")
+ self._queue.put_nowait((channel, payload.encode("utf-8")))
diff --git a/litestar/channels/backends/psycopg.py b/litestar/channels/backends/psycopg.py
new file mode 100644
index 0000000000..8cd5fff543
--- /dev/null
+++ b/litestar/channels/backends/psycopg.py
@@ -0,0 +1,54 @@
+from __future__ import annotations
+
+from contextlib import AsyncExitStack
+from typing import AsyncGenerator, Iterable
+
+import psycopg
+
+from .base import ChannelsBackend
+
+
+def _safe_quote(ident: str) -> str:
+ return '"{}"'.format(ident.replace('"', '""'))
+
+
+class PsycoPgChannelsBackend(ChannelsBackend):
+ _listener_conn: psycopg.AsyncConnection
+
+ def __init__(self, pg_dsn: str) -> None:
+ self._pg_dsn = pg_dsn
+ self._subscribed_channels: set[str] = set()
+ self._exit_stack = AsyncExitStack()
+
+ async def on_startup(self) -> None:
+ self._listener_conn = await psycopg.AsyncConnection.connect(self._pg_dsn, autocommit=True)
+ await self._exit_stack.enter_async_context(self._listener_conn)
+
+ async def on_shutdown(self) -> None:
+ await self._exit_stack.aclose()
+
+ async def publish(self, data: bytes, channels: Iterable[str]) -> None:
+ dec_data = data.decode("utf-8")
+ async with await psycopg.AsyncConnection.connect(self._pg_dsn) as conn:
+ for channel in channels:
+ await conn.execute("SELECT pg_notify(%s, %s);", (channel, dec_data))
+
+ async def subscribe(self, channels: Iterable[str]) -> None:
+ for channel in set(channels) - self._subscribed_channels:
+ # can't use placeholders in LISTEN
+ await self._listener_conn.execute(f"LISTEN {_safe_quote(channel)};") # pyright: ignore
+
+ self._subscribed_channels.add(channel)
+
+ async def unsubscribe(self, channels: Iterable[str]) -> None:
+ for channel in channels:
+ # can't use placeholders in UNLISTEN
+ await self._listener_conn.execute(f"UNLISTEN {_safe_quote(channel)};") # pyright: ignore
+ self._subscribed_channels = self._subscribed_channels - set(channels)
+
+ async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
+ async for notify in self._listener_conn.notifies():
+ yield notify.channel, notify.payload.encode("utf-8")
+
+ async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
+ raise NotImplementedError()
diff --git a/litestar/channels/plugin.py b/litestar/channels/plugin.py
index ae7dcc78b3..985c337847 100644
--- a/litestar/channels/plugin.py
+++ b/litestar/channels/plugin.py
@@ -311,10 +311,10 @@ async def _sub_worker(self) -> None:
subscriber.put_nowait(payload)
async def _on_startup(self) -> None:
+ await self._backend.on_startup()
self._pub_queue = Queue()
self._pub_task = create_task(self._pub_worker())
self._sub_task = create_task(self._sub_worker())
- await self._backend.on_startup()
if self._channels:
await self._backend.subscribe(list(self._channels))
diff --git a/pdm.lock b/pdm.lock
index 036a51bae2..eb8e513db3 100644
--- a/pdm.lock
+++ b/pdm.lock
@@ -5,7 +5,7 @@
groups = ["default", "standard", "jwt", "pydantic", "cli", "picologging", "dev-contrib", "piccolo", "prometheus", "dev", "mako", "test", "brotli", "cryptography", "linting", "attrs", "opentelemetry", "docs", "redis", "sqlalchemy", "full", "annotated-types", "jinja", "structlog", "minijinja"]
strategy = ["cross_platform"]
lock_version = "4.4"
-content_hash = "sha256:621796e9ccd87bf1be1aa40fd5e3ec170a89be761aae62fed47879256848d88f"
+content_hash = "sha256:d9d9469bfaa3932d8be1987dfba4650c9e88b5509e9fcce3ee7f039833980121"
[[package]]
name = "accessible-pygments"
@@ -237,6 +237,10 @@ summary = "Automatically generate code examples for different Python versions in
dependencies = [
"ruff>=0.0.260",
]
+files = [
+ {file = "auto_pytabs-0.4.0-py3-none-any.whl", hash = "sha256:941ca4f21b218249ee4d026ebaf4a8a7788a066fdb223571f1f7b93d44ac6a74"},
+ {file = "auto_pytabs-0.4.0.tar.gz", hash = "sha256:4c596aa02ea20c6c85809e5f60a22aa60499dcaa637e52d6313d07c58c5bb61e"},
+]
[[package]]
name = "auto-pytabs"
diff --git a/pyproject.toml b/pyproject.toml
index 02d9cd1aa6..78c06456af 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -107,6 +107,9 @@ dev = [
"trio",
"aiosqlite",
"exceptiongroup; python_version < \"3.11\"",
+ "asyncpg>=0.29.0",
+ "psycopg[pool,binary]>=3.1.10",
+ "psycopg2-binary",
]
dev-contrib = ["opentelemetry-sdk", "httpx-sse"]
docs = [
diff --git a/tests/docker_service_fixtures.py b/tests/docker_service_fixtures.py
index de71aace45..6efca36979 100644
--- a/tests/docker_service_fixtures.py
+++ b/tests/docker_service_fixtures.py
@@ -139,3 +139,8 @@ async def postgres_responsive(host: str) -> bool:
return (await conn.fetchrow("SELECT 1"))[0] == 1 # type: ignore
finally:
await conn.close()
+
+
+@pytest.fixture()
+async def postgres_service(docker_services: DockerServiceRegistry) -> None:
+ await docker_services.start("postgres", check=postgres_responsive)
diff --git a/tests/unit/test_channels/conftest.py b/tests/unit/test_channels/conftest.py
index c95799143d..1f041ae901 100644
--- a/tests/unit/test_channels/conftest.py
+++ b/tests/unit/test_channels/conftest.py
@@ -3,7 +3,9 @@
import pytest
from redis.asyncio import Redis as AsyncRedis
+from litestar.channels.backends.asyncpg import AsyncPgChannelsBackend
from litestar.channels.backends.memory import MemoryChannelsBackend
+from litestar.channels.backends.psycopg import PsycoPgChannelsBackend
from litestar.channels.backends.redis import RedisChannelsPubSubBackend, RedisChannelsStreamBackend
@@ -37,3 +39,13 @@ def redis_pub_sub_backend(redis_client: AsyncRedis) -> RedisChannelsPubSubBacken
@pytest.fixture()
def memory_backend() -> MemoryChannelsBackend:
return MemoryChannelsBackend(history=10)
+
+
+@pytest.fixture()
+def postgres_asyncpg_backend(postgres_service: None, docker_ip: str) -> AsyncPgChannelsBackend:
+ return AsyncPgChannelsBackend(f"postgres://postgres:super-secret@{docker_ip}:5423")
+
+
+@pytest.fixture()
+def postgres_psycopg_backend(postgres_service: None, docker_ip: str) -> PsycoPgChannelsBackend:
+ return PsycoPgChannelsBackend(f"postgres://postgres:super-secret@{docker_ip}:5423")
diff --git a/tests/unit/test_channels/test_backends.py b/tests/unit/test_channels/test_backends.py
index 17eb0baf8f..871d355131 100644
--- a/tests/unit/test_channels/test_backends.py
+++ b/tests/unit/test_channels/test_backends.py
@@ -3,14 +3,18 @@
import asyncio
from datetime import timedelta
from typing import AsyncGenerator, cast
+from unittest.mock import AsyncMock, MagicMock
import pytest
from _pytest.fixtures import FixtureRequest
from redis.asyncio.client import Redis
from litestar.channels import ChannelsBackend
+from litestar.channels.backends.asyncpg import AsyncPgChannelsBackend
from litestar.channels.backends.memory import MemoryChannelsBackend
+from litestar.channels.backends.psycopg import PsycoPgChannelsBackend
from litestar.channels.backends.redis import RedisChannelsPubSubBackend, RedisChannelsStreamBackend
+from litestar.exceptions import ImproperlyConfiguredException
from litestar.utils.compat import async_next
@@ -18,6 +22,8 @@
params=[
pytest.param("redis_pub_sub_backend", id="redis:pubsub", marks=pytest.mark.xdist_group("redis")),
pytest.param("redis_stream_backend", id="redis:stream", marks=pytest.mark.xdist_group("redis")),
+ pytest.param("postgres_asyncpg_backend", id="postgres:asyncpg", marks=pytest.mark.xdist_group("postgres")),
+ pytest.param("postgres_psycopg_backend", id="postgres:psycopg", marks=pytest.mark.xdist_group("postgres")),
pytest.param("memory_backend", id="memory"),
]
)
@@ -82,7 +88,7 @@ async def test_unsubscribe_without_subscription(channels_backend: ChannelsBacken
async def test_get_history(
channels_backend: ChannelsBackend, history_limit: int | None, expected_history_length: int
) -> None:
- if isinstance(channels_backend, RedisChannelsPubSubBackend):
+ if isinstance(channels_backend, (RedisChannelsPubSubBackend, AsyncPgChannelsBackend, PsycoPgChannelsBackend)):
pytest.skip("Redis pub/sub backend does not support history")
messages = [str(i).encode() for i in range(100)]
@@ -97,7 +103,7 @@ async def test_get_history(
async def test_discards_history_entries(channels_backend: ChannelsBackend) -> None:
- if isinstance(channels_backend, RedisChannelsPubSubBackend):
+ if isinstance(channels_backend, (RedisChannelsPubSubBackend, AsyncPgChannelsBackend, PsycoPgChannelsBackend)):
pytest.skip("Redis pub/sub backend does not support history")
for _ in range(20):
@@ -133,3 +139,35 @@ async def test_memory_publish_not_initialized_raises() -> None:
with pytest.raises(RuntimeError):
await backend.publish(b"foo", ["something"])
+
+
+@pytest.mark.xdist_group("postgres")
+async def test_asyncpg_get_history(postgres_asyncpg_backend: AsyncPgChannelsBackend) -> None:
+ with pytest.raises(NotImplementedError):
+ await postgres_asyncpg_backend.get_history("something")
+
+
+@pytest.mark.xdist_group("postgres")
+async def test_psycopg_get_history(postgres_psycopg_backend: PsycoPgChannelsBackend) -> None:
+ with pytest.raises(NotImplementedError):
+ await postgres_psycopg_backend.get_history("something")
+
+
+async def test_asyncpg_make_connection() -> None:
+ make_connection = AsyncMock()
+
+ backend = AsyncPgChannelsBackend(make_connection=make_connection)
+ await backend.on_startup()
+
+ make_connection.assert_awaited_once()
+
+
+async def test_asyncpg_no_make_conn_or_dsn_passed_raises() -> None:
+ with pytest.raises(ImproperlyConfiguredException):
+ AsyncPgChannelsBackend() # type: ignore[call-overload]
+
+
+def test_asyncpg_listener_raises_on_non_string_payload() -> None:
+ backend = AsyncPgChannelsBackend(make_connection=AsyncMock())
+ with pytest.raises(RuntimeError):
+ backend._listener(connection=MagicMock(), pid=1, payload=b"abc", channel="foo")
diff --git a/tests/unit/test_channels/test_plugin.py b/tests/unit/test_channels/test_plugin.py
index c6743e6102..856c8bc9a8 100644
--- a/tests/unit/test_channels/test_plugin.py
+++ b/tests/unit/test_channels/test_plugin.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
+import time
from secrets import token_hex
from typing import cast
from unittest.mock import AsyncMock, MagicMock
@@ -24,6 +25,8 @@
params=[
pytest.param("redis_pub_sub_backend", id="redis:pubsub", marks=pytest.mark.xdist_group("redis")),
pytest.param("redis_stream_backend", id="redis:stream", marks=pytest.mark.xdist_group("redis")),
+ pytest.param("postgres_asyncpg_backend", id="postgres:asyncpg", marks=pytest.mark.xdist_group("postgres")),
+ pytest.param("postgres_psycopg_backend", id="postgres:psycopg", marks=pytest.mark.xdist_group("postgres")),
pytest.param("memory_backend", id="memory"),
]
)
@@ -119,7 +122,7 @@ def test_create_ws_route_handlers(
@pytest.mark.flaky(reruns=5)
-def test_ws_route_handlers_receive_arbitrary_message(channels_backend: ChannelsBackend) -> None:
+async def test_ws_route_handlers_receive_arbitrary_message(channels_backend: ChannelsBackend) -> None:
"""The websocket handlers await `WebSocket.receive()` to detect disconnection and stop the subscription.
This test ensures that the subscription is only stopped in the case of receiving a `websocket.disconnect` message.
@@ -140,7 +143,7 @@ def test_ws_route_handlers_receive_arbitrary_message(channels_backend: ChannelsB
@pytest.mark.flaky(reruns=5)
-async def test_create_ws_route_handlers_arbitrary_channels_allowed(channels_backend: ChannelsBackend) -> None:
+def test_create_ws_route_handlers_arbitrary_channels_allowed(channels_backend: ChannelsBackend) -> None:
channels_plugin = ChannelsPlugin(
backend=channels_backend,
arbitrary_channels_allowed=True,
@@ -155,6 +158,8 @@ async def test_create_ws_route_handlers_arbitrary_channels_allowed(channels_back
channels_plugin.publish("something", "foo")
assert ws.receive_text(timeout=2) == "something"
+ time.sleep(0.1)
+
with client.websocket_connect("/ws/bar") as ws:
channels_plugin.publish("something else", "bar")
assert ws.receive_text(timeout=2) == "something else"