Skip to content

Commit

Permalink
feat: implement a repository proto and other typing improvements (#213)
Browse files Browse the repository at this point in the history
* feat: wip on a repo protocol

* feat: refactor test fixtures

* feat: linting

* fix: import changes

* fix: updated test import

* fix: additional import fix

* chore: updated types

* fix: fixture name

* feat: type signature fix for to_schema

* chore: fix formatting

* chore: add an additional override signature

* chore: linting

* fix: typing

* fix: text case

* fix: init method & protocol

* chore: protocol updates

* chore: pyright stuff

* additional mock fixture fixes

* fix: add runtime checkable

* feat: runtime checking tweak

* fix: test case type

* fix: 3.8
  • Loading branch information
cofin authored Jun 10, 2024
1 parent 1a04f3f commit f0a9554
Show file tree
Hide file tree
Showing 40 changed files with 3,146 additions and 2,081 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: unasyncd
additional_dependencies: ["ruff"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.4.7"
rev: "v0.4.8"
hooks:
- id: ruff
args: ["--fix"]
Expand Down
6 changes: 3 additions & 3 deletions advanced_alchemy/config/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ class AlembicAsyncConfig(GenericAlembicConfig):


@dataclass
class SQLAlchemyAsyncConfig(GenericSQLAlchemyConfig[AsyncEngine, AsyncSession, async_sessionmaker]):
class SQLAlchemyAsyncConfig(GenericSQLAlchemyConfig[AsyncEngine, AsyncSession, async_sessionmaker[AsyncSession]]):
"""Async SQLAlchemy Configuration."""

create_engine_callable: Callable[[str], AsyncEngine] = create_async_engine
"""Callable that creates an :class:`AsyncEngine <sqlalchemy.ext.asyncio.AsyncEngine>` instance or instance of its
subclass.
"""
session_config: AsyncSessionConfig = field(default_factory=AsyncSessionConfig)
session_config: AsyncSessionConfig = field(default_factory=AsyncSessionConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration options for the :class:`async_sessionmaker<sqlalchemy.ext.asyncio.async_sessionmaker>`."""
session_maker_class: type[async_sessionmaker] = async_sessionmaker
session_maker_class: type[async_sessionmaker[AsyncSession]] = async_sessionmaker # pyright: ignore[reportIncompatibleVariableOverride]
"""Sessionmaker class to use."""
alembic_config: AlembicAsyncConfig = field(default_factory=AlembicAsyncConfig)
"""Configuration for the SQLAlchemy Alembic migrations.
Expand Down
13 changes: 7 additions & 6 deletions advanced_alchemy/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Generic, TypeVar
from typing import TYPE_CHECKING, Callable, Generic, TypeVar, cast

from advanced_alchemy.base import orm_registry
from advanced_alchemy.config.engine import EngineConfig
Expand All @@ -27,14 +27,15 @@
)

ALEMBIC_TEMPLATE_PATH = f"{Path(__file__).parent.parent}/alembic/templates"

"""Path to the Alembic templates."""
ConnectionT = TypeVar("ConnectionT", bound="Connection | AsyncConnection")
"""Type variable for a SQLAlchemy connection."""
EngineT = TypeVar("EngineT", bound="Engine | AsyncEngine")
"""Type variable for a SQLAlchemy engine."""
SessionT = TypeVar("SessionT", bound="Session | AsyncSession")
"""Type variable for a SQLAlchemy session."""
SessionMakerT = TypeVar("SessionMakerT", bound="sessionmaker | async_sessionmaker")
SessionMakerT = TypeVar("SessionMakerT", bound="sessionmaker[Session] | async_sessionmaker[AsyncSession]")
"""Type variable for a SQLAlchemy sessionmaker."""


Expand All @@ -50,7 +51,7 @@ class GenericSessionConfig(Generic[ConnectionT, EngineT, SessionT]):
bind: EngineT | ConnectionT | None | EmptyType = Empty
"""The :class:`Engine <sqlalchemy.engine.Engine>` or :class:`Connection <sqlalchemy.engine.Connection>` that new
:class:`Session <sqlalchemy.orm.Session>` objects will be bound to."""
binds: dict[type[Any] | Mapper | TableClause | str, EngineT | ConnectionT] | None | EmptyType = Empty
binds: dict[type[Any] | Mapper | TableClause | str, EngineT | ConnectionT] | None | EmptyType = Empty # pyright: ignore[reportMissingTypeArgument]
"""A dictionary which may specify any number of :class:`Engine <sqlalchemy.engine.Engine>` or :class:`Connection
<sqlalchemy.engine.Connection>` objects as the source of connectivity for SQL operations on a per-entity basis. The
keys of the dictionary consist of any series of mapped classes, arbitrary Python classes that are bases for mapped
Expand All @@ -68,7 +69,7 @@ class GenericSessionConfig(Generic[ConnectionT, EngineT, SessionT]):
"""Describes the transactional behavior to take when a given bind is a Connection that has already begun a
transaction outside the scope of this Session; in other words the
:attr:`Connection.in_transaction() <sqlalchemy.Connection.in_transaction>` method returns True."""
query_cls: type[Query] | None | EmptyType = Empty
query_cls: type[Query] | None | EmptyType = Empty # pyright: ignore[reportMissingTypeArgument]
"""Class which should be used to create new Query objects, as returned by the
:attr:`Session.query() <sqlalchemy.orm.Session.query>` method."""
twophase: bool | EmptyType = Empty
Expand All @@ -92,7 +93,7 @@ class GenericSQLAlchemyConfig(Generic[EngineT, SessionT, SessionMakerT]):
"""Configuration options for either the :class:`async_sessionmaker <sqlalchemy.ext.asyncio.async_sessionmaker>`
or :class:`sessionmaker <sqlalchemy.orm.sessionmaker>`.
"""
session_maker_class: type[sessionmaker | async_sessionmaker]
session_maker_class: type[sessionmaker[Session] | async_sessionmaker[AsyncSession]]
"""Sessionmaker class to use."""
connection_string: str | None = field(default=None)
"""Database connection string in one of the formats supported by SQLAlchemy.
Expand Down Expand Up @@ -196,7 +197,7 @@ def create_session_maker(self) -> Callable[[], SessionT]:
session_kws = self.session_config_dict
if session_kws.get("bind") is None:
session_kws["bind"] = self.get_engine()
return self.session_maker_class(**session_kws)
return cast("Callable[[], SessionT]", self.session_maker_class(**session_kws))


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions advanced_alchemy/config/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ class AlembicSyncConfig(GenericAlembicConfig):


@dataclass
class SQLAlchemySyncConfig(GenericSQLAlchemyConfig[Engine, Session, sessionmaker]):
class SQLAlchemySyncConfig(GenericSQLAlchemyConfig[Engine, Session, sessionmaker[Session]]):
"""Sync SQLAlchemy Configuration."""

create_engine_callable: Callable[[str], Engine] = create_engine
"""Callable that creates an :class:`AsyncEngine <sqlalchemy.ext.asyncio.AsyncEngine>` instance or instance of its
subclass.
"""
session_config: SyncSessionConfig = field(default_factory=SyncSessionConfig) # pyright:ignore # noqa: PGH003
session_config: SyncSessionConfig = field(default_factory=SyncSessionConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration options for the :class:`sessionmaker<sqlalchemy.orm.sessionmaker>`."""
session_maker_class: type[sessionmaker] = sessionmaker
session_maker_class: type[sessionmaker[Session]] = sessionmaker # pyright: ignore[reportIncompatibleVariableOverride]
"""Sessionmaker class to use."""
alembic_config: AlembicSyncConfig = field(default_factory=AlembicSyncConfig)
"""Configuration for the SQLAlchemy Alembic migrations.
Expand Down
2 changes: 1 addition & 1 deletion advanced_alchemy/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def append_to_lambda_statement(
model: type[ModelT],
) -> StatementLambdaElement:
where_clause = self._operator(*self.get_search_clauses(model))
statement += lambda s: s.where(where_clause)
statement += lambda s: s.where(where_clause) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
return statement


Expand Down
16 changes: 15 additions & 1 deletion advanced_alchemy/repository/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
from advanced_alchemy.repository._async import (
SQLAlchemyAsyncQueryRepository,
SQLAlchemyAsyncRepository,
SQLAlchemyAsyncRepositoryProtocol,
SQLAlchemyAsyncSlugRepository,
SQLAlchemyAsyncSlugRepositoryProtocol,
)
from advanced_alchemy.repository._sync import (
SQLAlchemySyncQueryRepository,
SQLAlchemySyncRepository,
SQLAlchemySyncRepositoryProtocol,
SQLAlchemySyncSlugRepository,
SQLAlchemySyncSlugRepositoryProtocol,
)
from advanced_alchemy.repository._util import (
FilterableRepositoryProtocol,
LoadSpec,
get_instrumented_attr,
model_from_dict,
)
from advanced_alchemy.repository._util import LoadSpec, get_instrumented_attr, model_from_dict

__all__ = (
"SQLAlchemyAsyncRepository",
"SQLAlchemyAsyncRepositoryProtocol",
"SQLAlchemyAsyncSlugRepositoryProtocol",
"FilterableRepositoryProtocol",
"SQLAlchemySyncRepositoryProtocol",
"SQLAlchemySyncSlugRepositoryProtocol",
"SQLAlchemyAsyncQueryRepository",
"SQLAlchemyAsyncSlugRepository",
"SQLAlchemySyncSlugRepository",
Expand Down
Loading

0 comments on commit f0a9554

Please sign in to comment.