Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(channels): Postgres backends #2803

Merged
merged 22 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
.scannerwork/
.unasyncd_cache/
.venv/
.venv*
.vscode/
__pycache__/
build/
Expand Down
5 changes: 5 additions & 0 deletions docs/reference/channels/backends/asyncpg.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
asyncpg
=======

.. automodule:: litestar.channels.backends.asyncpg
:members:
2 changes: 2 additions & 0 deletions docs/reference/channels/backends/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ backends
base
memory
redis
psycopg
asyncpg
5 changes: 5 additions & 0 deletions docs/reference/channels/backends/psycopg.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
psycopg
=======

.. automodule:: litestar.channels.backends.psycopg
:members:
13 changes: 12 additions & 1 deletion docs/usage/channels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ implemented are:
A basic in-memory backend, mostly useful for testing and local development, but
still fully capable. Since it stores all data in-process, it can achieve the highest
performance of all the backends, but at the same time is not suitable for
applications running on multiple processes.
applications running on multiple processes

:class:`RedisChannelsPubSubBackend <.redis.RedisChannelsPubSubBackend>`
A Redis based backend, using `Pub/Sub <https://redis.io/docs/manual/pubsub/>`_ to
Expand All @@ -413,6 +413,17 @@ implemented are:
when history is needed


:class:`AsyncPgChannelsBackend <.asyncpg.AsyncPgChannelsBackend>`
A postgres backend using the
`asyncpg <https://magicstack.github.io/asyncpg/current/>`_ driver


:class:`PsycoPgChannelsBackend <.psycopg.AsyncPgChannelsBackend>`
A postgres backend using the `psycopg3 <https://www.psycopg.org/psycopg3/docs/>`_
async driver




Integrating with websocket handlers
-----------------------------------
Expand Down
82 changes: 82 additions & 0 deletions litestar/channels/backends/asyncpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

import asyncio
from contextlib import AsyncExitStack
from functools import partial
from typing import AsyncGenerator, Awaitable, Callable, Iterable, overload

import asyncpg

from litestar.channels import ChannelsBackend
from litestar.exceptions import ImproperlyConfiguredException


class AsyncPgChannelsBackend(ChannelsBackend):
_listener_conn: asyncpg.Connection
_queue: asyncio.Queue[tuple[str, bytes]]

@overload
def __init__(self, dsn: str) -> None:
...

@overload
def __init__(
self,
*,
make_connection: Callable[[], Awaitable[asyncpg.Connection]],
) -> None:
...

def __init__(
self,
dsn: str | None = None,
*,
make_connection: Callable[[], Awaitable[asyncpg.Connection]] | None = None,
) -> None:
if not (dsn or make_connection):
raise ImproperlyConfiguredException("Need to specify dsn or make_connection")

self._subscribed_channels: set[str] = set()
self._exit_stack = AsyncExitStack()
self._connect = make_connection or partial(asyncpg.connect, dsn=dsn)

async def on_startup(self) -> None:
self._queue = asyncio.Queue()
self._listener_conn = await self._connect()

async def on_shutdown(self) -> None:
await self._listener_conn.close()
del self._queue

async def publish(self, data: bytes, channels: Iterable[str]) -> None:
dec_data = data.decode("utf-8")

conn = await self._connect()
try:
for channel in channels:
await conn.execute("SELECT pg_notify($1, $2);", channel, dec_data)
finally:
await conn.close()

async def subscribe(self, channels: Iterable[str]) -> None:
for channel in set(channels) - self._subscribed_channels:
await self._listener_conn.add_listener(channel, self._listener) # type: ignore[arg-type]
self._subscribed_channels.add(channel)

async def unsubscribe(self, channels: Iterable[str]) -> None:
for channel in channels:
await self._listener_conn.remove_listener(channel, self._listener) # type: ignore[arg-type]
self._subscribed_channels = self._subscribed_channels - set(channels)

async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
while self._queue:
yield await self._queue.get()
self._queue.task_done()

async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
raise NotImplementedError()

def _listener(self, /, connection: asyncpg.Connection, pid: int, channel: str, payload: object) -> None:
if not isinstance(payload, str):
raise RuntimeError("Invalid data received")
self._queue.put_nowait((channel, payload.encode("utf-8")))
54 changes: 54 additions & 0 deletions litestar/channels/backends/psycopg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from contextlib import AsyncExitStack
from typing import AsyncGenerator, Iterable

import psycopg

from .base import ChannelsBackend


def _safe_quote(ident: str) -> str:
return '"{}"'.format(ident.replace('"', '""'))


class PsycoPgChannelsBackend(ChannelsBackend):
_listener_conn: psycopg.AsyncConnection

def __init__(self, pg_dsn: str) -> None:
self._pg_dsn = pg_dsn
self._subscribed_channels: set[str] = set()
self._exit_stack = AsyncExitStack()

async def on_startup(self) -> None:
self._listener_conn = await psycopg.AsyncConnection.connect(self._pg_dsn, autocommit=True)
await self._exit_stack.enter_async_context(self._listener_conn)

async def on_shutdown(self) -> None:
await self._exit_stack.aclose()

async def publish(self, data: bytes, channels: Iterable[str]) -> None:
dec_data = data.decode("utf-8")
async with await psycopg.AsyncConnection.connect(self._pg_dsn) as conn:
for channel in channels:
await conn.execute("SELECT pg_notify(%s, %s);", (channel, dec_data))

async def subscribe(self, channels: Iterable[str]) -> None:
for channel in set(channels) - self._subscribed_channels:
# can't use placeholders in LISTEN
await self._listener_conn.execute(f"LISTEN {_safe_quote(channel)};") # pyright: ignore

self._subscribed_channels.add(channel)

async def unsubscribe(self, channels: Iterable[str]) -> None:
for channel in channels:
# can't use placeholders in UNLISTEN
await self._listener_conn.execute(f"UNLISTEN {_safe_quote(channel)};") # pyright: ignore
self._subscribed_channels = self._subscribed_channels - set(channels)

async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
async for notify in self._listener_conn.notifies():
yield notify.channel, notify.payload.encode("utf-8")

async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
raise NotImplementedError()
2 changes: 1 addition & 1 deletion litestar/channels/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,10 @@ async def _sub_worker(self) -> None:
subscriber.put_nowait(payload)

async def _on_startup(self) -> None:
await self._backend.on_startup()
self._pub_queue = Queue()
self._pub_task = create_task(self._pub_worker())
self._sub_task = create_task(self._sub_worker())
await self._backend.on_startup()
if self._channels:
await self._backend.subscribe(list(self._channels))

Expand Down
Loading