Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: relationships loading #105

Closed
wants to merge 12 commits into from
4 changes: 4 additions & 0 deletions advanced_alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SyncSessionConfig,
)
from advanced_alchemy.repository._async import SQLAlchemyAsyncRepository
from advanced_alchemy.repository._load import SQLAlchemyLoad, SQLAlchemyLoadConfig, SQLALoadStrategy
from advanced_alchemy.repository._sync import SQLAlchemySyncRepository
from advanced_alchemy.repository._util import wrap_sqlalchemy_exception
from advanced_alchemy.repository.typing import ModelT
Expand All @@ -24,6 +25,9 @@
"NotFoundError",
"RepositoryError",
"SQLAlchemyAsyncRepository",
"SQLAlchemyLoad",
"SQLAlchemyLoadConfig",
"SQLALoadStrategy",
"SQLAlchemySyncRepository",
"SQLAlchemySyncRepositoryService",
"SQLAlchemySyncRepositoryReadService",
Expand Down
3 changes: 3 additions & 0 deletions advanced_alchemy/repository/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from ._async import SQLAlchemyAsyncRepository
from ._load import SQLAlchemyLoad, SQLAlchemyLoadConfig
from ._sync import SQLAlchemySyncRepository
from ._util import get_instrumented_attr, model_from_dict

__all__ = (
"SQLAlchemyLoad",
"SQLAlchemyLoadConfig",
"SQLAlchemyAsyncRepository",
"SQLAlchemySyncRepository",
"get_instrumented_attr",
Expand Down
51 changes: 50 additions & 1 deletion advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,22 @@
SearchFilter,
)
from advanced_alchemy.operations import Merge
from advanced_alchemy.repository._load import SQLAlchemyLoad, SQLAlchemyLoadConfig
from advanced_alchemy.repository._util import get_instrumented_attr, wrap_sqlalchemy_exception
from advanced_alchemy.repository.typing import ModelT
from advanced_alchemy.utils.deprecation import deprecated

if TYPE_CHECKING:
from collections import abc
from datetime import datetime
from typing import Self
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may need to be typing_extensions for us to get 3.8 support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, just fixed it, but EllipsisType does not seem to available before 3.10


from sqlalchemy.engine.interfaces import _CoreSingleExecuteParams
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.scoping import async_scoped_session

from advanced_alchemy.repository._load import AnySQLAtrategy

DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS: Final = 950
POSTGRES_VERSION_SUPPORTING_MERGE: Final = 15

Expand All @@ -68,6 +72,7 @@ def __init__(
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
load: SQLAlchemyLoad | None = None,
**kwargs: Any,
) -> None:
"""Repository pattern for SQLAlchemy models.
Expand All @@ -78,6 +83,7 @@ def __init__(
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
load: Default relationships load.
**kwargs: Additional arguments.

"""
Expand All @@ -86,6 +92,7 @@ def __init__(
self.auto_refresh = auto_refresh
self.auto_commit = auto_commit
self.session = session
self.default_load = load or SQLAlchemyLoad()
if isinstance(statement, Select):
self.statement = lambda_stmt(lambda: statement)
elif statement is None:
Expand All @@ -95,6 +102,7 @@ def __init__(
self.statement = statement
self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect
self._prefer_any = any(self._dialect.name == engine_type for engine_type in self.prefer_any_dialects or ())
self._load: SQLAlchemyLoad = self.default_load

@classmethod
def get_id_attribute_value(cls, item: ModelT | type[ModelT], id_attribute: str | None = None) -> Any:
Expand Down Expand Up @@ -146,6 +154,24 @@ def check_not_found(item_or_none: ModelT | None) -> ModelT:
raise NotFoundError(msg)
return item_or_none

def load(
self,
config: SQLAlchemyLoadConfig | None = None,
/,
**kwargs: AnySQLAtrategy,
) -> Self:
"""Set relationships to be loaded on the model

Args:
config: Override default load config. Defaults to None.
kwargs: Relationship paths to load

Returns:
The repository instance
"""
self._load = SQLAlchemyLoad(config, **kwargs)
return self

async def add(
self,
data: ModelT,
Expand All @@ -169,6 +195,9 @@ async def add(
"""
with wrap_sqlalchemy_exception():
instance = await self._attach_to_session(data)
if self._load:
await self._flush_or_commit(auto_commit=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the auto_commit follow what was sent in from the method or actually be True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should stay True since the following self._refresh_with_load() emits a select to get back the newly inserted rows with loaded relationships.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just thinking through this a bit more, and it's still not totally clear.

We definitely need a flush so that the inserted row is loaded into relationship. As long as it isn't a new session though, we don't have to commit to make that happen. However, i'm not quite following why the commit is necessary.

Is there a simple use case that I can walk through to visualize why the commit over a flush is needed?

return await self._refresh_with_load(instance)
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(instance, auto_refresh=auto_refresh)
self._expunge(instance, auto_expunge=auto_expunge)
Expand Down Expand Up @@ -812,6 +841,17 @@ async def _refresh(
else None
)

async def _refresh_with_load(self, instance: ModelT) -> ModelT:
with wrap_sqlalchemy_exception():
statement = self._get_base_stmt()
statement = self._filter_select_by_kwargs(
statement,
{self.id_attribute: getattr(instance, self.id_attribute)},
)
result = await self._execute(statement)
refreshed_instance = result.scalar_one_or_none()
return self.check_not_found(refreshed_instance)

async def _list_and_count_window(
self,
*filters: FilterTypes | ColumnElement[bool],
Expand Down Expand Up @@ -1175,7 +1215,16 @@ async def _attach_to_session(self, model: ModelT, strategy: Literal["add", "merg
raise ValueError(msg)

async def _execute(self, statement: Select[Any] | StatementLambdaElement) -> Result[Any]:
return await self.session.execute(statement)
if self._load:
if isinstance(statement, Select):
statement = lambda_stmt(lambda: statement)
loaders = self._load.loaders(self.model_type)
statement += lambda s: s.options(*loaders)
result = await self.session.execute(statement)
if self._load and self._load.has_wildcards():
result = result.unique()
self._load = self.default_load
return result

def _apply_limit_offset_pagination(
self,
Expand Down
132 changes: 132 additions & 0 deletions advanced_alchemy/repository/_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Tuple, Union

from sqlalchemy import inspect
from sqlalchemy.orm import defaultload, joinedload, noload, raiseload, selectinload, subqueryload

if TYPE_CHECKING:
from collections.abc import Callable
from types import EllipsisType

from sqlalchemy.orm import Mapper, RelationshipProperty
from sqlalchemy.orm.strategy_options import _AbstractLoad
from typing_extensions import TypeAlias

from advanced_alchemy import ModelT

AnySQLAtrategy: TypeAlias = Union["SQLALoadStrategy", bool, EllipsisType]
LoadPath: TypeAlias = Tuple[Tuple[str, ...], AnySQLAtrategy]

SQLALoadStrategy = Literal["defaultload", "noload", "joinedload", "selectinload", "subqueryload", "raiseload"]


@dataclass
class SQLAlchemyLoadConfig:
sep: str = "__"
default_strategy: AnySQLAtrategy | None = None


class SQLAlchemyLoad:
_strategy_map: dict[SQLALoadStrategy, Callable[..., _AbstractLoad]] = {
"defaultload": defaultload,
"joinedload": joinedload,
"noload": noload,
"raiseload": raiseload,
"selectinload": selectinload,
"subqueryload": subqueryload,
}

def __init__(
self,
config: SQLAlchemyLoadConfig | None = None,
/,
**kwargs: AnySQLAtrategy,
) -> None:
self._config = config if config is not None else SQLAlchemyLoadConfig()
self._kwargs = kwargs
self._paths = self._load_paths(**self._kwargs)
self._identity = (self._config.default_strategy, self._paths)

def __eq__(self, other: object) -> bool:
if isinstance(other, SQLAlchemyLoad):
return self._config.default_strategy == other._config.default_strategy and self._paths == other._paths
return False

def __bool__(self) -> bool:
return bool(self._config.default_strategy or self._kwargs)

def _load_paths(self, **kwargs: AnySQLAtrategy) -> tuple[LoadPath, ...]:
"""Split loading paths into tuples."""
# Resolve path conflicts: the last takes precedence
# - {"a": False, "a__b": True} -> {"a__b": True}
to_remove: set[str] = set()
for key, strategy in kwargs.items():
for other_key, other_strategy in kwargs.items():
if (
other_key != key
and other_key.startswith(key)
and not self._strategy_will_load(strategy)
and other_strategy != strategy
):
to_remove.add(key)
kwargs = {key: val for key, val in kwargs.items() if key not in to_remove}
return tuple((tuple(key.split(self._config.sep)), kwargs[key]) for key in sorted(kwargs))

@classmethod
def _strategy_to_load_fn(cls, strategy: AnySQLAtrategy, uselist: bool = False) -> Callable[..., _AbstractLoad]:
if not strategy:
return raiseload
if isinstance(strategy, str):
return cls._strategy_map[strategy]
if uselist:
return selectinload
return joinedload

def _default_load_strategy(self) -> _AbstractLoad | None:
if self._config.default_strategy is not None:
return self._strategy_to_load_fn(self._config.default_strategy)("*")
return None

def has_wildcards(self) -> bool:
"""Check if wildcard loading is used in any of loading path.

Returns:
True if there is at least one wildcard use, False otherwise
"""
return self._config.default_strategy is not None or Ellipsis in self._kwargs.values()

def loaders(self, model_type: type[ModelT]) -> list[_AbstractLoad]:
loaders: list[_AbstractLoad] = []
for path, strategy in self._paths:
mapper: Mapper[ModelT] = inspect(model_type, raiseerr=True)
loader_chain: list[_AbstractLoad] = []
relationship: RelationshipProperty[Any] | None = None
# Builder loaders
for i, key in enumerate(path):
key_strategy = strategy
current_prefix_strategy = self._kwargs.get(self._config.sep.join(path[: i + 1]), None)
if not self._strategy_will_load(strategy) and self._strategy_will_load(current_prefix_strategy):
key_strategy = True
relationship = mapper.relationships[key]
load_arg = "*" if strategy is Ellipsis else relationship.class_attribute
key_loader = self._strategy_to_load_fn(key_strategy, bool(relationship.uselist))(load_arg)
loader_chain.append(key_loader)
if relationship is not None:
mapper = inspect(relationship.entity.class_, raiseerr=True)
# Chain them together
path_loader = loader_chain[-1]
for loader in loader_chain[-2::-1]:
path_loader = loader.options(path_loader)
loaders.append(path_loader)
if (default_load := self._default_load_strategy()) is not None:
loaders.append(default_load)
return loaders

def _strategy_will_load(self, strategy: SQLALoadStrategy | bool | EllipsisType | None) -> bool:
if strategy is False or strategy is None:
return False
if isinstance(strategy, str):
return strategy not in ["noload", "raiseload"]
return True
Loading
Loading