Skip to content

Commit

Permalink
feat(channels): Postgres backends (#2803)
Browse files Browse the repository at this point in the history
* wip

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* some debugging

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* formatting

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* use a separate connection to publish/listen

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* reintroduce flaky

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* Fix typing

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* Add psycopg backend

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* Fix backend issues

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* Undo test debugging changes

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* mark groups

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* Ensure channel names ar quoted

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* sleep debugging

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* update docs

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* Add missing test

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* Fix docs link

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* Add missing listener test

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* Formatting

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* Fix test typing

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

* Fix some coverage issue

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>

---------

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>
Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com>
  • Loading branch information
provinzkraut and cofin committed Jan 4, 2024
1 parent 2409574 commit 6300249
Show file tree
Hide file tree
Showing 14 changed files with 234 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
.scannerwork/
.unasyncd_cache/
.venv/
.venv*
.vscode/
__pycache__/
build/
Expand Down
5 changes: 5 additions & 0 deletions docs/reference/channels/backends/asyncpg.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
asyncpg
=======

.. automodule:: litestar.channels.backends.asyncpg
:members:
2 changes: 2 additions & 0 deletions docs/reference/channels/backends/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ backends
base
memory
redis
psycopg
asyncpg
5 changes: 5 additions & 0 deletions docs/reference/channels/backends/psycopg.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
psycopg
=======

.. automodule:: litestar.channels.backends.psycopg
:members:
13 changes: 12 additions & 1 deletion docs/usage/channels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ implemented are:
A basic in-memory backend, mostly useful for testing and local development, but
still fully capable. Since it stores all data in-process, it can achieve the highest
performance of all the backends, but at the same time is not suitable for
applications running on multiple processes.
applications running on multiple processes

:class:`RedisChannelsPubSubBackend <.redis.RedisChannelsPubSubBackend>`
A Redis based backend, using `Pub/Sub <https://redis.io/docs/manual/pubsub/>`_ to
Expand All @@ -413,6 +413,17 @@ implemented are:
when history is needed


:class:`AsyncPgChannelsBackend <.asyncpg.AsyncPgChannelsBackend>`
A postgres backend using the
`asyncpg <https://magicstack.github.io/asyncpg/current/>`_ driver


:class:`PsycoPgChannelsBackend <.psycopg.AsyncPgChannelsBackend>`
A postgres backend using the `psycopg3 <https://www.psycopg.org/psycopg3/docs/>`_
async driver




Integrating with websocket handlers
-----------------------------------
Expand Down
82 changes: 82 additions & 0 deletions litestar/channels/backends/asyncpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

import asyncio
from contextlib import AsyncExitStack
from functools import partial
from typing import AsyncGenerator, Awaitable, Callable, Iterable, overload

import asyncpg

from litestar.channels import ChannelsBackend
from litestar.exceptions import ImproperlyConfiguredException


class AsyncPgChannelsBackend(ChannelsBackend):
_listener_conn: asyncpg.Connection
_queue: asyncio.Queue[tuple[str, bytes]]

@overload
def __init__(self, dsn: str) -> None:
...

@overload
def __init__(
self,
*,
make_connection: Callable[[], Awaitable[asyncpg.Connection]],
) -> None:
...

def __init__(
self,
dsn: str | None = None,
*,
make_connection: Callable[[], Awaitable[asyncpg.Connection]] | None = None,
) -> None:
if not (dsn or make_connection):
raise ImproperlyConfiguredException("Need to specify dsn or make_connection")

self._subscribed_channels: set[str] = set()
self._exit_stack = AsyncExitStack()
self._connect = make_connection or partial(asyncpg.connect, dsn=dsn)

async def on_startup(self) -> None:
self._queue = asyncio.Queue()
self._listener_conn = await self._connect()

async def on_shutdown(self) -> None:
await self._listener_conn.close()
del self._queue

async def publish(self, data: bytes, channels: Iterable[str]) -> None:
dec_data = data.decode("utf-8")

conn = await self._connect()
try:
for channel in channels:
await conn.execute("SELECT pg_notify($1, $2);", channel, dec_data)
finally:
await conn.close()

async def subscribe(self, channels: Iterable[str]) -> None:
for channel in set(channels) - self._subscribed_channels:
await self._listener_conn.add_listener(channel, self._listener) # type: ignore[arg-type]
self._subscribed_channels.add(channel)

async def unsubscribe(self, channels: Iterable[str]) -> None:
for channel in channels:
await self._listener_conn.remove_listener(channel, self._listener) # type: ignore[arg-type]
self._subscribed_channels = self._subscribed_channels - set(channels)

async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
while self._queue:
yield await self._queue.get()
self._queue.task_done()

async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
raise NotImplementedError()

def _listener(self, /, connection: asyncpg.Connection, pid: int, channel: str, payload: object) -> None:
if not isinstance(payload, str):
raise RuntimeError("Invalid data received")
self._queue.put_nowait((channel, payload.encode("utf-8")))
54 changes: 54 additions & 0 deletions litestar/channels/backends/psycopg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from contextlib import AsyncExitStack
from typing import AsyncGenerator, Iterable

import psycopg

from .base import ChannelsBackend


def _safe_quote(ident: str) -> str:
return '"{}"'.format(ident.replace('"', '""'))


class PsycoPgChannelsBackend(ChannelsBackend):
_listener_conn: psycopg.AsyncConnection

def __init__(self, pg_dsn: str) -> None:
self._pg_dsn = pg_dsn
self._subscribed_channels: set[str] = set()
self._exit_stack = AsyncExitStack()

async def on_startup(self) -> None:
self._listener_conn = await psycopg.AsyncConnection.connect(self._pg_dsn, autocommit=True)
await self._exit_stack.enter_async_context(self._listener_conn)

async def on_shutdown(self) -> None:
await self._exit_stack.aclose()

async def publish(self, data: bytes, channels: Iterable[str]) -> None:
dec_data = data.decode("utf-8")
async with await psycopg.AsyncConnection.connect(self._pg_dsn) as conn:
for channel in channels:
await conn.execute("SELECT pg_notify(%s, %s);", (channel, dec_data))

async def subscribe(self, channels: Iterable[str]) -> None:
for channel in set(channels) - self._subscribed_channels:
# can't use placeholders in LISTEN
await self._listener_conn.execute(f"LISTEN {_safe_quote(channel)};") # pyright: ignore

self._subscribed_channels.add(channel)

async def unsubscribe(self, channels: Iterable[str]) -> None:
for channel in channels:
# can't use placeholders in UNLISTEN
await self._listener_conn.execute(f"UNLISTEN {_safe_quote(channel)};") # pyright: ignore
self._subscribed_channels = self._subscribed_channels - set(channels)

async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
async for notify in self._listener_conn.notifies():
yield notify.channel, notify.payload.encode("utf-8")

async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
raise NotImplementedError()
2 changes: 1 addition & 1 deletion litestar/channels/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,10 @@ async def _sub_worker(self) -> None:
subscriber.put_nowait(payload)

async def _on_startup(self) -> None:
await self._backend.on_startup()
self._pub_queue = Queue()
self._pub_task = create_task(self._pub_worker())
self._sub_task = create_task(self._sub_worker())
await self._backend.on_startup()
if self._channels:
await self._backend.subscribe(list(self._channels))

Expand Down
6 changes: 5 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ dev = [
"trio",
"aiosqlite",
"exceptiongroup; python_version < \"3.11\"",
"asyncpg>=0.29.0",
"psycopg[pool,binary]>=3.1.10",
"psycopg2-binary",
]
dev-contrib = ["opentelemetry-sdk", "httpx-sse"]
docs = [
Expand Down
5 changes: 5 additions & 0 deletions tests/docker_service_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,8 @@ async def postgres_responsive(host: str) -> bool:
return (await conn.fetchrow("SELECT 1"))[0] == 1 # type: ignore
finally:
await conn.close()


@pytest.fixture()
async def postgres_service(docker_services: DockerServiceRegistry) -> None:
await docker_services.start("postgres", check=postgres_responsive)
12 changes: 12 additions & 0 deletions tests/unit/test_channels/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import pytest
from redis.asyncio import Redis as AsyncRedis

from litestar.channels.backends.asyncpg import AsyncPgChannelsBackend
from litestar.channels.backends.memory import MemoryChannelsBackend
from litestar.channels.backends.psycopg import PsycoPgChannelsBackend
from litestar.channels.backends.redis import RedisChannelsPubSubBackend, RedisChannelsStreamBackend


Expand Down Expand Up @@ -37,3 +39,13 @@ def redis_pub_sub_backend(redis_client: AsyncRedis) -> RedisChannelsPubSubBacken
@pytest.fixture()
def memory_backend() -> MemoryChannelsBackend:
return MemoryChannelsBackend(history=10)


@pytest.fixture()
def postgres_asyncpg_backend(postgres_service: None, docker_ip: str) -> AsyncPgChannelsBackend:
return AsyncPgChannelsBackend(f"postgres://postgres:super-secret@{docker_ip}:5423")


@pytest.fixture()
def postgres_psycopg_backend(postgres_service: None, docker_ip: str) -> PsycoPgChannelsBackend:
return PsycoPgChannelsBackend(f"postgres://postgres:super-secret@{docker_ip}:5423")
42 changes: 40 additions & 2 deletions tests/unit/test_channels/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,27 @@
import asyncio
from datetime import timedelta
from typing import AsyncGenerator, cast
from unittest.mock import AsyncMock, MagicMock

import pytest
from _pytest.fixtures import FixtureRequest
from redis.asyncio.client import Redis

from litestar.channels import ChannelsBackend
from litestar.channels.backends.asyncpg import AsyncPgChannelsBackend
from litestar.channels.backends.memory import MemoryChannelsBackend
from litestar.channels.backends.psycopg import PsycoPgChannelsBackend
from litestar.channels.backends.redis import RedisChannelsPubSubBackend, RedisChannelsStreamBackend
from litestar.exceptions import ImproperlyConfiguredException
from litestar.utils.compat import async_next


@pytest.fixture(
params=[
pytest.param("redis_pub_sub_backend", id="redis:pubsub", marks=pytest.mark.xdist_group("redis")),
pytest.param("redis_stream_backend", id="redis:stream", marks=pytest.mark.xdist_group("redis")),
pytest.param("postgres_asyncpg_backend", id="postgres:asyncpg", marks=pytest.mark.xdist_group("postgres")),
pytest.param("postgres_psycopg_backend", id="postgres:psycopg", marks=pytest.mark.xdist_group("postgres")),
pytest.param("memory_backend", id="memory"),
]
)
Expand Down Expand Up @@ -82,7 +88,7 @@ async def test_unsubscribe_without_subscription(channels_backend: ChannelsBacken
async def test_get_history(
channels_backend: ChannelsBackend, history_limit: int | None, expected_history_length: int
) -> None:
if isinstance(channels_backend, RedisChannelsPubSubBackend):
if isinstance(channels_backend, (RedisChannelsPubSubBackend, AsyncPgChannelsBackend, PsycoPgChannelsBackend)):
pytest.skip("Redis pub/sub backend does not support history")

messages = [str(i).encode() for i in range(100)]
Expand All @@ -97,7 +103,7 @@ async def test_get_history(


async def test_discards_history_entries(channels_backend: ChannelsBackend) -> None:
if isinstance(channels_backend, RedisChannelsPubSubBackend):
if isinstance(channels_backend, (RedisChannelsPubSubBackend, AsyncPgChannelsBackend, PsycoPgChannelsBackend)):
pytest.skip("Redis pub/sub backend does not support history")

for _ in range(20):
Expand Down Expand Up @@ -133,3 +139,35 @@ async def test_memory_publish_not_initialized_raises() -> None:

with pytest.raises(RuntimeError):
await backend.publish(b"foo", ["something"])


@pytest.mark.xdist_group("postgres")
async def test_asyncpg_get_history(postgres_asyncpg_backend: AsyncPgChannelsBackend) -> None:
with pytest.raises(NotImplementedError):
await postgres_asyncpg_backend.get_history("something")


@pytest.mark.xdist_group("postgres")
async def test_psycopg_get_history(postgres_psycopg_backend: PsycoPgChannelsBackend) -> None:
with pytest.raises(NotImplementedError):
await postgres_psycopg_backend.get_history("something")


async def test_asyncpg_make_connection() -> None:
make_connection = AsyncMock()

backend = AsyncPgChannelsBackend(make_connection=make_connection)
await backend.on_startup()

make_connection.assert_awaited_once()


async def test_asyncpg_no_make_conn_or_dsn_passed_raises() -> None:
with pytest.raises(ImproperlyConfiguredException):
AsyncPgChannelsBackend() # type: ignore[call-overload]


def test_asyncpg_listener_raises_on_non_string_payload() -> None:
backend = AsyncPgChannelsBackend(make_connection=AsyncMock())
with pytest.raises(RuntimeError):
backend._listener(connection=MagicMock(), pid=1, payload=b"abc", channel="foo")
9 changes: 7 additions & 2 deletions tests/unit/test_channels/test_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import time
from secrets import token_hex
from typing import cast
from unittest.mock import AsyncMock, MagicMock
Expand All @@ -24,6 +25,8 @@
params=[
pytest.param("redis_pub_sub_backend", id="redis:pubsub", marks=pytest.mark.xdist_group("redis")),
pytest.param("redis_stream_backend", id="redis:stream", marks=pytest.mark.xdist_group("redis")),
pytest.param("postgres_asyncpg_backend", id="postgres:asyncpg", marks=pytest.mark.xdist_group("postgres")),
pytest.param("postgres_psycopg_backend", id="postgres:psycopg", marks=pytest.mark.xdist_group("postgres")),
pytest.param("memory_backend", id="memory"),
]
)
Expand Down Expand Up @@ -119,7 +122,7 @@ def test_create_ws_route_handlers(


@pytest.mark.flaky(reruns=5)
def test_ws_route_handlers_receive_arbitrary_message(channels_backend: ChannelsBackend) -> None:
async def test_ws_route_handlers_receive_arbitrary_message(channels_backend: ChannelsBackend) -> None:
"""The websocket handlers await `WebSocket.receive()` to detect disconnection and stop the subscription.
This test ensures that the subscription is only stopped in the case of receiving a `websocket.disconnect` message.
Expand All @@ -140,7 +143,7 @@ def test_ws_route_handlers_receive_arbitrary_message(channels_backend: ChannelsB


@pytest.mark.flaky(reruns=5)
async def test_create_ws_route_handlers_arbitrary_channels_allowed(channels_backend: ChannelsBackend) -> None:
def test_create_ws_route_handlers_arbitrary_channels_allowed(channels_backend: ChannelsBackend) -> None:
channels_plugin = ChannelsPlugin(
backend=channels_backend,
arbitrary_channels_allowed=True,
Expand All @@ -155,6 +158,8 @@ async def test_create_ws_route_handlers_arbitrary_channels_allowed(channels_back
channels_plugin.publish("something", "foo")
assert ws.receive_text(timeout=2) == "something"

time.sleep(0.1)

with client.websocket_connect("/ws/bar") as ws:
channels_plugin.publish("something else", "bar")
assert ws.receive_text(timeout=2) == "something else"
Expand Down

0 comments on commit 6300249

Please sign in to comment.