Skip to content

Commit

Permalink
Add support for nested generics (pydantic#1104)
Browse files Browse the repository at this point in the history
* Add support for nested generics

* Allow instantiation of unparameterized generics

* Add better more partial instantiation tests

* Add changes

* Add docs
  • Loading branch information
dmontagu authored and andreshndz committed Jan 17, 2020
1 parent d943d95 commit 0b37746
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 39 deletions.
1 change: 1 addition & 0 deletions changes/1104-dmontagu.md
@@ -0,0 +1 @@
Add support for nested generic models
21 changes: 21 additions & 0 deletions docs/examples/models_generics_nested.py
@@ -0,0 +1,21 @@
from typing import Generic, TypeVar

from pydantic import ValidationError
from pydantic.generics import GenericModel

T = TypeVar('T')

class InnerT(GenericModel, Generic[T]):
inner: T

class OuterT(GenericModel, Generic[T]):
outer: T
nested: InnerT[T]

nested = InnerT[int](inner=1)
print(OuterT[int](outer=1, nested=nested))
try:
nested = InnerT[str](inner='a')
print(OuterT[int](outer='a', nested=nested))
except ValidationError as e:
print(e)
24 changes: 24 additions & 0 deletions docs/examples/models_generics_typevars.py
@@ -0,0 +1,24 @@
from typing import Generic, TypeVar

from pydantic import ValidationError
from pydantic.generics import GenericModel

AT = TypeVar('AT')
BT = TypeVar('BT')

class Model(GenericModel, Generic[AT, BT]):
a: AT
b: BT

print(Model(a='a', b='a'))

IntT = TypeVar('IntT', bound=int)
typevar_model = Model[int, IntT]
print(typevar_model(a=1, b=1))
try:
typevar_model(a='a', b='a')
except ValidationError as exc:
print(exc)

concrete_model = typevar_model[int]
print(concrete_model(a=1, b=1))
20 changes: 20 additions & 0 deletions docs/usage/models.md
Expand Up @@ -303,6 +303,26 @@ If the name of the concrete subclasses is important, you can also override the d
```
_(This script is complete, it should run "as is")_

Using the same TypeVar in nested models allows you to enforce typing relationships at different points in your model:

```py
{!.tmp_examples/models_generics_nested.py!}
```
_(This script is complete, it should run "as is")_

Pydantic also treats `GenericModel` similarly to how it treats built-in generic types like `List` and `Dict` when it
comes to leaving them unparameterized, or using bounded `TypeVar` instances:

* If you don't specify parameters before instantiating the generic model, they will be treated as `Any`
* You can parametrize models with one or more *bounded* parameters to add subclass checks

Also, like `List` and `Dict`, any parameters specified using a `TypeVar` can later be substituted with concrete types.

```py
{!.tmp_examples/models_generics_typevars.py!}
```
_(This script is complete, it should run "as is")_

## Dynamic model creation

There are some occasions where the shape of a model is not known until runtime. For this *pydantic* provides
Expand Down
72 changes: 52 additions & 20 deletions pydantic/generics.py
@@ -1,21 +1,24 @@
from typing import Any, ClassVar, Dict, Generic, Tuple, Type, TypeVar, Union, get_type_hints
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type, TypeVar, Union, cast, get_type_hints

from .class_validators import gather_all_validators
from .fields import FieldInfo, ModelField
from .main import BaseModel, create_model
from .utils import lenient_issubclass

_generic_types_cache: Dict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = {}
GenericModelT = TypeVar('GenericModelT', bound='GenericModel')
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type


class GenericModel(BaseModel):
__slots__ = ()
__concrete__: ClassVar[bool] = False

def __new__(cls, *args: Any, **kwargs: Any) -> Any:
if cls.__concrete__:
return super().__new__(cls)
else:
raise TypeError(f'Type {cls.__name__} cannot be used without generic parameters, e.g. {cls.__name__}[T]')
if TYPE_CHECKING:
# Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with
# `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of
# `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below.
__parameters__: ClassVar[Tuple[TypeVarType, ...]]

# Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings
def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]:
Expand All @@ -28,11 +31,11 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T
params = (params,)
if cls is GenericModel and any(isinstance(param, TypeVar) for param in params): # type: ignore
raise TypeError(f'Type parameters should be placed on typing.Generic, not GenericModel')
if Generic not in cls.__bases__:
if not hasattr(cls, '__parameters__'):
raise TypeError(f'Type {cls.__name__} must inherit from typing.Generic before being parameterized')

check_parameters_count(cls, params)
typevars_map: Dict[Any, Any] = dict(zip(cls.__parameters__, params)) # type: ignore
typevars_map: Dict[TypeVarType, Type[Any]] = dict(zip(cls.__parameters__, params))
type_hints = get_type_hints(cls).items()
instance_type_hints = {k: v for k, v in type_hints if getattr(v, '__origin__', None) is not ClassVar}
concrete_type_hints: Dict[str, Type[Any]] = {
Expand All @@ -41,19 +44,25 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T

model_name = cls.__concrete_name__(params)
validators = gather_all_validators(cls)
fields: Dict[str, Tuple[Type[Any], Any]] = {
k: (v, cls.__fields__[k].field_info) for k, v in concrete_type_hints.items() if k in cls.__fields__
}
created_model = create_model(
model_name=model_name,
__module__=cls.__module__,
__base__=cls,
__config__=None,
__validators__=validators,
**fields,
fields = _build_generic_fields(cls.__fields__, concrete_type_hints, typevars_map)
created_model = cast(
Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes
create_model(
model_name=model_name,
__module__=cls.__module__,
__base__=cls,
__config__=None,
__validators__=validators,
**fields,
),
)
created_model.Config = cls.Config
created_model.__concrete__ = True # type: ignore
concrete = all(not _is_typevar(v) for v in concrete_type_hints.values())
created_model.__concrete__ = concrete
if not concrete:
parameters = tuple(v for v in concrete_type_hints.values() if _is_typevar(v))
parameters = tuple({k: None for k in parameters}.keys()) # get unique params while maintaining order
created_model.__parameters__ = parameters
_generic_types_cache[(cls, params)] = created_model
if len(params) == 1:
_generic_types_cache[(cls, params[0])] = created_model
Expand All @@ -78,7 +87,30 @@ def resolve_type_hint(type_: Any, typevars_map: Dict[Any, Any]) -> Type[Any]:

def check_parameters_count(cls: Type[GenericModel], parameters: Tuple[Any, ...]) -> None:
actual = len(parameters)
expected = len(cls.__parameters__) # type: ignore
expected = len(cls.__parameters__)
if actual != expected:
description = 'many' if actual > expected else 'few'
raise TypeError(f'Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}')


def _build_generic_fields(
raw_fields: Dict[str, ModelField],
concrete_type_hints: Dict[str, Type[Any]],
typevars_map: Dict[TypeVarType, Type[Any]],
) -> Dict[str, Tuple[Type[Any], FieldInfo]]:
return {
k: (_parameterize_generic_field(v, typevars_map), raw_fields[k].field_info)
for k, v in concrete_type_hints.items()
if k in raw_fields
}


def _parameterize_generic_field(field_type: Type[Any], typevars_map: Dict[TypeVarType, Type[Any]]) -> Type[Any]:
if lenient_issubclass(field_type, GenericModel) and not field_type.__concrete__:
parameters = tuple(typevars_map.get(param, param) for param in field_type.__parameters__)
field_type = field_type[parameters]
return field_type


def _is_typevar(v: Any) -> bool:
return isinstance(v, TypeVar) # type: ignore
173 changes: 154 additions & 19 deletions tests/test_generics.py
Expand Up @@ -248,25 +248,6 @@ class Config:
result.data = 2


@skip_36
def test_generic_instantiation_error():
with pytest.raises(TypeError) as exc_info:
GenericModel()
assert str(exc_info.value) == 'Type GenericModel cannot be used without generic parameters, e.g. GenericModel[T]'


@skip_36
def test_parameterized_generic_instantiation_error():
data_type = TypeVar('data_type')

class Result(GenericModel, Generic[data_type]):
data: data_type

with pytest.raises(TypeError) as exc_info:
Result(data=1)
assert str(exc_info.value) == 'Type Result cannot be used without generic parameters, e.g. Result[T]'


@skip_36
def test_deep_generic():
T = TypeVar('T')
Expand Down Expand Up @@ -444,3 +425,157 @@ def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str:

assert repr(MyModel[int](value=1)) == 'OptionalIntWrapper(value=1)'
assert repr(MyModel[str](value=None)) == 'OptionalStrWrapper(value=None)'


@skip_36
def test_nested():
AT = TypeVar('AT')

class InnerT(GenericModel, Generic[AT]):
a: AT

inner_int = InnerT[int](a=8)
inner_str = InnerT[str](a='ate')
inner_dict_any = InnerT[Any](a={})
inner_int_any = InnerT[Any](a=7)

class OuterT_SameType(GenericModel, Generic[AT]):
i: InnerT[AT]

OuterT_SameType[int](i=inner_int)
OuterT_SameType[str](i=inner_str)
OuterT_SameType[int](i=inner_int_any) # ensure parsing the broader inner type works

with pytest.raises(ValidationError) as exc_info:
OuterT_SameType[int](i=inner_str)
assert exc_info.value.errors() == [
{'loc': ('i', 'a'), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}
]

with pytest.raises(ValidationError) as exc_info:
OuterT_SameType[int](i=inner_dict_any)
assert exc_info.value.errors() == [
{'loc': ('i', 'a'), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}
]


@skip_36
def test_partial_specification():
AT = TypeVar('AT')
BT = TypeVar('BT')

class Model(GenericModel, Generic[AT, BT]):
a: AT
b: BT

partial_model = Model[int, BT]
concrete_model = partial_model[str]
concrete_model(a=1, b='abc')
with pytest.raises(ValidationError) as exc_info:
concrete_model(a='abc', b=None)
assert exc_info.value.errors() == [
{'loc': ('a',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'},
{'loc': ('b',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'},
]


@skip_36
def test_partial_specification_name():
AT = TypeVar('AT')
BT = TypeVar('BT')

class Model(GenericModel, Generic[AT, BT]):
a: AT
b: BT

partial_model = Model[int, BT]
assert partial_model.__name__ == 'Model[int, BT]'
concrete_model = partial_model[str]
assert concrete_model.__name__ == 'Model[int, BT][str]'


@skip_36
def test_partial_specification_instantiation():
AT = TypeVar('AT')
BT = TypeVar('BT')

class Model(GenericModel, Generic[AT, BT]):
a: AT
b: BT

partial_model = Model[int, BT]
partial_model(a=1, b=2)

partial_model(a=1, b='a')

with pytest.raises(ValidationError) as exc_info:
partial_model(a='a', b=2)
assert exc_info.value.errors() == [
{'loc': ('a',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}
]


@skip_36
def test_partial_specification_instantiation_bounded():
AT = TypeVar('AT')
BT = TypeVar('BT', bound=int)

class Model(GenericModel, Generic[AT, BT]):
a: AT
b: BT

Model(a=1, b=1)
with pytest.raises(ValidationError) as exc_info:
Model(a=1, b='a')
assert exc_info.value.errors() == [
{'loc': ('b',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}
]

partial_model = Model[int, BT]
partial_model(a=1, b=1)
with pytest.raises(ValidationError) as exc_info:
partial_model(a=1, b='a')
assert exc_info.value.errors() == [
{'loc': ('b',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}
]


@skip_36
def test_typevar_parametrization():
AT = TypeVar('AT')
BT = TypeVar('BT')

class Model(GenericModel, Generic[AT, BT]):
a: AT
b: BT

CT = TypeVar('CT', bound=int)
DT = TypeVar('DT', bound=int)

with pytest.raises(ValidationError) as exc_info:
Model[CT, DT](a='a', b='b')
assert exc_info.value.errors() == [
{'loc': ('a',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'},
{'loc': ('b',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'},
]


@skip_36
def test_multiple_specification():
AT = TypeVar('AT')
BT = TypeVar('BT')

class Model(GenericModel, Generic[AT, BT]):
a: AT
b: BT

CT = TypeVar('CT')
partial_model = Model[CT, CT]
concrete_model = partial_model[str]

with pytest.raises(ValidationError) as exc_info:
concrete_model(a=None, b=None)
assert exc_info.value.errors() == [
{'loc': ('a',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'},
{'loc': ('b',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'},
]

0 comments on commit 0b37746

Please sign in to comment.