From 514925568a1614ea5f14f46c76cd8babd57b574f Mon Sep 17 00:00:00 2001 From: Davide Mezzogori Date: Wed, 31 May 2023 13:10:12 +0200 Subject: [PATCH] Separation of database context vars --- kwik/database/context_vars/__init__.py | 4 ++++ .../current_user.py} | 11 +++------ kwik/database/context_vars/db_conn.py | 9 +++++++ kwik/database/db_context_manager.py | 24 +++++++++++++------ kwik/routers/auditor.py | 6 +++-- kwik/utils/tests.py | 6 +++-- 6 files changed, 41 insertions(+), 19 deletions(-) create mode 100644 kwik/database/context_vars/__init__.py rename kwik/database/{db_context_var.py => context_vars/current_user.py} (54%) create mode 100644 kwik/database/context_vars/db_conn.py diff --git a/kwik/database/context_vars/__init__.py b/kwik/database/context_vars/__init__.py new file mode 100644 index 0000000..6fd9753 --- /dev/null +++ b/kwik/database/context_vars/__init__.py @@ -0,0 +1,4 @@ +from __future__ import annotations + +from .current_user import current_user_ctx_var +from .db_conn import db_conn_ctx_var diff --git a/kwik/database/db_context_var.py b/kwik/database/context_vars/current_user.py similarity index 54% rename from kwik/database/db_context_var.py rename to kwik/database/context_vars/current_user.py index a4bde15..f62b15c 100644 --- a/kwik/database/db_context_var.py +++ b/kwik/database/context_vars/current_user.py @@ -5,11 +5,6 @@ if TYPE_CHECKING: from kwik import models - from .session import KwikSession - -db_conn_ctx_var: ContextVar[KwikSession | None] = ContextVar( - "db_conn_ctx_var", default=None -) -current_user_ctx_var: ContextVar[models.User | None] = ContextVar( - "current_user_ctx_var", default=None -) + + +current_user_ctx_var: ContextVar[models.User | None] = ContextVar("current_user_ctx_var", default=None) diff --git a/kwik/database/context_vars/db_conn.py b/kwik/database/context_vars/db_conn.py new file mode 100644 index 0000000..7a1434b --- /dev/null +++ b/kwik/database/context_vars/db_conn.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from contextvars import ContextVar +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + +db_conn_ctx_var: ContextVar[Session | None] = ContextVar("db_conn_ctx_var", default=None) diff --git a/kwik/database/db_context_manager.py b/kwik/database/db_context_manager.py index d2af469..a5e0454 100644 --- a/kwik/database/db_context_manager.py +++ b/kwik/database/db_context_manager.py @@ -7,11 +7,10 @@ import kwik from kwik import models from kwik.core.config import Settings +from kwik.database.context_vars import current_user_ctx_var, db_conn_ctx_var from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker -from .db_context_var import current_user_ctx_var, db_conn_ctx_var - if TYPE_CHECKING: from .session import KwikSession @@ -61,17 +60,28 @@ def __get__(self, obj, objtype=None) -> models.User | None: class DBContextManager: """ DB Session Context Manager. - Correctly initialize the session by overriding the Session and Query class. - Implemented as a python context manager, automatically rollback a transaction - if any exception is raised by the application. + + Implemented as a context manager, + automatically rollback a transaction if any exception is raised by the application. """ - def __init__(self, *, settings: Settings | None = None) -> None: - self.db: KwikSession | Session | None = None + def __init__(self, *, settings: Settings) -> None: + """ + Initialize the DBContextManager. + + Requires a Settings object instance to be passed in. + """ self.settings: Settings = settings + self.db: KwikSession | Session | None = None self.token: Token | None = None def __enter__(self) -> KwikSession | Session: + """ + Enter the context manager. + + Returns a database session. + """ + token = db_conn_ctx_var.get() if token is not None: self.db = token diff --git a/kwik/routers/auditor.py b/kwik/routers/auditor.py index 4619713..8d5f7f4 100644 --- a/kwik/routers/auditor.py +++ b/kwik/routers/auditor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from typing import Callable @@ -6,7 +8,7 @@ from fastapi.routing import APIRoute from jose import jwt from kwik import crud, schemas -from kwik.api.deps import get_current_user, current_token, get_token +from kwik.api.deps import current_token, get_current_user, get_token from kwik.core import security from kwik.middlewares import get_request_id @@ -32,7 +34,7 @@ def get_route_handler(self) -> Callable: original_route_handler = super().get_route_handler() async def custom_route_handler(request: Request) -> Response: - from kwik.database.db_context_var import current_user_ctx_var + from kwik.database.context_vars import current_user_ctx_var # start the timer start = time.time() diff --git a/kwik/utils/tests.py b/kwik/utils/tests.py index 29c50fa..933d700 100644 --- a/kwik/utils/tests.py +++ b/kwik/utils/tests.py @@ -1,9 +1,11 @@ +from __future__ import annotations + from contextlib import contextmanager from typing import Generator import kwik.crud +from kwik.database.context_vars.db_conn import db_conn_ctx_var from kwik.database.db_context_manager import DBContextManager -from kwik.database.db_context_var import db_conn_ctx_var @contextmanager @@ -17,7 +19,7 @@ def test_db(*, db_path: str, setup=True) -> Generator: ) as db_cxt: if setup: superuser = kwik.crud.user.get_by_email(email=kwik.settings.FIRST_SUPERUSER) - from kwik.database.db_context_var import current_user_ctx_var + from kwik.database.context_vars import current_user_ctx_var current_user_ctx_var.set(superuser)