From 2b39e3281d10a08ec9cc7c35a7881ff97aaf3045 Mon Sep 17 00:00:00 2001 From: Raphael Gibson Date: Tue, 7 Sep 2021 00:20:17 -0300 Subject: [PATCH 1/2] feat: add unique constraint param to Field function --- sqlmodel/main.py | 6 +++ tests/test_main.py | 91 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 tests/test_main.py diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 4d6d2f2712..33ab957d3c 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -69,6 +69,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) nullable = kwargs.pop("nullable", Undefined) foreign_key = kwargs.pop("foreign_key", Undefined) + unique = kwargs.pop("unique", False) index = kwargs.pop("index", Undefined) sa_column = kwargs.pop("sa_column", Undefined) sa_column_args = kwargs.pop("sa_column_args", Undefined) @@ -88,6 +89,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: self.primary_key = primary_key self.nullable = nullable self.foreign_key = foreign_key + self.unique = unique self.index = index self.sa_column = sa_column self.sa_column_args = sa_column_args @@ -149,6 +151,7 @@ def Field( regex: Optional[str] = None, primary_key: bool = False, foreign_key: Optional[Any] = None, + unique: bool = False, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore @@ -179,6 +182,7 @@ def Field( regex=regex, primary_key=primary_key, foreign_key=foreign_key, + unique=unique, nullable=nullable, index=index, sa_column=sa_column, @@ -432,12 +436,14 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore nullable = field_nullable args = [] foreign_key = getattr(field.field_info, "foreign_key", None) + unique = getattr(field.field_info, "unique", False) if foreign_key: args.append(ForeignKey(foreign_key)) kwargs = { "primary_key": primary_key, "nullable": nullable, "index": index, + "unique": unique } sa_default = Undefined if field.field_info.default_factory: diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000000..65ad0d9b56 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,91 @@ +import pytest +from typing import Optional + +from sqlmodel import Field, Session, SQLModel, create_engine +from sqlalchemy.exc import IntegrityError + + +def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel): + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None + + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero_1) + session.commit() + session.refresh(hero_1) + + with Session(engine) as session: + session.add(hero_2) + session.commit() + session.refresh(hero_2) + + with Session(engine) as session: + heroes = session.query(Hero).all() + assert len(heroes) == 2 + assert heroes[0].name == heroes[1].name + + +def test_should_allow_duplicate_row_if_unique_constraint_is_false(clear_sqlmodel): + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str = Field(unique=False) + age: Optional[int] = None + + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero_1) + session.commit() + session.refresh(hero_1) + + with Session(engine) as session: + session.add(hero_2) + session.commit() + session.refresh(hero_2) + + with Session(engine) as session: + heroes = session.query(Hero).all() + assert len(heroes) == 2 + assert heroes[0].name == heroes[1].name + + +def test_should_raise_exception_when_try_to_duplicate_row_if_unique_constraint_is_true(clear_sqlmodel): + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str = Field(unique=True) + age: Optional[int] = None + + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero_1) + session.commit() + session.refresh(hero_1) + + with pytest.raises(IntegrityError): + with Session(engine) as session: + session.add(hero_2) + session.commit() + session.refresh(hero_2) From 9f80059870961f7b9af13f547bddf77151433690 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 28 Aug 2022 01:47:38 +0200 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=8E=A8=20Format=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 2 +- tests/test_main.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 092400a63b..7c79edd2e3 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -437,7 +437,7 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore "primary_key": primary_key, "nullable": nullable, "index": index, - "unique": unique + "unique": unique, } sa_default = Undefined if field.field_info.default_factory: diff --git a/tests/test_main.py b/tests/test_main.py index 65ad0d9b56..22c62327da 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,8 +1,8 @@ -import pytest from typing import Optional -from sqlmodel import Field, Session, SQLModel, create_engine +import pytest from sqlalchemy.exc import IntegrityError +from sqlmodel import Field, Session, SQLModel, create_engine def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel): @@ -65,7 +65,9 @@ class Hero(SQLModel, table=True): assert heroes[0].name == heroes[1].name -def test_should_raise_exception_when_try_to_duplicate_row_if_unique_constraint_is_true(clear_sqlmodel): +def test_should_raise_exception_when_try_to_duplicate_row_if_unique_constraint_is_true( + clear_sqlmodel, +): class Hero(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str