Skip to content

Commit

Permalink
feat: UniqueMixin that instantiates objects ensuring uniqueness on …
Browse files Browse the repository at this point in the history
…some field(s) (#138)

Co-authored-by: Alc-Alc <alc@localhost>
Co-authored-by: Cody Fincher <cody.fincher@gmail.com>
  • Loading branch information
3 people committed Mar 19, 2024
1 parent 49b9062 commit 3364b6e
Show file tree
Hide file tree
Showing 12 changed files with 344 additions and 38 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ repos:
hooks:
- id: codespell
exclude: "pdm.lock|examples/us_state_lookup.json"
additional_dependencies:
- tomli
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
Expand Down
13 changes: 9 additions & 4 deletions advanced_alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,26 @@
SQLAlchemySyncConfig,
SyncSessionConfig,
)
from advanced_alchemy.exceptions import (
IntegrityError,
MultipleResultsFoundError,
NotFoundError,
RepositoryError,
wrap_sqlalchemy_exception,
)
from advanced_alchemy.filters import FilterTypes
from advanced_alchemy.repository._async import SQLAlchemyAsyncRepository
from advanced_alchemy.repository._sync import SQLAlchemySyncRepository
from advanced_alchemy.repository._util import wrap_sqlalchemy_exception
from advanced_alchemy.repository.memory._async import SQLAlchemyAsyncMockRepository
from advanced_alchemy.repository.memory._sync import SQLAlchemySyncMockRepository
from advanced_alchemy.repository.typing import ModelT
from advanced_alchemy.service._async import SQLAlchemyAsyncRepositoryReadService, SQLAlchemyAsyncRepositoryService
from advanced_alchemy.service._sync import SQLAlchemySyncRepositoryReadService, SQLAlchemySyncRepositoryService

from .exceptions import IntegrityError, NotFoundError, RepositoryError
from .filters import FilterTypes

__all__ = (
"IntegrityError",
"FilterTypes",
"MultipleResultsFoundError",
"NotFoundError",
"RepositoryError",
"SQLAlchemyAsyncMockRepository",
Expand Down
35 changes: 35 additions & 0 deletions advanced_alchemy/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import Any

from sqlalchemy.exc import IntegrityError as SQLAlchemyIntegrityError
from sqlalchemy.exc import MultipleResultsFound, SQLAlchemyError

from advanced_alchemy.utils.deprecation import deprecated


Expand Down Expand Up @@ -84,3 +88,34 @@ class IntegrityError(RepositoryError):

class NotFoundError(RepositoryError):
"""An identity does not exist."""


class MultipleResultsFoundError(AdvancedAlchemyError):
"""A single database result was required but more than one were found."""


@contextmanager
def wrap_sqlalchemy_exception() -> Any:
"""Do something within context to raise a ``RepositoryError`` chained
from an original ``SQLAlchemyError``.
>>> try:
... with wrap_sqlalchemy_exception():
... raise SQLAlchemyError("Original Exception")
... except RepositoryError as exc:
... print(f"caught repository exception from {type(exc.__context__)}")
...
caught repository exception from <class 'sqlalchemy.exc.SQLAlchemyError'>
"""
try:
yield
except MultipleResultsFound as e:
msg = "Multiple rows matched the specified key"
raise MultipleResultsFoundError(msg) from e
except SQLAlchemyIntegrityError as exc:
raise IntegrityError from exc
except SQLAlchemyError as exc:
msg = f"An exception occurred: {exc}"
raise RepositoryError(msg) from exc
except AttributeError as exc:
raise RepositoryError from exc
3 changes: 3 additions & 0 deletions advanced_alchemy/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .unique import UniqueMixin

__all__ = ("UniqueMixin",)
160 changes: 160 additions & 0 deletions advanced_alchemy/mixins/unique.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, Any

from sqlalchemy import ColumnElement, select

from advanced_alchemy.exceptions import wrap_sqlalchemy_exception

if TYPE_CHECKING:
from collections.abc import Hashable
from typing import Iterator

from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.scoping import async_scoped_session
from sqlalchemy.orm import Session
from sqlalchemy.orm.scoping import scoped_session
from typing_extensions import Self


class UniqueMixin:
"""Mixin for instantiating objects while ensuring uniqueness on some field(s).
This is a slightly modified implementation derived from https://github.com/sqlalchemy/sqlalchemy/wiki/UniqueObject
"""

@classmethod
@contextmanager
def _prevent_autoflush(
cls,
session: AsyncSession | async_scoped_session[AsyncSession] | Session | scoped_session[Session],
) -> Iterator[None]:
with session.no_autoflush, wrap_sqlalchemy_exception():
yield

@classmethod
def _check_uniqueness(
cls,
cache: dict[tuple[type[Self], Hashable], Self] | None,
session: AsyncSession | async_scoped_session[AsyncSession] | Session | scoped_session[Session],
key: tuple[type[Self], Hashable],
*args: Any,
**kwargs: Any,
) -> tuple[dict[tuple[type[Self], Hashable], Self], Select[tuple[Self]], Self | None]:
if cache is None:
cache = {}
setattr(session, "_unique_cache", cache)
statement = select(cls).where(cls.unique_filter(*args, **kwargs)).limit(2)
return cache, statement, cache.get(key)

@classmethod
async def as_unique_async(
cls,
session: AsyncSession | async_scoped_session[AsyncSession],
*args: Any,
**kwargs: Any,
) -> Self:
"""Instantiate and return a unique object within the provided session based on the given arguments.
If an object with the same unique identifier already exists in the session, it is returned from the cache.
Args:
session (AsyncSession | async_scoped_session[AsyncSession]): SQLAlchemy async session
*args (Any): Values used to instantiate the instance if no duplicate exists
**kwargs (Any): Values used to instantiate the instance if no duplicate exists
Returns:
Self: The unique object instance.
"""
key = cls, cls.unique_hash(*args, **kwargs)
cache, statement, obj = cls._check_uniqueness(
getattr(session, "_unique_cache", None),
session,
key,
*args,
**kwargs,
)
if obj:
return obj
with cls._prevent_autoflush(session):
if (obj := (await session.execute(statement)).scalar_one_or_none()) is None:
session.add(obj := cls(*args, **kwargs))
cache[key] = obj
return obj

@classmethod
def as_unique_sync(
cls,
session: Session | scoped_session[Session],
*args: Any,
**kwargs: Any,
) -> Self:
"""Instantiate and return a unique object within the provided session based on the given arguments.
If an object with the same unique identifier already exists in the session, it is returned from the cache.
Args:
session (Session | scoped_session[Session]): SQLAlchemy sync session
*args (Any): Values used to instantiate the instance if no duplicate exists
**kwargs (Any): Values used to instantiate the instance if no duplicate exists
Returns:
Self: The unique object instance.
"""
key = cls, cls.unique_hash(*args, **kwargs)
cache, statement, obj = cls._check_uniqueness(
getattr(session, "_unique_cache", None),
session,
key,
*args,
**kwargs,
)
if obj:
return obj
with cls._prevent_autoflush(session):
if (obj := session.execute(statement).scalar_one_or_none()) is None:
session.add(obj := cls(*args, **kwargs))
cache[key] = obj
return obj

@classmethod
def unique_hash(cls, *args: Any, **kwargs: Any) -> Hashable: # noqa: ARG003
"""Generate a unique key based on the provided arguments.
This method should be implemented in the subclass.
Args:
*args (Any): Values passed to the alternate classmethod constructors
**kwargs (Any): Values passed to the alternate classmethod constructors
Raises:
NotImplementedError: If not implemented in the subclass.
Returns:
Hashable: Any hashable object.
"""
msg = "Implement this in subclass"
raise NotImplementedError(msg)

@classmethod
def unique_filter(cls, *args: Any, **kwargs: Any) -> ColumnElement[bool]: # noqa: ARG003
"""Generate a filter condition for ensuring uniqueness.
This method should be implemented in the subclass.
Args:
*args (Any): Values passed to the alternate classmethod constructors
**kwargs (Any): Values passed to the alternate classmethod constructors
Raises:
NotImplementedError: If not implemented in the subclass.
Returns:
ColumnElement[bool]: Filter condition to establish the uniqueness.
"""
msg = "Implement this in subclass"
raise NotImplementedError(msg)
4 changes: 2 additions & 2 deletions advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql import ColumnElement, ColumnExpressionArgument

from advanced_alchemy.exceptions import NotFoundError, RepositoryError
from advanced_alchemy.exceptions import NotFoundError, RepositoryError, wrap_sqlalchemy_exception
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
Expand All @@ -32,7 +32,7 @@
SearchFilter,
)
from advanced_alchemy.operations import Merge
from advanced_alchemy.repository._util import get_instrumented_attr, wrap_sqlalchemy_exception
from advanced_alchemy.repository._util import get_instrumented_attr
from advanced_alchemy.repository.typing import MISSING, ModelT
from advanced_alchemy.utils.deprecation import deprecated

Expand Down
4 changes: 2 additions & 2 deletions advanced_alchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from sqlalchemy.orm import InstrumentedAttribute, Session
from sqlalchemy.sql import ColumnElement, ColumnExpressionArgument

from advanced_alchemy.exceptions import NotFoundError, RepositoryError
from advanced_alchemy.exceptions import NotFoundError, RepositoryError, wrap_sqlalchemy_exception
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
Expand All @@ -34,7 +34,7 @@
SearchFilter,
)
from advanced_alchemy.operations import Merge
from advanced_alchemy.repository._util import get_instrumented_attr, wrap_sqlalchemy_exception
from advanced_alchemy.repository._util import get_instrumented_attr
from advanced_alchemy.repository.typing import MISSING, ModelT
from advanced_alchemy.utils.deprecation import deprecated

Expand Down
30 changes: 0 additions & 30 deletions advanced_alchemy/repository/_util.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,14 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, cast

from sqlalchemy.exc import IntegrityError as SQLAlchemyIntegrityError
from sqlalchemy.exc import SQLAlchemyError

from advanced_alchemy.exceptions import IntegrityError, RepositoryError

if TYPE_CHECKING:
from sqlalchemy.orm import InstrumentedAttribute

from advanced_alchemy.base import ModelProtocol
from advanced_alchemy.repository.typing import ModelT


@contextmanager
def wrap_sqlalchemy_exception() -> Any:
"""Do something within context to raise a `RepositoryError` chained
from an original `SQLAlchemyError`.
>>> try:
... with wrap_sqlalchemy_exception():
... raise SQLAlchemyError("Original Exception")
... except RepositoryError as exc:
... print(f"caught repository exception from {type(exc.__context__)}")
...
caught repository exception from <class 'sqlalchemy.exc.SQLAlchemyError'>
"""
try:
yield
except SQLAlchemyIntegrityError as exc:
raise IntegrityError from exc
except SQLAlchemyError as exc:
msg = f"An exception occurred: {exc}"
raise RepositoryError(msg) from exc
except AttributeError as exc:
raise RepositoryError from exc


def get_instrumented_attr(model: type[ModelProtocol], key: str | InstrumentedAttribute) -> InstrumentedAttribute:
if isinstance(key, str):
return cast("InstrumentedAttribute", getattr(model, key))
Expand Down
1 change: 1 addition & 0 deletions docs/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ Available API References
alembic/index
config/index
extensions/index
mixins/index
15 changes: 15 additions & 0 deletions docs/reference/mixins/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
======
mixins
======

API Reference for the ``mixins`` module

.. note:: Private methods and attributes are not included in the API reference.

Available API References
------------------------

.. toctree::
:titlesonly:

unique
6 changes: 6 additions & 0 deletions docs/reference/mixins/unique.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
======
unique
======

.. automodule:: advanced_alchemy.mixins.unique
:members:

0 comments on commit 3364b6e

Please sign in to comment.