Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: merge table argument helper #158

Merged
merged 6 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion advanced_alchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"SlugKey",
"SQLQuery",
"orm_registry",
"merge_table_arguments",
)


Expand All @@ -75,6 +76,49 @@
"""Templates for automated constraint name generation."""


def merge_table_arguments(
cls: DeclarativeBase,
*mixins: Any,
table_args: dict | tuple | None = None,
) -> tuple | dict:
"""Merge Table Arguments.

When using mixins that include their own table args, it is difficult to append info into the model such as a comment.

This function helps you merge the args together.

Args:
cls (DeclarativeBase): This is the model that will get the table args
*mixins (Any): The mixins to add into the model
table_args: additional information to add to tableargs

Returns:
tuple | dict: The merged __table_args__ property
"""
args: list[Any] = []
kwargs: dict[str, Any] = {}

mixin_table_args = (getattr(super(base_cls, cls), "__table_args__", None) for base_cls in (cls, *mixins))

for arg_to_merge in (*mixin_table_args, table_args):
if arg_to_merge:
if isinstance(arg_to_merge, tuple):
last_positional_arg = arg_to_merge[-1]
args.extend(arg_to_merge[:-1])
if isinstance(last_positional_arg, dict):
kwargs.update(last_positional_arg)
else:
args.append(last_positional_arg)
elif isinstance(arg_to_merge, dict):
kwargs.update(arg_to_merge)

if args:
if kwargs:
return (*args, kwargs)
return tuple(args)
return kwargs


@runtime_checkable
class ModelProtocol(Protocol):
"""The base SQLAlchemy model protocol."""
Expand Down Expand Up @@ -199,7 +243,7 @@ def _create_unique_slug_constraint(*_args: Any, **kwargs: Any) -> bool:
return not kwargs["dialect"].name.startswith("spanner")

@declared_attr.directive
def __table_args__(cls) -> tuple:
def __table_args__(cls) -> tuple | dict:
return (
UniqueConstraint(
cls.slug,
Expand Down
12 changes: 10 additions & 2 deletions tests/models_bigint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from typing import Any, List

from sqlalchemy import Column, FetchedValue, ForeignKey, String, Table, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship

from advanced_alchemy.base import BigIntAuditBase, BigIntBase, SlugKey
from advanced_alchemy.base import BigIntAuditBase, BigIntBase, SlugKey, merge_table_arguments
from advanced_alchemy.repository import (
SQLAlchemyAsyncRepository,
SQLAlchemyAsyncSlugRepository,
Expand Down Expand Up @@ -59,6 +59,14 @@ class BigIntSlugBook(BigIntBase, SlugKey):
title: Mapped[str] = mapped_column(String(length=250)) # pyright: ignore
author_id: Mapped[str] = mapped_column(String(length=250)) # pyright: ignore

@declared_attr.directive
def __table_args__(cls) -> dict | tuple:
return merge_table_arguments(
cls,
SlugKey,
table_args={"comment": "Slugbook"},
)


class BigIntEventLog(BigIntAuditBase):
"""The event log domain object."""
Expand Down
12 changes: 10 additions & 2 deletions tests/models_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from uuid import UUID

from sqlalchemy import Column, FetchedValue, ForeignKey, String, Table, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship

from advanced_alchemy.base import SlugKey, UUIDAuditBase, UUIDBase, UUIDv6Base, UUIDv7Base
from advanced_alchemy.base import SlugKey, UUIDAuditBase, UUIDBase, UUIDv6Base, UUIDv7Base, merge_table_arguments
from advanced_alchemy.repository import (
SQLAlchemyAsyncRepository,
SQLAlchemyAsyncSlugRepository,
Expand Down Expand Up @@ -55,6 +55,14 @@ class UUIDSlugBook(UUIDBase, SlugKey):
title: Mapped[str] = mapped_column(String(length=250)) # pyright: ignore
author_id: Mapped[str] = mapped_column(String(length=250)) # pyright: ignore

@declared_attr.directive
def __table_args__(cls) -> dict | tuple:
return merge_table_arguments(
cls,
SlugKey,
table_args={"comment": "Slugbook"},
)


class UUIDEventLog(UUIDAuditBase):
"""The event log domain object."""
Expand Down
Loading