Skip to content

Commit

Permalink
Fix PydanticFieldMeta.from_model_field() to get constraints from oute…
Browse files Browse the repository at this point in the history
…r_type instead of annotation
  • Loading branch information
gsakkis committed May 17, 2023
1 parent e74fd8e commit 5adce2d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
43 changes: 25 additions & 18 deletions polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,30 +68,37 @@ def from_model_field(cls, model_field: ModelField, use_alias: bool) -> PydanticF

name = model_field.alias if model_field.alias and use_alias else model_field.name

annotation = unwrap_new_type(
model_field.annotation if not isinstance(model_field.annotation, DeferredType) else model_field.outer_type_
outer_type = unwrap_new_type(model_field.outer_type_)
annotation = (
unwrap_new_type(model_field.annotation)
if not isinstance(model_field.annotation, DeferredType)
else outer_type
)

constraints = cast(
"Constraints",
{
"constant": bool(model_field.field_info.const) or None,
"ge": getattr(annotation, "ge", model_field.field_info.ge),
"gt": getattr(annotation, "gt", model_field.field_info.gt),
"le": getattr(annotation, "le", model_field.field_info.le),
"lt": getattr(annotation, "lt", model_field.field_info.lt),
"min_length": getattr(annotation, "min_length", model_field.field_info.min_length)
or getattr(annotation, "min_items", model_field.field_info.min_items),
"max_length": getattr(annotation, "max_length", model_field.field_info.max_length)
or getattr(annotation, "max_items", model_field.field_info.max_items),
"pattern": getattr(annotation, "regex", model_field.field_info.regex),
"unique_items": getattr(annotation, "unique_items", model_field.field_info.unique_items),
"decimal_places": getattr(annotation, "decimal_places", None),
"max_digits": getattr(annotation, "max_digits", None),
"multiple_of": getattr(annotation, "multiple_of", None),
"upper_case": getattr(annotation, "to_upper", None),
"lower_case": getattr(annotation, "to_lower", None),
"item_type": getattr(annotation, "item_type", None),
"ge": getattr(outer_type, "ge", model_field.field_info.ge),
"gt": getattr(outer_type, "gt", model_field.field_info.gt),
"le": getattr(outer_type, "le", model_field.field_info.le),
"lt": getattr(outer_type, "lt", model_field.field_info.lt),
"min_length": (
getattr(outer_type, "min_length", model_field.field_info.min_length)
or getattr(outer_type, "min_items", model_field.field_info.min_items)
),
"max_length": (
getattr(outer_type, "max_length", model_field.field_info.max_length)
or getattr(outer_type, "max_items", model_field.field_info.max_items)
),
"pattern": getattr(outer_type, "regex", model_field.field_info.regex),
"unique_items": getattr(outer_type, "unique_items", model_field.field_info.unique_items),
"decimal_places": getattr(outer_type, "decimal_places", None),
"max_digits": getattr(outer_type, "max_digits", None),
"multiple_of": getattr(outer_type, "multiple_of", None),
"upper_case": getattr(outer_type, "to_upper", None),
"lower_case": getattr(outer_type, "to_lower", None),
"item_type": getattr(outer_type, "item_type", None),
},
)

Expand Down
4 changes: 3 additions & 1 deletion tests/test_constrained_attribute_parsing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from decimal import Decimal
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -44,6 +44,7 @@ class ConstrainedModel(BaseModel):
decimal_field: Decimal = Field(ge=100, le=1000)
list_field: List[str] = Field(min_items=1, max_items=10)
constant_field: int = Field(const=True, default=100)
optional_field: Optional[constr(min_length=1)] # type: ignore[valid-type]

class MyFactory(ModelFactory):
__model__ = ConstrainedModel
Expand Down Expand Up @@ -83,6 +84,7 @@ class MyFactory(ModelFactory):
assert len(result.list_field) <= 10
assert all(isinstance(r, str) for r in result.list_field)
assert result.constant_field == 100
assert result.optional_field is None or len(result.optional_field) >= 1


def test_complex_constrained_attribute_parsing() -> None:
Expand Down

0 comments on commit 5adce2d

Please sign in to comment.