Skip to content

Commit

Permalink
use a separate connection to publish/listen
Browse files Browse the repository at this point in the history
Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>
  • Loading branch information
provinzkraut committed Nov 29, 2023
1 parent 29bc1f4 commit e1a9adc
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions litestar/channels/backends/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,33 @@
class PostgresChannelsBackend(ChannelsBackend):
def __init__(self, url: str) -> None:
self._pg_url = url
self._connection: asyncpg.Connection
self._listener_conn: asyncpg.Connection
self._queue: asyncio.Queue[tuple[str, bytes]] = asyncio.Queue()
self._subscribed_channels: set[str] = set()

async def on_startup(self) -> None:
self._connection = await asyncpg.connect(self._pg_url)
self._listener_conn = await asyncpg.connect(self._pg_url)

async def on_shutdown(self) -> None:
await self._connection.close()
self._connection = None
await self._listener_conn.close()
self._listener_conn = None

async def publish(self, data: bytes, channels: Iterable[str]) -> None:
while not self._connection._listeners:
await asyncio.sleep(0.001)

dec_data = data.decode("utf-8")

conn = await asyncpg.connect(self._pg_url)
for channel in channels:
await self._connection.execute("SELECT pg_notify($1, $2);", channel, dec_data)
await conn.execute("SELECT pg_notify($1, $2);", channel, dec_data)
await conn.close()

async def subscribe(self, channels: Iterable[str]) -> None:
for channel in set(channels) - self._subscribed_channels:
await self._connection.add_listener(channel, self._listener)
await self._listener_conn.add_listener(channel, self._listener)
self._subscribed_channels.add(channel)

async def unsubscribe(self, channels: Iterable[str]) -> None:
for channel in channels:
await self._connection.remove_listener(channel, self._listener)
await self._listener_conn.remove_listener(channel, self._listener)
self._subscribed_channels = self._subscribed_channels - set(channels)

async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
Expand Down

0 comments on commit e1a9adc

Please sign in to comment.