From f48e1d880bebf0e9dfd2d6b05b0f3bc83cf01287 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 28 Oct 2023 17:28:33 +0400 Subject: [PATCH 1/6] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Make=20sa=5Fcolumn=20e?= =?UTF-8?q?xclusive,=20do=20not=20allow=20incompatible=20arguments,=20sa?= =?UTF-8?q?=5Fcolumn=5Fargs,=20primary=5Fkey,=20etc?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 130 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 120 insertions(+), 10 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3015aa9fbd..66536d4143 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -22,6 +22,7 @@ TypeVar, Union, cast, + overload, ) from pydantic import BaseConfig, BaseModel @@ -87,6 +88,28 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: "Passing sa_column_kwargs is not supported when " "also passing a sa_column" ) + if primary_key is not Undefined: + raise RuntimeError( + "Passing primary_key is not supported when " + "also passing a sa_column" + ) + if nullable is not Undefined: + raise RuntimeError( + "Passing nullable is not supported when " "also passing a sa_column" + ) + if foreign_key is not Undefined: + raise RuntimeError( + "Passing foreign_key is not supported when " + "also passing a sa_column" + ) + if unique is not Undefined: + raise RuntimeError( + "Passing unique is not supported when " "also passing a sa_column" + ) + if index is not Undefined: + raise RuntimeError( + "Passing index is not supported when " "also passing a sa_column" + ) super().__init__(default=default, **kwargs) self.primary_key = primary_key self.nullable = nullable @@ -126,6 +149,86 @@ def __init__( self.sa_relationship_kwargs = sa_relationship_kwargs +@overload +def Field( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + include: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + unique_items: Optional[bool] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, + primary_key: Union[bool, UndefinedType] = Undefined, + foreign_key: Any = Undefined, + unique: Union[bool, UndefinedType] = Undefined, + nullable: Union[bool, UndefinedType] = Undefined, + index: Union[bool, UndefinedType] = Undefined, + sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, + sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + schema_extra: Optional[Dict[str, Any]] = None, +) -> Any: + ... + + +@overload +def Field( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + include: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + unique_items: Optional[bool] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, + sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore + schema_extra: Optional[Dict[str, Any]] = None, +) -> Any: + ... + + def Field( default: Any = Undefined, *, @@ -156,9 +259,9 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - primary_key: bool = False, - foreign_key: Optional[Any] = None, - unique: bool = False, + primary_key: Union[bool, UndefinedType] = Undefined, + foreign_key: Any = Undefined, + unique: Union[bool, UndefinedType] = Undefined, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore @@ -440,21 +543,28 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore if isinstance(sa_column, Column): return sa_column sa_type = get_sqlalchemy_type(field) - primary_key = getattr(field.field_info, "primary_key", False) + primary_key = getattr(field.field_info, "primary_key", Undefined) + if primary_key is Undefined: + primary_key = False index = getattr(field.field_info, "index", Undefined) if index is Undefined: index = False nullable = not primary_key and _is_field_noneable(field) # Override derived nullability if the nullable property is set explicitly # on the field - if hasattr(field.field_info, "nullable"): - field_nullable = getattr(field.field_info, "nullable") # noqa: B009 - if field_nullable != Undefined: - nullable = field_nullable + field_nullable = getattr(field.field_info, "nullable", Undefined) # noqa: B009 + if field_nullable != Undefined: + assert not isinstance(field_nullable, UndefinedType) + nullable = field_nullable args = [] - foreign_key = getattr(field.field_info, "foreign_key", None) - unique = getattr(field.field_info, "unique", False) + foreign_key = getattr(field.field_info, "foreign_key", Undefined) + if foreign_key is Undefined: + foreign_key = None + unique = getattr(field.field_info, "unique", Undefined) + if unique is Undefined: + unique = False if foreign_key: + assert isinstance(foreign_key, str) args.append(ForeignKey(foreign_key)) kwargs = { "primary_key": primary_key, From 41ff76664db9680e690665bffade7bd2a312521f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 28 Oct 2023 17:29:25 +0400 Subject: [PATCH 2/6] =?UTF-8?q?=E2=9C=85=20Add=20tests=20for=20new=20error?= =?UTF-8?q?s=20when=20incorrectly=20using=20sa=5Fcolumn?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_field_sa_column.py | 99 +++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 tests/test_field_sa_column.py diff --git a/tests/test_field_sa_column.py b/tests/test_field_sa_column.py new file mode 100644 index 0000000000..51cfdfa797 --- /dev/null +++ b/tests/test_field_sa_column.py @@ -0,0 +1,99 @@ +from typing import Optional + +import pytest +from sqlalchemy import Column, Integer, String +from sqlmodel import Field, SQLModel + + +def test_sa_column_takes_precedence() -> None: + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column=Column(String, primary_key=True, nullable=False), + ) + + # It would have been nullable with no sa_column + assert Item.id.nullable is False # type: ignore + assert isinstance(Item.id.type, String) # type: ignore + + +def test_sa_column_no_sa_args() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column_args=[Integer], + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_sa_kargs() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column_kwargs={"primary_key": True}, + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_primary_key() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + primary_key=True, + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_nullable() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + nullable=True, + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_foreign_key() -> None: + with pytest.raises(RuntimeError): + + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + team_id: Optional[int] = Field( + default=None, + foreign_key="team.id", + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_unique() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + unique=True, + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_index() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + index=True, + sa_column=Column(Integer, primary_key=True), + ) From 73e52967316a23384e1a695338bbedfda832dfbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 28 Oct 2023 17:29:55 +0400 Subject: [PATCH 3/6] =?UTF-8?q?=E2=9C=85=20Add=20tests=20for=20sa=5Fcolumn?= =?UTF-8?q?=5Fargs=20and=20sa=5Fcolumn=5Fkwargs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_field_sa_args_kwargs.py | 39 ++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/test_field_sa_args_kwargs.py diff --git a/tests/test_field_sa_args_kwargs.py b/tests/test_field_sa_args_kwargs.py new file mode 100644 index 0000000000..cf70565a94 --- /dev/null +++ b/tests/test_field_sa_args_kwargs.py @@ -0,0 +1,39 @@ +from typing import Optional + +from sqlalchemy import Column, ForeignKey +from sqlmodel import Field, SQLModel, create_engine + + +def test_sa_column_args(clear_sqlmodel, caplog) -> None: + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + team_id: Optional[int] = Field( + default=None, + sa_column=Column(ForeignKey("team.id")), + ) + + engine = create_engine("sqlite://", echo=True) + SQLModel.metadata.create_all(engine) + create_table_log = [ + message for message in caplog.messages if "CREATE TABLE hero" in message + ][0] + assert "FOREIGN KEY(team_id) REFERENCES team (id)" in create_table_log + + +def test_sa_column_kargs(clear_sqlmodel, caplog) -> None: + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column_kwargs={"primary_key": True}, + ) + + engine = create_engine("sqlite://", echo=True) + SQLModel.metadata.create_all(engine) + create_table_log = [ + message for message in caplog.messages if "CREATE TABLE item" in message + ][0] + assert "PRIMARY KEY (id)" in create_table_log From 41f16355ee91cad805c9a4d9cb9029a3aef3cd6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 28 Oct 2023 17:45:46 +0400 Subject: [PATCH 4/6] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Do=20not=20allow=20sa?= =?UTF-8?q?=5Frelationship=20with=20sa=5Frelationship=5Fargs=20or=20sa=5Fr?= =?UTF-8?q?elationship=5Fkwargs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 66536d4143..f48e388e13 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -309,6 +309,27 @@ def Field( return field_info +@overload +def Relationship( + *, + back_populates: Optional[str] = None, + link_model: Optional[Any] = None, + sa_relationship_args: Optional[Sequence[Any]] = None, + sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, +) -> Any: + ... + + +@overload +def Relationship( + *, + back_populates: Optional[str] = None, + link_model: Optional[Any] = None, + sa_relationship: Optional[RelationshipProperty] = None, # type: ignore +) -> Any: + ... + + def Relationship( *, back_populates: Optional[str] = None, From beafb6a8a1d64e9439c502f66721b5f210b4be20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 28 Oct 2023 17:46:26 +0400 Subject: [PATCH 5/6] =?UTF-8?q?=E2=9C=85=20Add=20tests=20for=20relationshi?= =?UTF-8?q?p=20errors?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_field_sa_relationship.py | 53 +++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 tests/test_field_sa_relationship.py diff --git a/tests/test_field_sa_relationship.py b/tests/test_field_sa_relationship.py new file mode 100644 index 0000000000..7606fd86d8 --- /dev/null +++ b/tests/test_field_sa_relationship.py @@ -0,0 +1,53 @@ +from typing import List, Optional + +import pytest +from sqlalchemy.orm import relationship +from sqlmodel import Field, Relationship, SQLModel + + +def test_sa_relationship_no_args() -> None: + with pytest.raises(RuntimeError): + + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + headquarters: str + + heroes: List["Hero"] = Relationship( + back_populates="team", + sa_relationship_args=["Hero"], + sa_relationship=relationship("Hero", back_populates="team"), + ) + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + secret_name: str + age: Optional[int] = Field(default=None, index=True) + + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + team: Optional[Team] = Relationship(back_populates="heroes") + + +def test_sa_relationship_no_kwargs() -> None: + with pytest.raises(RuntimeError): + + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + headquarters: str + + heroes: List["Hero"] = Relationship( + back_populates="team", + sa_relationship_kwargs={"lazy": "selectin"}, + sa_relationship=relationship("Hero", back_populates="team"), + ) + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + secret_name: str + age: Optional[int] = Field(default=None, index=True) + + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + team: Optional[Team] = Relationship(back_populates="heroes") From 306f202b3e9c35b8f1a6668d2718bfac086fca06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 28 Oct 2023 17:46:42 +0400 Subject: [PATCH 6/6] =?UTF-8?q?=E2=9C=85=20Fix=20test=20for=20sa=5Fcolumn?= =?UTF-8?q?=5Fargs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_field_sa_args_kwargs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_field_sa_args_kwargs.py b/tests/test_field_sa_args_kwargs.py index cf70565a94..94a1a13483 100644 --- a/tests/test_field_sa_args_kwargs.py +++ b/tests/test_field_sa_args_kwargs.py @@ -1,6 +1,6 @@ from typing import Optional -from sqlalchemy import Column, ForeignKey +from sqlalchemy import ForeignKey from sqlmodel import Field, SQLModel, create_engine @@ -13,7 +13,7 @@ class Hero(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) team_id: Optional[int] = Field( default=None, - sa_column=Column(ForeignKey("team.id")), + sa_column_args=[ForeignKey("team.id")], ) engine = create_engine("sqlite://", echo=True)