From 0024745ef94bcfb3b98a6ebeffe904f894ccddfa Mon Sep 17 00:00:00 2001 From: David Danier Date: Mon, 29 Aug 2022 17:20:10 +0200 Subject: [PATCH 01/10] Add failing test for test_get_sqlachemy_type() --- tests/test_get_sqlachemy_type.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/test_get_sqlachemy_type.py diff --git a/tests/test_get_sqlachemy_type.py b/tests/test_get_sqlachemy_type.py new file mode 100644 index 0000000000..dc88bbd3ec --- /dev/null +++ b/tests/test_get_sqlachemy_type.py @@ -0,0 +1,20 @@ +from typing import Any, Dict, List +from unittest.mock import MagicMock + +import pytest +from pydantic.fields import ModelField + +from sqlmodel.main import get_sqlachemy_type + + +@pytest.mark.parametrize( + "input_type", + [ + List[str], + Dict[str, Any], + ], +) +def test_non_type_does_not_break(input_type: type) -> None: + model_field_mock = MagicMock(ModelField, type_=input_type) + with pytest.raises(ValueError): + get_sqlachemy_type(model_field_mock) From 57f5112149e46e8ee6a3e02003064eb74598d706 Mon Sep 17 00:00:00 2001 From: David Danier Date: Mon, 29 Aug 2022 17:21:10 +0200 Subject: [PATCH 02/10] Fix get_sqlachemy_type() not checking for a valid type first --- sqlmodel/main.py | 79 ++++++++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index a5ce8faf74..dfd6dafae6 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -372,45 +372,46 @@ def __init__( def get_sqlachemy_type(field: ModelField) -> Any: - if issubclass(field.type_, str): - if field.field_info.max_length: - return AutoString(length=field.field_info.max_length) - return AutoString - if issubclass(field.type_, float): - return Float - if issubclass(field.type_, bool): - return Boolean - if issubclass(field.type_, int): - return Integer - if issubclass(field.type_, datetime): - return DateTime - if issubclass(field.type_, date): - return Date - if issubclass(field.type_, timedelta): - return Interval - if issubclass(field.type_, time): - return Time - if issubclass(field.type_, Enum): - return sa_Enum(field.type_) - if issubclass(field.type_, bytes): - return LargeBinary - if issubclass(field.type_, Decimal): - return Numeric( - precision=getattr(field.type_, "max_digits", None), - scale=getattr(field.type_, "decimal_places", None), - ) - if issubclass(field.type_, ipaddress.IPv4Address): - return AutoString - if issubclass(field.type_, ipaddress.IPv4Network): - return AutoString - if issubclass(field.type_, ipaddress.IPv6Address): - return AutoString - if issubclass(field.type_, ipaddress.IPv6Network): - return AutoString - if issubclass(field.type_, Path): - return AutoString - if issubclass(field.type_, uuid.UUID): - return GUID + if isinstance(field.type_, type): + if issubclass(field.type_, str): + if field.field_info.max_length: + return AutoString(length=field.field_info.max_length) + return AutoString + if issubclass(field.type_, float): + return Float + if issubclass(field.type_, bool): + return Boolean + if issubclass(field.type_, int): + return Integer + if issubclass(field.type_, datetime): + return DateTime + if issubclass(field.type_, date): + return Date + if issubclass(field.type_, timedelta): + return Interval + if issubclass(field.type_, time): + return Time + if issubclass(field.type_, Enum): + return sa_Enum(field.type_) + if issubclass(field.type_, bytes): + return LargeBinary + if issubclass(field.type_, Decimal): + return Numeric( + precision=getattr(field.type_, "max_digits", None), + scale=getattr(field.type_, "decimal_places", None), + ) + if issubclass(field.type_, ipaddress.IPv4Address): + return AutoString + if issubclass(field.type_, ipaddress.IPv4Network): + return AutoString + if issubclass(field.type_, ipaddress.IPv6Address): + return AutoString + if issubclass(field.type_, ipaddress.IPv6Network): + return AutoString + if issubclass(field.type_, Path): + return AutoString + if issubclass(field.type_, uuid.UUID): + return GUID raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") From 7967bcb48393aeea1e1ece2e4b64b8eb5fb22e36 Mon Sep 17 00:00:00 2001 From: David Danier Date: Mon, 29 Aug 2022 17:21:54 +0200 Subject: [PATCH 03/10] Add Union type to tests --- tests/test_get_sqlachemy_type.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_get_sqlachemy_type.py b/tests/test_get_sqlachemy_type.py index dc88bbd3ec..28bdd56b1c 100644 --- a/tests/test_get_sqlachemy_type.py +++ b/tests/test_get_sqlachemy_type.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Union from unittest.mock import MagicMock import pytest @@ -12,6 +12,7 @@ [ List[str], Dict[str, Any], + Union[int, str], ], ) def test_non_type_does_not_break(input_type: type) -> None: From 55c10b5b97142875044b35ac4c3b91ce4e910de9 Mon Sep 17 00:00:00 2001 From: David Danier Date: Mon, 29 Aug 2022 19:48:05 +0200 Subject: [PATCH 04/10] =?UTF-8?q?Run=20isort=20to=20make=20linter=20happy?= =?UTF-8?q?=20=F0=9F=91=8D=F0=9F=A4=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_get_sqlachemy_type.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_get_sqlachemy_type.py b/tests/test_get_sqlachemy_type.py index 28bdd56b1c..c43515b909 100644 --- a/tests/test_get_sqlachemy_type.py +++ b/tests/test_get_sqlachemy_type.py @@ -3,7 +3,6 @@ import pytest from pydantic.fields import ModelField - from sqlmodel.main import get_sqlachemy_type From 99a837e130112c1b85bdbcc6ddb28e512fa4119c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 22 Oct 2023 21:35:55 +0400 Subject: [PATCH 05/10] =?UTF-8?q?=F0=9F=9A=9A=20Rename=20function=20get=5F?= =?UTF-8?q?sqlaclhemy=5Ftype?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index eb864bc62c..2bbbcc2914 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -373,7 +373,7 @@ def __init__( ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) -def get_sqlachemy_type(field: ModelField) -> Any: +def get_sqlaclhemy_type(field: ModelField) -> Any: if isinstance(field.type_, type): if issubclass(field.type_, str): if field.field_info.max_length: @@ -421,7 +421,7 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore sa_column = getattr(field.field_info, "sa_column", Undefined) if isinstance(sa_column, Column): return sa_column - sa_type = get_sqlalchemy_type(field) + sa_type = get_sqlaclhemy_type(field) primary_key = getattr(field.field_info, "primary_key", False) index = getattr(field.field_info, "index", Undefined) if index is Undefined: From fd816bf0672f279395f112134d3b8cab285d39f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 23 Oct 2023 10:18:35 +0400 Subject: [PATCH 06/10] =?UTF-8?q?=F0=9F=9A=9A=20Rename=20get=5Fsqlalchemy?= =?UTF-8?q?=5Ftype?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 4 ++-- tests/test_get_sqlachemy_type.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 2bbbcc2914..ba8e031d4b 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -373,7 +373,7 @@ def __init__( ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) -def get_sqlaclhemy_type(field: ModelField) -> Any: +def get_sqlalchemy_type(field: ModelField) -> Any: if isinstance(field.type_, type): if issubclass(field.type_, str): if field.field_info.max_length: @@ -421,7 +421,7 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore sa_column = getattr(field.field_info, "sa_column", Undefined) if isinstance(sa_column, Column): return sa_column - sa_type = get_sqlaclhemy_type(field) + sa_type = get_sqlalchemy_type(field) primary_key = getattr(field.field_info, "primary_key", False) index = getattr(field.field_info, "index", Undefined) if index is Undefined: diff --git a/tests/test_get_sqlachemy_type.py b/tests/test_get_sqlachemy_type.py index c43515b909..79da3588a9 100644 --- a/tests/test_get_sqlachemy_type.py +++ b/tests/test_get_sqlachemy_type.py @@ -3,7 +3,7 @@ import pytest from pydantic.fields import ModelField -from sqlmodel.main import get_sqlachemy_type +from sqlmodel.main import get_sqlalchemy_type @pytest.mark.parametrize( @@ -17,4 +17,4 @@ def test_non_type_does_not_break(input_type: type) -> None: model_field_mock = MagicMock(ModelField, type_=input_type) with pytest.raises(ValueError): - get_sqlachemy_type(model_field_mock) + get_sqlalchemy_type(model_field_mock) From 2694e5d1d1eb778eb3f654f915136648df4f1274 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 23 Oct 2023 10:28:34 +0400 Subject: [PATCH 07/10] =?UTF-8?q?=E2=9C=85=20Update=20tests,=20use=20compl?= =?UTF-8?q?ete=20full=20use=20cases?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_get_sqlachemy_type.py | 20 -------------------- tests/test_sqlalchemy_type_errors.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 20 deletions(-) delete mode 100644 tests/test_get_sqlachemy_type.py create mode 100644 tests/test_sqlalchemy_type_errors.py diff --git a/tests/test_get_sqlachemy_type.py b/tests/test_get_sqlachemy_type.py deleted file mode 100644 index 79da3588a9..0000000000 --- a/tests/test_get_sqlachemy_type.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Any, Dict, List, Union -from unittest.mock import MagicMock - -import pytest -from pydantic.fields import ModelField -from sqlmodel.main import get_sqlalchemy_type - - -@pytest.mark.parametrize( - "input_type", - [ - List[str], - Dict[str, Any], - Union[int, str], - ], -) -def test_non_type_does_not_break(input_type: type) -> None: - model_field_mock = MagicMock(ModelField, type_=input_type) - with pytest.raises(ValueError): - get_sqlalchemy_type(model_field_mock) diff --git a/tests/test_sqlalchemy_type_errors.py b/tests/test_sqlalchemy_type_errors.py new file mode 100644 index 0000000000..9a77318e23 --- /dev/null +++ b/tests/test_sqlalchemy_type_errors.py @@ -0,0 +1,28 @@ +from typing import Any, Dict, List, Optional, Union + +import pytest +from sqlmodel import SQLModel, Field + + +def test_type_list_breaks() -> None: + with pytest.raises(ValueError): + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + tags: List[str] + + +def test_type_dict_breaks() -> None: + with pytest.raises(ValueError): + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + tags: Dict[str, Any] + + +def test_type_union_breaks() -> None: + with pytest.raises(ValueError): + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + tags: Union[int, str] From 6ad38cbf0903f4269142ef017c00fb50fcd7e961 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 23 Oct 2023 10:29:12 +0400 Subject: [PATCH 08/10] =?UTF-8?q?=F0=9F=90=9B=20Fix=20implementation=20for?= =?UTF-8?q?=20complete=20use=20cases?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index ba8e031d4b..7dec60ddac 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -374,7 +374,7 @@ def __init__( def get_sqlalchemy_type(field: ModelField) -> Any: - if isinstance(field.type_, type): + if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: if issubclass(field.type_, str): if field.field_info.max_length: return AutoString(length=field.field_info.max_length) From 47eaf081b5e8f1dcbe3f4036a779f404e0f0bcef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 23 Oct 2023 10:32:29 +0400 Subject: [PATCH 09/10] =?UTF-8?q?=F0=9F=8E=A8=20Fix=20format?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_sqlalchemy_type_errors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_sqlalchemy_type_errors.py b/tests/test_sqlalchemy_type_errors.py index 9a77318e23..e3cdf2c925 100644 --- a/tests/test_sqlalchemy_type_errors.py +++ b/tests/test_sqlalchemy_type_errors.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List, Optional, Union import pytest -from sqlmodel import SQLModel, Field + +from sqlmodel import Field, SQLModel def test_type_list_breaks() -> None: From 4e376dfed450e2ead5042877742dc785a216841c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 23 Oct 2023 10:38:13 +0400 Subject: [PATCH 10/10] =?UTF-8?q?=F0=9F=8E=A8=20Format=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_sqlalchemy_type_errors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_sqlalchemy_type_errors.py b/tests/test_sqlalchemy_type_errors.py index e3cdf2c925..e211c46a34 100644 --- a/tests/test_sqlalchemy_type_errors.py +++ b/tests/test_sqlalchemy_type_errors.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Optional, Union import pytest - from sqlmodel import Field, SQLModel