Skip to content

Commit

Permalink
Refactoring for better separation of database package implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
dmezzogori committed May 31, 2023
1 parent 5149255 commit 2c7d311
Show file tree
Hide file tree
Showing 12 changed files with 152 additions and 121 deletions.
4 changes: 3 additions & 1 deletion kwik/api/endpoints/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any

import kwik.crud
Expand Down Expand Up @@ -36,7 +38,7 @@ def test_db_switcher() -> Any:
"""

user_db = kwik.crud.user.get(id=1)
with kwik.database.db_context_switcher() as db:
with kwik.database.DBContextSwitcher() as db:
db = db
kwik.logger.error(kwik.crud.user.db.get_bind())
kwik.logger.error(db.get_bind())
Expand Down
2 changes: 1 addition & 1 deletion kwik/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def assemble_cors_origins(cls, v: Union[str, list[str]]) -> Union[list[str], str
POSTGRES_DB: str = "db"
POSTGRES_MAX_CONNECTIONS: int
ENABLE_SOFT_DELETE: bool = False
SQLALCHEMY_DATABASE_URI: PostgresDsn | None = None
SQLALCHEMY_DATABASE_URI: PostgresDsn | str | None = None

@validator("SQLALCHEMY_DATABASE_URI", pre=True)
def assemble_db_connection(cls, v: str | None, values: dict[str, Any]) -> Any:
Expand Down
32 changes: 26 additions & 6 deletions kwik/crud/base.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,42 @@
from __future__ import annotations

import abc
from typing import Any, Type, TYPE_CHECKING, Generic, get_args, NoReturn
from typing import TYPE_CHECKING, Any, Generic, NoReturn, Type, get_args

from kwik.database import db_context_manager
import kwik
from kwik.database.context_vars import current_user_ctx_var, db_conn_ctx_var
from kwik.models import User
from kwik.typings import ModelType, CreateSchemaType, UpdateSchemaType
from kwik.typings import ParsedSortingQuery, PaginatedCRUDResult
from kwik.typings import (
CreateSchemaType,
ModelType,
PaginatedCRUDResult,
ParsedSortingQuery,
UpdateSchemaType,
)

if TYPE_CHECKING:
from kwik.database.session import KwikSession
from sqlalchemy.orm import Session

T = Generic[ModelType, CreateSchemaType, UpdateSchemaType]


class DBSession:
def __get__(self, obj, objtype=None) -> Session:
if (db := db_conn_ctx_var.get()) is not None:
return db
raise Exception("No database connection available")


class CurrentUser:
def __get__(self, obj, objtype=None) -> kwik.models.User | None:
user = current_user_ctx_var.get()
return user


class CRUDBase(abc.ABC, Generic[ModelType]):
db: KwikSession = db_context_manager.DBSession()
user: User | None = db_context_manager.CurrentUser()
db: KwikSession = DBSession()
user: User | None = CurrentUser()
model: Type[ModelType]

_instances: dict[str, T] = {}
Expand Down
8 changes: 4 additions & 4 deletions kwik/database/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import base
from . import mixins
from . import session
from __future__ import annotations

from .db_context_manager import db_context_switcher
from . import base, mixins, session
from .db_context_manager import DBContextManager
from .db_context_switcher import DBContextSwitcher
109 changes: 46 additions & 63 deletions kwik/database/db_context_manager.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,15 @@
from __future__ import annotations

from contextlib import contextmanager
from contextvars import Token
from types import TracebackType
from typing import TYPE_CHECKING

import kwik
from kwik import models
from kwik.core.config import Settings
from kwik.database.context_vars import current_user_ctx_var, db_conn_ctx_var
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from kwik.database.context_vars import db_conn_ctx_var
from kwik.database.session_local import SessionLocal

if TYPE_CHECKING:
from .session import KwikSession
from contextvars import Token


@contextmanager
def db_context_switcher():
from kwik import settings

prev_db_conn_ctx_var = db_conn_ctx_var.get()
with DBContextManager(
db_uri=settings.alternate_db.ALTERNATE_SQLALCHEMY_DATABASE_URI,
settings=settings.alternate_db,
) as db:
yield db

db_conn_ctx_var.set(prev_db_conn_ctx_var)


class DBSession:
def __get__(self, obj, objtype=None) -> KwikSession:
db = db_conn_ctx_var.get()
if db is not None:
return db
raise Exception("No database connection available")


class CurrentUser:
def __get__(self, obj, objtype=None) -> models.User | None:
user = current_user_ctx_var.get()
return user


engine = create_engine(
url=kwik.settings.SQLALCHEMY_DATABASE_URI,
pool_pre_ping=True,
pool_size=kwik.settings.POSTGRES_MAX_CONNECTIONS // kwik.settings.BACKEND_WORKERS,
max_overflow=0,
)

SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
)
from sqlalchemy.orm import Session


class DBContextManager:
Expand All @@ -65,39 +20,67 @@ class DBContextManager:
automatically rollback a transaction if any exception is raised by the application.
"""

def __init__(self, *, settings: Settings) -> None:
def __init__(self) -> None:
"""
Initialize the DBContextManager.
Requires a Settings object instance to be passed in.
"""
self.settings: Settings = settings
self.db: KwikSession | Session | None = None
self.token: Token | None = None
self.db: Session | None = None
self.token: Token[Session | None] | None = None

def __enter__(self) -> KwikSession | Session:
def __enter__(self) -> Session:
"""
Enter the context manager.
Enter the context manager, which returns a database session.
Retrieves a database session from the context variable.
If no session is found, a new session is created and stored in the context variable.
Returns a database session.
"""

token = db_conn_ctx_var.get()
if token is not None:
if token is None:
# No session found in the context variable.

# Create a new session.
self.db = SessionLocal()
# Store the session in the context variable.
self.token = db_conn_ctx_var.set(self.db)
else:
# Session found in the context variable.
self.db = token
self.token = token
return self.db

self.db = SessionLocal()

self.token = db_conn_ctx_var.set(self.db)
return self.db

def __exit__(self, exception_type, exception_value, exception_traceback) -> None:
def __exit__(
self,
exception_type: type[BaseException] | None,
exception_value: BaseException | None,
exception_traceback: TracebackType | None,
) -> None:
"""
Exit the context manager, handling any exceptions raised by the application.
If an exception is raised by the application, rollback the transaction.
Otherwise, commit the transaction.
Then, closes the database session and reset the context variable to its previous value.
"""

if exception_type is not None:
# An exception was raised by the application.

# Rollback the transaction.
self.db.rollback()
else:
# No exception was raised by the application.

# Commit the transaction.
self.db.commit()

# Close the database session.
self.db.close()

# Reset the context variable to its previous value.
db_conn_ctx_var.reset(self.token)
32 changes: 32 additions & 0 deletions kwik/database/db_context_switcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, Generator

from kwik.database import DBContextManager
from kwik.database.context_vars import db_conn_ctx_var

if TYPE_CHECKING:
from sqlalchemy.orm import Session


@contextmanager
def DBContextSwitcher() -> Generator[Session, None, None]: # noqa: N802
"""
Context manager to switch to an alternate database.
Example:
with db_context_switcher():
# Do something with the alternate database.
pass
"""

# Get the current database session from the context variable.
prev_db_conn_ctx_var = db_conn_ctx_var.get()

# Create a new database session.
with DBContextManager() as db:
yield db

# Restore the previous database session in the context variable.
db_conn_ctx_var.set(prev_db_conn_ctx_var)
11 changes: 11 additions & 0 deletions kwik/database/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from __future__ import annotations

import kwik
from sqlalchemy import create_engine

engine = create_engine(
url=kwik.settings.SQLALCHEMY_DATABASE_URI,
pool_pre_ping=True,
pool_size=kwik.settings.POSTGRES_MAX_CONNECTIONS // kwik.settings.BACKEND_WORKERS,
max_overflow=0,
)
10 changes: 10 additions & 0 deletions kwik/database/session_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from kwik.database.engine import engine
from sqlalchemy.orm import sessionmaker

SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
)
5 changes: 3 additions & 2 deletions kwik/middlewares/db_session.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import kwik.crud.base
from __future__ import annotations

from kwik.database.db_context_manager import DBContextManager
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request


class DBSessionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
with DBContextManager(settings=kwik.settings) as db:
with DBContextManager() as db:
request.state.db = db
response = await call_next(request)

Expand Down
20 changes: 11 additions & 9 deletions kwik/tests/utils/setup.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import os
from __future__ import annotations

from typing import Callable

import kwik
from kwik.database import DBContextManager
from kwik.database.base import Base


def init_test_db(db_path: str, init_db: Callable, *args, **kwargs) -> None:
# Create a temporary database
if os.path.exists(db_path):
os.remove(db_path)
os.mknod(db_path)

def init_test_db(init_db: Callable, *args, **kwargs) -> None:
# Initialize the database
with kwik.utils.tests.test_db(db_path=db_path, setup=False) as db:
with DBContextManager() as db:
Base.metadata.create_all(bind=db.get_bind())
init_db(*args, **kwargs)


def drop_test_db() -> None:
# Drop the database
with DBContextManager() as db:
Base.metadata.drop_all(bind=db.get_bind())
12 changes: 5 additions & 7 deletions kwik/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from .files import store_file
from __future__ import annotations

from .emails import (
send_email,
send_test_email,
send_new_account_email,
send_reset_password_email,
send_test_email,
)
from .login import (
generate_password_reset_token,
verify_password_reset_token,
)
from .files import store_file
from .login import generate_password_reset_token, verify_password_reset_token
from .query import sort_query
from . import tests
28 changes: 0 additions & 28 deletions kwik/utils/tests.py

This file was deleted.

0 comments on commit 2c7d311

Please sign in to comment.