Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ sqlalchemy2-stubs = {version = "*", allow-prereleases = true}

[tool.poetry.dev-dependencies]
pytest = "^6.2.4"
mypy = "^0.812"
mypy = "^0.910"
flake8 = "^3.9.2"
black = {version = "^21.5-beta.1", python = "^3.7"}
mkdocs = "^1.2.1"
Expand Down Expand Up @@ -98,3 +98,7 @@ warn_return_any = true
implicit_reexport = false
strict_equality = true
# --strict end

[[tool.mypy.overrides]]
module = "sqlmodel.sql.expression"
warn_unused_ignores = false
2 changes: 1 addition & 1 deletion sqlmodel/engine/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,4 @@ def create_engine(
if not isinstance(query_cache_size, _DefaultPlaceholder):
current_kwargs["query_cache_size"] = query_cache_size
current_kwargs.update(kwargs)
return _create_engine(url, **current_kwargs)
return _create_engine(url, **current_kwargs) # type: ignore
8 changes: 4 additions & 4 deletions sqlmodel/engine/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __iter__(self) -> Iterator[_T]:
return super().__iter__()

def __next__(self) -> _T:
return super().__next__()
return super().__next__() # type: ignore

def first(self) -> Optional[_T]:
return super().first()
Expand All @@ -32,7 +32,7 @@ def one_or_none(self) -> Optional[_T]:
return super().one_or_none()

def one(self) -> _T:
return super().one()
return super().one() # type: ignore


class Result(_Result, Generic[_T]):
Expand Down Expand Up @@ -70,10 +70,10 @@ def scalar_one(self) -> _T:
return super().scalar_one() # type: ignore

def scalar_one_or_none(self) -> Optional[_T]:
return super().scalar_one_or_none() # type: ignore
return super().scalar_one_or_none()

def one(self) -> _T: # type: ignore
return super().one() # type: ignore

def scalar(self) -> Optional[_T]:
return super().scalar() # type: ignore
return super().scalar()
4 changes: 2 additions & 2 deletions sqlmodel/ext/asyncio/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
**kw,
**kw: Any,
):
# All the same code of the original AsyncSession
kw["future"] = True
Expand Down Expand Up @@ -52,7 +52,7 @@ async def exec(
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore

return await greenlet_spawn( # type: ignore
return await greenlet_spawn(
self.sync_session.exec,
statement,
params=params,
Expand Down
73 changes: 42 additions & 31 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None,
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> None:
Expand All @@ -127,32 +127,32 @@ def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: str = None,
title: str = None,
description: str = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: bool = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
multiple_of: float = None,
min_items: int = None,
max_items: int = None,
min_length: int = None,
max_length: int = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: str = None,
regex: Optional[str] = None,
primary_key: bool = False,
foreign_key: Optional[Any] = None,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -195,7 +195,7 @@ def Relationship(
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None,
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any:
Expand All @@ -217,19 +217,25 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):

# Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None:
if getattr(cls.__config__, "table", False): # type: ignore
if getattr(cls.__config__, "table", False):
DeclarativeMeta.__setattr__(cls, name, value)
else:
super().__setattr__(name, value)

def __delattr__(cls, name: str) -> None:
if getattr(cls.__config__, "table", False): # type: ignore
if getattr(cls.__config__, "table", False):
DeclarativeMeta.__delattr__(cls, name)
else:
super().__delattr__(name)

# From Pydantic
def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any:
def __new__(
cls,
name: str,
bases: Tuple[Type[Any], ...],
class_dict: Dict[str, Any],
**kwargs: Any,
) -> Any:
relationships: Dict[str, RelationshipInfo] = {}
dict_for_pydantic = {}
original_annotations = resolve_annotations(
Expand Down Expand Up @@ -342,7 +348,7 @@ def __init__(
)
relationship_to = temp_field.type_
if isinstance(temp_field.type_, ForwardRef):
relationship_to = temp_field.type_.__forward_arg__ # type: ignore
relationship_to = temp_field.type_.__forward_arg__
rel_kwargs: Dict[str, Any] = {}
if rel_info.back_populates:
rel_kwargs["back_populates"] = rel_info.back_populates
Expand All @@ -360,7 +366,7 @@ def __init__(
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value: RelationshipProperty = relationship(
rel_value: RelationshipProperty = relationship( # type: ignore
relationship_to, *rel_args, **rel_kwargs
)
dict_used[rel_name] = rel_value
Expand Down Expand Up @@ -408,7 +414,7 @@ def get_sqlachemy_type(field: ModelField) -> Any:
return GUID


def get_column_from_field(field: ModelField) -> Column:
def get_column_from_field(field: ModelField) -> Column: # type: ignore
sa_column = getattr(field.field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
return sa_column
Expand Down Expand Up @@ -440,10 +446,10 @@ def get_column_from_field(field: ModelField) -> Column:
kwargs["default"] = sa_default
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
if sa_column_args is not Undefined:
args.extend(list(cast(Sequence, sa_column_args)))
args.extend(list(cast(Sequence[Any], sa_column_args)))
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
if sa_column_kwargs is not Undefined:
kwargs.update(cast(dict, sa_column_kwargs))
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
return Column(sa_type, *args, **kwargs)


Expand All @@ -452,24 +458,27 @@ def get_column_from_field(field: ModelField) -> Column:
default_registry = registry()


def _value_items_is_true(v) -> bool:
def _value_items_is_true(v: Any) -> bool:
# Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of
# the current latest, Pydantic 1.8.2
return v is True or v is ...


_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")


class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore
__name__: ClassVar[str]
metadata: ClassVar[MetaData]

class Config:
orm_mode = True

def __new__(cls, *args, **kwargs) -> Any:
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
new_object = super().__new__(cls)
# SQLAlchemy doesn't call __init__ on the base class
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
Expand Down Expand Up @@ -520,7 +529,9 @@ def __setattr__(self, name: str, value: Any) -> None:
super().__setattr__(name, value)

@classmethod
def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
def from_orm(
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
) -> _TSQLModel:
# Duplicated from Pydantic
if not cls.__config__.orm_mode:
raise ConfigError(
Expand All @@ -533,7 +544,7 @@ def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
# End SQLModel support dict
if not getattr(cls.__config__, "table", False):
# If not table, normal Pydantic code
m = cls.__new__(cls)
m: _TSQLModel = cls.__new__(cls)
else:
# If table, create the new instance normally to make SQLAlchemy create
# the _sa_instance_state attribute
Expand All @@ -554,7 +565,7 @@ def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):

@classmethod
def parse_obj(
cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None
cls: Type["SQLModel"], obj: Any, update: Optional[Dict[str, Any]] = None
) -> "SQLModel":
obj = cls._enforce_dict_if_root(obj)
# SQLModel, support update dict
Expand Down
6 changes: 3 additions & 3 deletions sqlmodel/orm/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def exec(
results = super().execute(
statement,
params=params,
execution_options=execution_options, # type: ignore
execution_options=execution_options,
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
Expand All @@ -74,7 +74,7 @@ def execute(
self,
statement: _Executable,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
Expand All @@ -101,7 +101,7 @@ def execute(
return super().execute( # type: ignore
statement,
params=params,
execution_options=execution_options, # type: ignore
execution_options=execution_options,
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
Expand Down
4 changes: 1 addition & 3 deletions sqlmodel/sql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,4 @@


class Executable(_Executable, Generic[_T]):
def __init__(self, *args, **kwargs):
self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
super(_Executable, self).__init__(*args, **kwargs)
pass
32 changes: 16 additions & 16 deletions sqlmodel/sql/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ class SelectOfScalar(_Select, Generic[_TSelect]):
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
pass

class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
pass

class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
pass

# Cast them for editors to work correctly, from several tricks tried, this works
Expand All @@ -65,9 +65,9 @@ class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMet

_TScalar_0 = TypeVar(
"_TScalar_0",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
Expand All @@ -83,9 +83,9 @@ class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMet

_TScalar_1 = TypeVar(
"_TScalar_1",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
Expand All @@ -101,9 +101,9 @@ class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMet

_TScalar_2 = TypeVar(
"_TScalar_2",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
Expand All @@ -119,9 +119,9 @@ class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMet

_TScalar_3 = TypeVar(
"_TScalar_3",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
Expand Down Expand Up @@ -446,14 +446,14 @@ def select( # type: ignore
# Generated overloads end


def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
if len(entities) == 1:
return SelectOfScalar._create(*entities, **kw) # type: ignore
return Select._create(*entities, **kw) # type: ignore


# TODO: add several @overload from Python types to SQLAlchemy equivalents
def col(column_expression: Any) -> ColumnClause:
def col(column_expression: Any) -> ColumnClause: # type: ignore
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression
Loading