Skip to content

Commit

Permalink
Separation of database context vars
Browse files Browse the repository at this point in the history
  • Loading branch information
dmezzogori committed May 31, 2023
1 parent 4bbed32 commit 5149255
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 19 deletions.
4 changes: 4 additions & 0 deletions kwik/database/context_vars/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from __future__ import annotations

from .current_user import current_user_ctx_var
from .db_conn import db_conn_ctx_var
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,6 @@

if TYPE_CHECKING:
from kwik import models
from .session import KwikSession

db_conn_ctx_var: ContextVar[KwikSession | None] = ContextVar(
"db_conn_ctx_var", default=None
)
current_user_ctx_var: ContextVar[models.User | None] = ContextVar(
"current_user_ctx_var", default=None
)


current_user_ctx_var: ContextVar[models.User | None] = ContextVar("current_user_ctx_var", default=None)
9 changes: 9 additions & 0 deletions kwik/database/context_vars/db_conn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from contextvars import ContextVar
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from sqlalchemy.orm import Session

db_conn_ctx_var: ContextVar[Session | None] = ContextVar("db_conn_ctx_var", default=None)
24 changes: 17 additions & 7 deletions kwik/database/db_context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import kwik
from kwik import models
from kwik.core.config import Settings
from kwik.database.context_vars import current_user_ctx_var, db_conn_ctx_var
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker

from .db_context_var import current_user_ctx_var, db_conn_ctx_var

if TYPE_CHECKING:
from .session import KwikSession

Expand Down Expand Up @@ -61,17 +60,28 @@ def __get__(self, obj, objtype=None) -> models.User | None:
class DBContextManager:
"""
DB Session Context Manager.
Correctly initialize the session by overriding the Session and Query class.
Implemented as a python context manager, automatically rollback a transaction
if any exception is raised by the application.
Implemented as a context manager,
automatically rollback a transaction if any exception is raised by the application.
"""

def __init__(self, *, settings: Settings | None = None) -> None:
self.db: KwikSession | Session | None = None
def __init__(self, *, settings: Settings) -> None:
"""
Initialize the DBContextManager.
Requires a Settings object instance to be passed in.
"""
self.settings: Settings = settings
self.db: KwikSession | Session | None = None
self.token: Token | None = None

def __enter__(self) -> KwikSession | Session:
"""
Enter the context manager.
Returns a database session.
"""

token = db_conn_ctx_var.get()
if token is not None:
self.db = token
Expand Down
6 changes: 4 additions & 2 deletions kwik/routers/auditor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import time
from typing import Callable

Expand All @@ -6,7 +8,7 @@
from fastapi.routing import APIRoute
from jose import jwt
from kwik import crud, schemas
from kwik.api.deps import get_current_user, current_token, get_token
from kwik.api.deps import current_token, get_current_user, get_token
from kwik.core import security
from kwik.middlewares import get_request_id

Expand All @@ -32,7 +34,7 @@ def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()

async def custom_route_handler(request: Request) -> Response:
from kwik.database.db_context_var import current_user_ctx_var
from kwik.database.context_vars import current_user_ctx_var

# start the timer
start = time.time()
Expand Down
6 changes: 4 additions & 2 deletions kwik/utils/tests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import Generator

import kwik.crud
from kwik.database.context_vars.db_conn import db_conn_ctx_var
from kwik.database.db_context_manager import DBContextManager
from kwik.database.db_context_var import db_conn_ctx_var


@contextmanager
Expand All @@ -17,7 +19,7 @@ def test_db(*, db_path: str, setup=True) -> Generator:
) as db_cxt:
if setup:
superuser = kwik.crud.user.get_by_email(email=kwik.settings.FIRST_SUPERUSER)
from kwik.database.db_context_var import current_user_ctx_var
from kwik.database.context_vars import current_user_ctx_var

current_user_ctx_var.set(superuser)

Expand Down

0 comments on commit 5149255

Please sign in to comment.