diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7c916f79af..3bbce9eadf 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -3,6 +3,7 @@ import ipaddress import uuid import weakref +from dataclasses import dataclass from datetime import date, datetime, time, timedelta from decimal import Decimal from enum import Enum @@ -347,6 +348,38 @@ def Field( ) -> Any: ... +@dataclass +class FieldInfoMetadata: + primary_key: Union[bool, UndefinedType] = Undefined + nullable: Union[bool, UndefinedType] = Undefined + foreign_key: Any = Undefined + ondelete: Union[OnDeleteType, UndefinedType] = Undefined + unique: Union[bool, UndefinedType] = Undefined + index: Union[bool, UndefinedType] = Undefined + sa_type: Union[Type[Any], UndefinedType] = Undefined + sa_column: Union[Column[Any], UndefinedType] = Undefined + sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined + sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined + + +def _get_sqlmodel_field_metadata(field_info: Any) -> Optional[FieldInfoMetadata]: + metadata_items = getattr(field_info, "metadata", None) + if metadata_items: + for meta in metadata_items: + if isinstance(meta, FieldInfoMetadata): + return meta + return None + + +def _get_sqlmodel_field_value( + field_info: Any, attribute: str, default: Any = Undefined +) -> Any: + metadata = _get_sqlmodel_field_metadata(field_info) + if metadata is not None and hasattr(metadata, attribute): + return getattr(metadata, attribute) + return getattr(field_info, attribute, default) + + def Field( default: Any = Undefined, *, @@ -427,6 +460,20 @@ def Field( sa_column_kwargs=sa_column_kwargs, **current_schema_extra, ) + field_metadata = FieldInfoMetadata( + primary_key=primary_key, + nullable=nullable, + foreign_key=foreign_key, + ondelete=ondelete, + unique=unique, + index=index, + sa_type=sa_type, + sa_column=sa_column, + sa_column_args=sa_column_args, + sa_column_kwargs=sa_column_kwargs, + ) + if hasattr(field_info, "metadata"): + field_info.metadata.append(field_metadata) # type: ignore[attr-defined] post_init_field_info(field_info) return field_info @@ -651,7 +698,7 @@ def get_sqlalchemy_type(field: Any) -> Any: field_info = field else: field_info = field.field_info - sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009 + sa_type = _get_sqlmodel_field_value(field_info, "sa_type", Undefined) if sa_type is not Undefined: return sa_type @@ -708,39 +755,39 @@ def get_column_from_field(field: Any) -> Column: # type: ignore field_info = field else: field_info = field.field_info - sa_column = getattr(field_info, "sa_column", Undefined) + sa_column = _get_sqlmodel_field_value(field_info, "sa_column", Undefined) if isinstance(sa_column, Column): return sa_column sa_type = get_sqlalchemy_type(field) - primary_key = getattr(field_info, "primary_key", Undefined) + primary_key = _get_sqlmodel_field_value(field_info, "primary_key", Undefined) if primary_key is Undefined: primary_key = False - index = getattr(field_info, "index", Undefined) + index = _get_sqlmodel_field_value(field_info, "index", Undefined) if index is Undefined: index = False nullable = not primary_key and is_field_noneable(field) # Override derived nullability if the nullable property is set explicitly # on the field - field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009 + field_nullable = _get_sqlmodel_field_value(field_info, "nullable", Undefined) if field_nullable is not Undefined: assert not isinstance(field_nullable, UndefinedType) nullable = field_nullable args = [] - foreign_key = getattr(field_info, "foreign_key", Undefined) + foreign_key = _get_sqlmodel_field_value(field_info, "foreign_key", Undefined) if foreign_key is Undefined: foreign_key = None - unique = getattr(field_info, "unique", Undefined) + unique = _get_sqlmodel_field_value(field_info, "unique", Undefined) if unique is Undefined: unique = False if foreign_key: - if field_info.ondelete == "SET NULL" and not nullable: + ondelete_value = _get_sqlmodel_field_value(field_info, "ondelete", Undefined) + if ondelete_value is Undefined: + ondelete_value = None + if ondelete_value == "SET NULL" and not nullable: raise RuntimeError('ondelete="SET NULL" requires nullable=True') assert isinstance(foreign_key, str) - ondelete = getattr(field_info, "ondelete", Undefined) - if ondelete is Undefined: - ondelete = None - assert isinstance(ondelete, (str, type(None))) # for typing - args.append(ForeignKey(foreign_key, ondelete=ondelete)) + assert isinstance(ondelete_value, (str, type(None))) # for typing + args.append(ForeignKey(foreign_key, ondelete=ondelete_value)) kwargs = { "primary_key": primary_key, "nullable": nullable, @@ -754,10 +801,12 @@ def get_column_from_field(field: Any) -> Column: # type: ignore sa_default = field_info.default if sa_default is not Undefined: kwargs["default"] = sa_default - sa_column_args = getattr(field_info, "sa_column_args", Undefined) + sa_column_args = _get_sqlmodel_field_value(field_info, "sa_column_args", Undefined) if sa_column_args is not Undefined: args.extend(list(cast(Sequence[Any], sa_column_args))) - sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined) + sa_column_kwargs = _get_sqlmodel_field_value( + field_info, "sa_column_kwargs", Undefined + ) if sa_column_kwargs is not Undefined: kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) return Column(sa_type, *args, **kwargs) # type: ignore diff --git a/tests/test_field_sa_column.py b/tests/test_field_sa_column.py index e2ccc6d7ef..b9eddfe38b 100644 --- a/tests/test_field_sa_column.py +++ b/tests/test_field_sa_column.py @@ -3,6 +3,7 @@ import pytest from sqlalchemy import Column, Integer, String from sqlmodel import Field, SQLModel +from typing_extensions import Annotated def test_sa_column_takes_precedence() -> None: @@ -17,6 +18,17 @@ class Item(SQLModel, table=True): assert isinstance(Item.id.type, String) # type: ignore +def test_sa_column_with_annotated_metadata() -> None: + class Item(SQLModel, table=True): + id: Annotated[Optional[int], "meta"] = Field( + default=None, + sa_column=Column(String, primary_key=True, nullable=False), + ) + + assert Item.id.nullable is False # type: ignore + assert isinstance(Item.id.type, String) # type: ignore + + def test_sa_column_no_sa_args() -> None: with pytest.raises(RuntimeError): diff --git a/tests/test_main.py b/tests/test_main.py index 60d5c40ebb..54760f0b9d 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,6 +4,9 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import RelationshipProperty from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from typing_extensions import Annotated + +from tests.conftest import needs_pydanticv2 def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel): @@ -125,3 +128,96 @@ class Hero(SQLModel, table=True): # The next statement should not raise an AttributeError assert hero_rusty_man.team assert hero_rusty_man.team.name == "Preventers" + + +def test_composite_primary_key(clear_sqlmodel): + class UserPermission(SQLModel, table=True): + user_id: int = Field(primary_key=True) + resource_id: int = Field(primary_key=True) + permission: str + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + + pk_column_names = {column.name for column in UserPermission.__table__.primary_key} + assert pk_column_names == {"user_id", "resource_id"} + + with Session(engine) as session: + perm1 = UserPermission(user_id=1, resource_id=1, permission="read") + perm2 = UserPermission(user_id=1, resource_id=2, permission="write") + session.add(perm1) + session.add(perm2) + session.commit() + + with pytest.raises(IntegrityError): + with Session(engine) as session: + perm3 = UserPermission(user_id=1, resource_id=1, permission="admin") + session.add(perm3) + session.commit() + + +@needs_pydanticv2 +def test_composite_primary_key_and_validator(clear_sqlmodel): + from pydantic import AfterValidator + + def validate_resource_id(value: int) -> int: + if value < 1: + raise ValueError("Resource ID must be positive") + return value + + class UserPermission(SQLModel, table=True): + user_id: int = Field(primary_key=True) + resource_id: Annotated[int, AfterValidator(validate_resource_id)] = Field( + primary_key=True + ) + permission: str + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + + pk_column_names = {column.name for column in UserPermission.__table__.primary_key} + assert pk_column_names == {"user_id", "resource_id"} + + with Session(engine) as session: + perm1 = UserPermission(user_id=1, resource_id=1, permission="read") + perm2 = UserPermission(user_id=1, resource_id=2, permission="write") + session.add(perm1) + session.add(perm2) + session.commit() + + with pytest.raises(IntegrityError): + with Session(engine) as session: + perm3 = UserPermission(user_id=1, resource_id=1, permission="admin") + session.add(perm3) + session.commit() + + +@needs_pydanticv2 +def test_foreign_key_ondelete_with_annotated(clear_sqlmodel): + from pydantic import AfterValidator + + def ensure_positive(value: int) -> int: + if value < 0: + raise ValueError("Team ID must be positive") + return value + + class Team(SQLModel, table=True): + id: int = Field(primary_key=True) + name: str + + class Hero(SQLModel, table=True): + id: int = Field(primary_key=True) + team_id: Annotated[int, AfterValidator(ensure_positive)] = Field( + foreign_key="team.id", + ondelete="CASCADE", + ) + name: str + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + + team_id_column = Hero.__table__.c.team_id # type: ignore[attr-defined] + foreign_keys = list(team_id_column.foreign_keys) + assert len(foreign_keys) == 1 + assert foreign_keys[0].ondelete == "CASCADE" + assert team_id_column.nullable is False