Skip to content

Commit

Permalink
Deferred Listener Connection and README Update (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
janbjorge committed Feb 20, 2024
1 parent 1ac83c8 commit e635682
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 77 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ on:
jobs:
ci:
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]
postgres-version: ["14", "15", "16"]
Expand Down Expand Up @@ -54,3 +55,11 @@ jobs:

- name: Full test
run: pytest -v

check:
name: Check test matrix passed.
needs: ci
runs-on: ubuntu-latest
steps:
- name: Check status
run: echo "All tests passed; ready to merge."
40 changes: 28 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,40 @@ pgcachewatch install <tables-to-cache>
Example showing how to use PGCacheWatch for cache invalidation in a FastAPI app

```python
import contextlib
import typing

import asyncpg
from fastapi import FastAPI
from pgcachewatch import decorators, listeners, models, strategies

app = FastAPI()
listener = listeners.PGEventQueue()


async def setup_app(channel: models.PGChannel) -> FastAPI:
@contextlib.asynccontextmanager
async def app_setup_teardown(_: FastAPI) -> typing.AsyncGenerator[None, None]:
conn = await asyncpg.connect()
listener = await listeners.PGEventQueue.create(channel, conn)
await listener.connect(conn, models.PGChannel("ch_pgcachewatch_table_change"))
yield
await conn.close()


@decorators.cache(strategy=strategies.Greedy(listener=listener))
async def cached_query():
# Simulate a database query
return {"data": "query result"}
APP = FastAPI(lifespan=app_setup_teardown)

@app.get("/data")
async def get_data():
return await cached_query()

return app
```
# Only allow for cache refresh after an update
@decorators.cache(
strategy=strategies.Gready(
listener=listener,
predicate=lambda x: x.operation == "update",
)
)
async def cached_query() -> dict[str, str]:
# Simulate a database query
return {"data": "query result"}


@APP.get("/data")
async def get_data() -> dict:
return await cached_query()
```
80 changes: 46 additions & 34 deletions src/pgcachewatch/listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def _critical_termination_listener(*_: object, **__: object) -> None:
# Must be defined in the global namespace, as ayncpg keeps
# a set of functions to call. This this will now happen once as
# all instance will point to the same function.
logging.critical("Connection is closed / terminated!")
logging.critical("Connection is closed / terminated.")


class PGEventQueue(asyncio.Queue[models.Event]):
Expand All @@ -23,48 +23,59 @@ class PGEventQueue(asyncio.Queue[models.Event]):

def __init__(
self,
pgchannel: models.PGChannel,
pgconn: asyncpg.Connection,
max_size: int = 0,
max_latency: datetime.timedelta = datetime.timedelta(milliseconds=500),
_called_by_create: bool = False,
) -> None:
"""
Initializes the PGEventQueue instance. Use the create() classmethod to
instantiate.
"""
if not _called_by_create:
raise RuntimeError(
"Use classmethod create(...) to instantiate PGEventQueue."
)
super().__init__(maxsize=max_size)
self._pg_channel = pgchannel
self._pg_connection = pgconn
self._pg_channel: None | models.PGChannel = None
self._pg_connection: None | asyncpg.Connection = None
self._max_latency = max_latency

@classmethod
async def create(
cls,
pgchannel: models.PGChannel,
pgconn: asyncpg.Connection,
maxsize: int = 0,
max_latency: datetime.timedelta = datetime.timedelta(milliseconds=500),
) -> "PGEventQueue":
"""
Creates and initializes a new PGEventQueue instance, connecting to the specified
PostgreSQL channel. Returns the initialized PGEventQueue instance.
async def connect(
self,
connection: asyncpg.Connection,
channel: models.PGChannel,
) -> None:
"""
me = cls(
pgchannel=pgchannel,
pgconn=pgconn,
max_size=maxsize,
max_latency=max_latency,
_called_by_create=True,
Asynchronously connects the PGEventQueue to a specified
PostgreSQL channel and connection.
This method establishes a listener on a PostgreSQL channel
using the provided connection. It is designed to be called
once per PGEventQueue instance to ensure a one-to-one relationship
between the event queue and a database channel. If an attempt is
made to connect a PGEventQueue instance to more than one channel
or connection, a RuntimeError is raised to enforce this constraint.
Parameters:
- connection: asyncpg.Connection
The asyncpg connection object to be used for listening to database events.
- channel: models.PGChannel
The database channel to listen on for events.
Raises:
- RuntimeError: If the PGEventQueue is already connected to a
channel or connection.
Usage:
```python
await pg_event_queue.connect(
connection=your_asyncpg_connection,
channel=your_pg_channel,
)
me._pg_connection.add_termination_listener(_critical_termination_listener)
await me._pg_connection.add_listener(me._pg_channel, me.parse_and_put) # type: ignore[arg-type]
```
"""
if self._pg_channel or self._pg_connection:
raise RuntimeError(
"PGEventQueue instance is already connected to a channel and/or "
"connection. Only supports one channel and connection per "
"PGEventQueue instance."
)

return me
self._pg_channel = channel
self._pg_connection = connection
self._pg_connection.add_termination_listener(_critical_termination_listener)
await self._pg_connection.add_listener(self._pg_channel, self.parse_and_put) # type: ignore[arg-type]

def parse_and_put(
self,
Expand All @@ -87,6 +98,7 @@ def parse_and_put(
except Exception:
logging.exception("Unable to parse `%s`.", payload)
else:
logging.info("Received event: %s on %s", parsed, channel)
try:
self.put_nowait(parsed)
except Exception:
Expand Down
6 changes: 2 additions & 4 deletions tests/test_decoraters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
@pytest.mark.parametrize("N", (4, 16, 64, 512))
async def test_gready_cache_decorator(N: int, pgconn: asyncpg.Connection) -> None:
statistics = collections.Counter[str]()
listener = await listeners.PGEventQueue.create(
models.PGChannel("test_cache_decorator"),
pgconn=pgconn,
)
listener = listeners.PGEventQueue()
await listener.connect(pgconn, models.PGChannel("test_cache_decorator"))

@decorators.cache(
strategy=strategies.Gready(listener=listener),
Expand Down
3 changes: 2 additions & 1 deletion tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ async def fastapitestapp(
) -> fastapi.FastAPI:
app = fastapi.FastAPI()

listener = await listeners.PGEventQueue.create(channel, pgconn)
listener = listeners.PGEventQueue()
await listener.connect(pgconn, channel)

@decorators.cache(strategy=strategies.Gready(listener=listener))
async def slow_db_read() -> dict:
Expand Down
21 changes: 11 additions & 10 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ async def test_2_caching(
pgpool: asyncpg.Pool,
) -> None:
statistics = collections.Counter[str]()
listener = await listeners.PGEventQueue.create(
models.PGChannel("test_2_caching"),
pgconn=pgconn,
)
listener = listeners.PGEventQueue()
await listener.connect(pgconn, models.PGChannel("test_2_caching"))

cnt = 0

Expand All @@ -64,9 +62,10 @@ async def test_3_cache_invalidation_update(
pgpool: asyncpg.Pool,
) -> None:
statistics = collections.Counter[str]()
listener = await listeners.PGEventQueue.create(
listener = listeners.PGEventQueue()
await listener.connect(
pgconn,
models.PGChannel("ch_pgcachewatch_table_change"),
pgconn=pgconn,
)

@decorators.cache(
Expand Down Expand Up @@ -97,9 +96,10 @@ async def test_3_cache_invalidation_insert(
pgpool: asyncpg.Pool,
) -> None:
statistics = collections.Counter[str]()
listener = await listeners.PGEventQueue.create(
listener = listeners.PGEventQueue()
await listener.connect(
pgconn,
models.PGChannel("ch_pgcachewatch_table_change"),
pgconn=pgconn,
)

@decorators.cache(
Expand Down Expand Up @@ -131,9 +131,10 @@ async def test_3_cache_invalidation_delete(
pgpool: asyncpg.Pool,
) -> None:
statistics = collections.Counter[str]()
listener = await listeners.PGEventQueue.create(
listener = listeners.PGEventQueue()
await listener.connect(
pgconn,
models.PGChannel("ch_pgcachewatch_table_change"),
pgconn=pgconn,
)

@decorators.cache(
Expand Down
5 changes: 3 additions & 2 deletions tests/test_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ async def test_eventqueue_and_pglistner(
pgpool: asyncpg.Pool,
) -> None:
channel = models.PGChannel(f"test_eventqueue_and_pglistner_{N}_{operation}")
eq = await listeners.PGEventQueue.create(channel, pgconn)
listener = listeners.PGEventQueue()
await listener.connect(pgconn, channel)

for _ in range(N):
await utils.emit_event(
Expand All @@ -32,7 +33,7 @@ async def test_eventqueue_and_pglistner(
evnets = list[models.Event]()
while True:
try:
evnets.append(eq.get_nowait())
evnets.append(listener.get_nowait())
except asyncio.QueueEmpty:
break

Expand Down
11 changes: 8 additions & 3 deletions tests/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
@pytest.mark.parametrize("N", (4, 16, 64))
async def test_gready_strategy(N: int, pgconn: asyncpg.Connection) -> None:
channel = models.PGChannel("test_gready_strategy")
listener = await listeners.PGEventQueue.create(channel, pgconn)

listener = listeners.PGEventQueue()
await listener.connect(pgconn, channel)

strategy = strategies.Gready(
listener=listener,
predicate=lambda e: e.operation == "insert",
Expand Down Expand Up @@ -47,7 +50,8 @@ async def test_windowed_strategy(
pgconn: asyncpg.Connection,
) -> None:
channel = models.PGChannel("test_windowed_strategy")
listener = await listeners.PGEventQueue.create(channel, pgconn)
listener = listeners.PGEventQueue()
await listener.connect(pgconn, channel)
strategy = strategies.Windowed(
listener=listener, window=["insert", "update", "delete"]
)
Expand Down Expand Up @@ -111,7 +115,8 @@ async def test_timed_strategy(
pgconn: asyncpg.Connection,
) -> None:
channel = models.PGChannel("test_timed_strategy")
listener = await listeners.PGEventQueue.create(channel, pgconn)
listener = listeners.PGEventQueue()
await listener.connect(pgconn, channel)
strategy = strategies.Timed(listener=listener, timedelta=dt)

# Bursed spaced out accoring to min dt req. to trigger a refresh.
Expand Down
17 changes: 6 additions & 11 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ async def test_emit_event(
pgpool: asyncpg.Pool,
) -> None:
channel = "test_emit_event"
listener = await listeners.PGEventQueue.create(
models.PGChannel(channel), pgconn=pgconn
)
listener = listeners.PGEventQueue()
await listener.connect(pgconn, models.PGChannel(channel))
await asyncio.gather(
*[
utils.emit_event(
Expand Down Expand Up @@ -47,10 +46,8 @@ async def test_pick_until_deadline_max_iter(
pgconn: asyncpg.Connection,
) -> None:
channel = "test_pick_until_deadline_max_iter"
listener = await listeners.PGEventQueue.create(
models.PGChannel(channel),
pgconn=pgconn,
)
listener = listeners.PGEventQueue()
await listener.connect(pgconn, models.PGChannel(channel))

items = list(range(max_iter * 2))
for item in items:
Expand Down Expand Up @@ -87,10 +84,8 @@ async def test_pick_until_deadline_max_time(
pgconn: asyncpg.Connection,
) -> None:
channel = "test_pick_until_deadline_max_time"
listener = await listeners.PGEventQueue.create(
models.PGChannel(channel),
pgconn=pgconn,
)
listener = listeners.PGEventQueue()
await listener.connect(pgconn, models.PGChannel(channel))

x = -1

Expand Down

0 comments on commit e635682

Please sign in to comment.