diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 661276b31d..ac409ac6a2 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -53,6 +53,7 @@ from .sql.sqltypes import GUID, AutoString _T = TypeVar("_T") +_Model = TypeVar("_Model", bound="SQLModel") def __dataclass_transform__( @@ -520,7 +521,7 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) @classmethod - def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None): + def from_orm(cls: Type["_Model"], obj: Any, update: Dict[str, Any] = None) -> "_Model": # Duplicated from Pydantic if not cls.__config__.orm_mode: raise ConfigError( @@ -554,8 +555,8 @@ def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None): @classmethod def parse_obj( - cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None - ) -> "SQLModel": + cls: Type["_Model"], obj: Any, update: Dict[str, Any] = None + ) -> "_Model": obj = cls._enforce_dict_if_root(obj) # SQLModel, support update dict if update is not None: @@ -569,7 +570,7 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: # From Pydantic, override to enforce validation with dict @classmethod - def validate(cls: Type["SQLModel"], value: Any) -> "SQLModel": + def validate(cls: Type["_Model"], value: Any) -> "_Model": if isinstance(value, cls): return value.copy() if cls.__config__.copy_on_model_validation else value