-
-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: relationships loading #105
Changes from all commits
830559a
5a51e55
2d226a2
e96a4d1
96f8910
74b5490
7f6cec2
5e8ae4f
a395db0
5655a35
aae19a8
d40773f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,18 +32,22 @@ | |
SearchFilter, | ||
) | ||
from advanced_alchemy.operations import Merge | ||
from advanced_alchemy.repository._load import SQLAlchemyLoad, SQLAlchemyLoadConfig | ||
from advanced_alchemy.repository._util import get_instrumented_attr, wrap_sqlalchemy_exception | ||
from advanced_alchemy.repository.typing import ModelT | ||
from advanced_alchemy.utils.deprecation import deprecated | ||
|
||
if TYPE_CHECKING: | ||
from collections import abc | ||
from datetime import datetime | ||
from typing import Self | ||
|
||
from sqlalchemy.engine.interfaces import _CoreSingleExecuteParams | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
from sqlalchemy.ext.asyncio.scoping import async_scoped_session | ||
|
||
from advanced_alchemy.repository._load import AnySQLAtrategy | ||
|
||
DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS: Final = 950 | ||
POSTGRES_VERSION_SUPPORTING_MERGE: Final = 15 | ||
|
||
|
@@ -68,6 +72,7 @@ def __init__( | |
auto_expunge: bool = False, | ||
auto_refresh: bool = True, | ||
auto_commit: bool = False, | ||
load: SQLAlchemyLoad | None = None, | ||
**kwargs: Any, | ||
) -> None: | ||
"""Repository pattern for SQLAlchemy models. | ||
|
@@ -78,6 +83,7 @@ def __init__( | |
auto_expunge: Remove object from session before returning. | ||
auto_refresh: Refresh object from session before returning. | ||
auto_commit: Commit objects before returning. | ||
load: Default relationships load. | ||
**kwargs: Additional arguments. | ||
|
||
""" | ||
|
@@ -86,6 +92,7 @@ def __init__( | |
self.auto_refresh = auto_refresh | ||
self.auto_commit = auto_commit | ||
self.session = session | ||
self.default_load = load or SQLAlchemyLoad() | ||
if isinstance(statement, Select): | ||
self.statement = lambda_stmt(lambda: statement) | ||
elif statement is None: | ||
|
@@ -95,6 +102,7 @@ def __init__( | |
self.statement = statement | ||
self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect | ||
self._prefer_any = any(self._dialect.name == engine_type for engine_type in self.prefer_any_dialects or ()) | ||
self._load: SQLAlchemyLoad = self.default_load | ||
|
||
@classmethod | ||
def get_id_attribute_value(cls, item: ModelT | type[ModelT], id_attribute: str | None = None) -> Any: | ||
|
@@ -146,6 +154,24 @@ def check_not_found(item_or_none: ModelT | None) -> ModelT: | |
raise NotFoundError(msg) | ||
return item_or_none | ||
|
||
def load( | ||
self, | ||
config: SQLAlchemyLoadConfig | None = None, | ||
/, | ||
**kwargs: AnySQLAtrategy, | ||
) -> Self: | ||
"""Set relationships to be loaded on the model | ||
|
||
Args: | ||
config: Override default load config. Defaults to None. | ||
kwargs: Relationship paths to load | ||
|
||
Returns: | ||
The repository instance | ||
""" | ||
self._load = SQLAlchemyLoad(config, **kwargs) | ||
return self | ||
|
||
async def add( | ||
self, | ||
data: ModelT, | ||
|
@@ -169,6 +195,9 @@ async def add( | |
""" | ||
with wrap_sqlalchemy_exception(): | ||
instance = await self._attach_to_session(data) | ||
if self._load: | ||
await self._flush_or_commit(auto_commit=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should stay There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just thinking through this a bit more, and it's still not totally clear. We definitely need a flush so that the inserted row is loaded into relationship. As long as it isn't a new session though, we don't have to commit to make that happen. However, i'm not quite following why the commit is necessary. Is there a simple use case that I can walk through to visualize why the commit over a flush is needed? |
||
return await self._refresh_with_load(instance) | ||
await self._flush_or_commit(auto_commit=auto_commit) | ||
await self._refresh(instance, auto_refresh=auto_refresh) | ||
self._expunge(instance, auto_expunge=auto_expunge) | ||
|
@@ -812,6 +841,17 @@ async def _refresh( | |
else None | ||
) | ||
|
||
async def _refresh_with_load(self, instance: ModelT) -> ModelT: | ||
with wrap_sqlalchemy_exception(): | ||
statement = self._get_base_stmt() | ||
statement = self._filter_select_by_kwargs( | ||
statement, | ||
{self.id_attribute: getattr(instance, self.id_attribute)}, | ||
) | ||
result = await self._execute(statement) | ||
refreshed_instance = result.scalar_one_or_none() | ||
return self.check_not_found(refreshed_instance) | ||
|
||
async def _list_and_count_window( | ||
self, | ||
*filters: FilterTypes | ColumnElement[bool], | ||
|
@@ -1175,7 +1215,16 @@ async def _attach_to_session(self, model: ModelT, strategy: Literal["add", "merg | |
raise ValueError(msg) | ||
|
||
async def _execute(self, statement: Select[Any] | StatementLambdaElement) -> Result[Any]: | ||
return await self.session.execute(statement) | ||
if self._load: | ||
if isinstance(statement, Select): | ||
statement = lambda_stmt(lambda: statement) | ||
loaders = self._load.loaders(self.model_type) | ||
statement += lambda s: s.options(*loaders) | ||
result = await self.session.execute(statement) | ||
if self._load and self._load.has_wildcards(): | ||
result = result.unique() | ||
self._load = self.default_load | ||
return result | ||
|
||
def _apply_limit_offset_pagination( | ||
self, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Any, Literal, Tuple, Union | ||
|
||
from sqlalchemy import inspect | ||
from sqlalchemy.orm import defaultload, joinedload, noload, raiseload, selectinload, subqueryload | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Callable | ||
from types import EllipsisType | ||
|
||
from sqlalchemy.orm import Mapper, RelationshipProperty | ||
from sqlalchemy.orm.strategy_options import _AbstractLoad | ||
from typing_extensions import TypeAlias | ||
|
||
from advanced_alchemy import ModelT | ||
|
||
AnySQLAtrategy: TypeAlias = Union["SQLALoadStrategy", bool, EllipsisType] | ||
LoadPath: TypeAlias = Tuple[Tuple[str, ...], AnySQLAtrategy] | ||
|
||
SQLALoadStrategy = Literal["defaultload", "noload", "joinedload", "selectinload", "subqueryload", "raiseload"] | ||
|
||
|
||
@dataclass | ||
class SQLAlchemyLoadConfig: | ||
sep: str = "__" | ||
default_strategy: AnySQLAtrategy | None = None | ||
|
||
|
||
class SQLAlchemyLoad: | ||
_strategy_map: dict[SQLALoadStrategy, Callable[..., _AbstractLoad]] = { | ||
"defaultload": defaultload, | ||
"joinedload": joinedload, | ||
"noload": noload, | ||
"raiseload": raiseload, | ||
"selectinload": selectinload, | ||
"subqueryload": subqueryload, | ||
} | ||
|
||
def __init__( | ||
self, | ||
config: SQLAlchemyLoadConfig | None = None, | ||
/, | ||
**kwargs: AnySQLAtrategy, | ||
) -> None: | ||
self._config = config if config is not None else SQLAlchemyLoadConfig() | ||
self._kwargs = kwargs | ||
self._paths = self._load_paths(**self._kwargs) | ||
self._identity = (self._config.default_strategy, self._paths) | ||
|
||
def __eq__(self, other: object) -> bool: | ||
if isinstance(other, SQLAlchemyLoad): | ||
return self._config.default_strategy == other._config.default_strategy and self._paths == other._paths | ||
return False | ||
|
||
def __bool__(self) -> bool: | ||
return bool(self._config.default_strategy or self._kwargs) | ||
|
||
def _load_paths(self, **kwargs: AnySQLAtrategy) -> tuple[LoadPath, ...]: | ||
"""Split loading paths into tuples.""" | ||
# Resolve path conflicts: the last takes precedence | ||
# - {"a": False, "a__b": True} -> {"a__b": True} | ||
to_remove: set[str] = set() | ||
for key, strategy in kwargs.items(): | ||
for other_key, other_strategy in kwargs.items(): | ||
if ( | ||
other_key != key | ||
and other_key.startswith(key) | ||
and not self._strategy_will_load(strategy) | ||
and other_strategy != strategy | ||
): | ||
to_remove.add(key) | ||
kwargs = {key: val for key, val in kwargs.items() if key not in to_remove} | ||
return tuple((tuple(key.split(self._config.sep)), kwargs[key]) for key in sorted(kwargs)) | ||
|
||
@classmethod | ||
def _strategy_to_load_fn(cls, strategy: AnySQLAtrategy, uselist: bool = False) -> Callable[..., _AbstractLoad]: | ||
if not strategy: | ||
return raiseload | ||
if isinstance(strategy, str): | ||
return cls._strategy_map[strategy] | ||
if uselist: | ||
return selectinload | ||
return joinedload | ||
|
||
def _default_load_strategy(self) -> _AbstractLoad | None: | ||
if self._config.default_strategy is not None: | ||
return self._strategy_to_load_fn(self._config.default_strategy)("*") | ||
return None | ||
|
||
def has_wildcards(self) -> bool: | ||
"""Check if wildcard loading is used in any of loading path. | ||
|
||
Returns: | ||
True if there is at least one wildcard use, False otherwise | ||
""" | ||
return self._config.default_strategy is not None or Ellipsis in self._kwargs.values() | ||
|
||
def loaders(self, model_type: type[ModelT]) -> list[_AbstractLoad]: | ||
loaders: list[_AbstractLoad] = [] | ||
for path, strategy in self._paths: | ||
mapper: Mapper[ModelT] = inspect(model_type, raiseerr=True) | ||
loader_chain: list[_AbstractLoad] = [] | ||
relationship: RelationshipProperty[Any] | None = None | ||
# Builder loaders | ||
for i, key in enumerate(path): | ||
key_strategy = strategy | ||
current_prefix_strategy = self._kwargs.get(self._config.sep.join(path[: i + 1]), None) | ||
if not self._strategy_will_load(strategy) and self._strategy_will_load(current_prefix_strategy): | ||
key_strategy = True | ||
relationship = mapper.relationships[key] | ||
load_arg = "*" if strategy is Ellipsis else relationship.class_attribute | ||
key_loader = self._strategy_to_load_fn(key_strategy, bool(relationship.uselist))(load_arg) | ||
loader_chain.append(key_loader) | ||
if relationship is not None: | ||
mapper = inspect(relationship.entity.class_, raiseerr=True) | ||
# Chain them together | ||
path_loader = loader_chain[-1] | ||
for loader in loader_chain[-2::-1]: | ||
path_loader = loader.options(path_loader) | ||
loaders.append(path_loader) | ||
if (default_load := self._default_load_strategy()) is not None: | ||
loaders.append(default_load) | ||
return loaders | ||
|
||
def _strategy_will_load(self, strategy: SQLALoadStrategy | bool | EllipsisType | None) -> bool: | ||
if strategy is False or strategy is None: | ||
return False | ||
if isinstance(strategy, str): | ||
return strategy not in ["noload", "raiseload"] | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may need to be
typing_extensions
for us to get 3.8 support?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, just fixed it, but
EllipsisType
does not seem to available before 3.10