Skip to content

Commit

Permalink
feat: merge table argument helper (#158)
Browse files Browse the repository at this point in the history
Co-authored-by: Peter Schutt <peter.github@proton.me>
  • Loading branch information
cofin and peterschutt committed Apr 10, 2024
1 parent 96aad8d commit 99c5446
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 5 deletions.
46 changes: 45 additions & 1 deletion advanced_alchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"SlugKey",
"SQLQuery",
"orm_registry",
"merge_table_arguments",
)


Expand All @@ -75,6 +76,49 @@
"""Templates for automated constraint name generation."""


def merge_table_arguments(
cls: DeclarativeBase,
*mixins: Any,
table_args: dict | tuple | None = None,
) -> tuple | dict:
"""Merge Table Arguments.
When using mixins that include their own table args, it is difficult to append info into the model such as a comment.
This function helps you merge the args together.
Args:
cls (DeclarativeBase): This is the model that will get the table args
*mixins (Any): The mixins to add into the model
table_args: additional information to add to tableargs
Returns:
tuple | dict: The merged __table_args__ property
"""
args: list[Any] = []
kwargs: dict[str, Any] = {}

mixin_table_args = (getattr(super(base_cls, cls), "__table_args__", None) for base_cls in (cls, *mixins))

for arg_to_merge in (*mixin_table_args, table_args):
if arg_to_merge:
if isinstance(arg_to_merge, tuple):
last_positional_arg = arg_to_merge[-1]
args.extend(arg_to_merge[:-1])
if isinstance(last_positional_arg, dict):
kwargs.update(last_positional_arg)
else:
args.append(last_positional_arg)
elif isinstance(arg_to_merge, dict):
kwargs.update(arg_to_merge)

if args:
if kwargs:
return (*args, kwargs)
return tuple(args)
return kwargs


@runtime_checkable
class ModelProtocol(Protocol):
"""The base SQLAlchemy model protocol."""
Expand Down Expand Up @@ -199,7 +243,7 @@ def _create_unique_slug_constraint(*_args: Any, **kwargs: Any) -> bool:
return not kwargs["dialect"].name.startswith("spanner")

@declared_attr.directive
def __table_args__(cls) -> tuple:
def __table_args__(cls) -> tuple | dict:
return (
UniqueConstraint(
cls.slug,
Expand Down
12 changes: 10 additions & 2 deletions tests/models_bigint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from typing import Any, List

from sqlalchemy import Column, FetchedValue, ForeignKey, String, Table, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship

from advanced_alchemy.base import BigIntAuditBase, BigIntBase, SlugKey
from advanced_alchemy.base import BigIntAuditBase, BigIntBase, SlugKey, merge_table_arguments
from advanced_alchemy.repository import (
SQLAlchemyAsyncRepository,
SQLAlchemyAsyncSlugRepository,
Expand Down Expand Up @@ -59,6 +59,14 @@ class BigIntSlugBook(BigIntBase, SlugKey):
title: Mapped[str] = mapped_column(String(length=250)) # pyright: ignore
author_id: Mapped[str] = mapped_column(String(length=250)) # pyright: ignore

@declared_attr.directive
def __table_args__(cls) -> dict | tuple:
return merge_table_arguments(
cls,
SlugKey,
table_args={"comment": "Slugbook"},
)


class BigIntEventLog(BigIntAuditBase):
"""The event log domain object."""
Expand Down
12 changes: 10 additions & 2 deletions tests/models_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from uuid import UUID

from sqlalchemy import Column, FetchedValue, ForeignKey, String, Table, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship

from advanced_alchemy.base import SlugKey, UUIDAuditBase, UUIDBase, UUIDv6Base, UUIDv7Base
from advanced_alchemy.base import SlugKey, UUIDAuditBase, UUIDBase, UUIDv6Base, UUIDv7Base, merge_table_arguments
from advanced_alchemy.repository import (
SQLAlchemyAsyncRepository,
SQLAlchemyAsyncSlugRepository,
Expand Down Expand Up @@ -55,6 +55,14 @@ class UUIDSlugBook(UUIDBase, SlugKey):
title: Mapped[str] = mapped_column(String(length=250)) # pyright: ignore
author_id: Mapped[str] = mapped_column(String(length=250)) # pyright: ignore

@declared_attr.directive
def __table_args__(cls) -> dict | tuple:
return merge_table_arguments(
cls,
SlugKey,
table_args={"comment": "Slugbook"},
)


class UUIDEventLog(UUIDAuditBase):
"""The event log domain object."""
Expand Down

0 comments on commit 99c5446

Please sign in to comment.