Skip to content

Commit

Permalink
model - nullable schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
commonism committed Feb 2, 2024
1 parent b06667b commit 6e71716
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 32 deletions.
19 changes: 8 additions & 11 deletions aiopenapi3/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
else:
from typing_extensions import TypeGuard

from pydantic import BaseModel, Field, AnyUrl, model_validator, PrivateAttr, ConfigDict
from pydantic import RootModel, BaseModel, TypeAdapter, Field, AnyUrl, model_validator, PrivateAttr, ConfigDict

from .json import JSONPointer, JSONReference
from .errors import ReferenceResolutionError, OperationParameterValidationError
Expand Down Expand Up @@ -424,7 +424,7 @@ def get_type(
discriminators: Optional[Sequence[DiscriminatorBase]] = None,
extra: Optional["SchemaBase"] = None,
fwdref: bool = False,
) -> Union[Type[BaseModel], ForwardRef]:
) -> Union[Type[BaseModel], Type[TypeAdapter], ForwardRef]:
if fwdref:
if "module" in ForwardRef.__init__.__code__.co_varnames:
# FIXME Python < 3.9 compat
Expand All @@ -449,18 +449,15 @@ def model(self, data: "JSON") -> Union[BaseModel, List[BaseModel]]:
:rtype: self.get_type()
"""

if self.type == "boolean":
assert len(self.properties) == 0
t = Model.typeof(cast("SchemaType", self))
if not isinstance(data, t):
return t(data)
return data
type_ = cast("SchemaType", self.get_type())
if isinstance(type_, TypeAdapter):
r = type_.validate_python(data)
else:
type_ = cast("SchemaType", self.get_type())
r = type_.model_validate(data)
if self.type in ("string", "number", "integer", "array"):
if self.type in ("string", "number", "integer", "array", "boolean"):
if isinstance(r, RootModel):
return r.root
return r
return r


class OperationBase:
Expand Down
40 changes: 23 additions & 17 deletions aiopenapi3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import List, Optional, Union, Tuple, Dict
from typing_extensions import Annotated, Literal

from pydantic import BaseModel, Field, RootModel, ConfigDict
from pydantic import BaseModel, TypeAdapter, Field, RootModel, ConfigDict
import pydantic

from .base import ReferenceBase, SchemaBase
Expand Down Expand Up @@ -69,12 +69,15 @@ def class_from_schema(s, _type):
return b


import pydantic_core


@dataclasses.dataclass
class _ClassInfo:
@dataclasses.dataclass
class _PropertyInfo:
annotation: Any = None
default: Any = None
default: Any = pydantic_core.PydanticUndefined

root: Any = None
config: Dict[str, Any] = dataclasses.field(default_factory=dict)
Expand Down Expand Up @@ -115,16 +118,18 @@ def from_schema(
schemanames: Optional[List[str]] = None,
discriminators: Optional[List["DiscriminatorType"]] = None,
extra: Optional["SchemaType"] = None,
) -> Type[BaseModel]:
) -> Union[Type[BaseModel], Type[TypeAdapter]]:
if schemanames is None:
schemanames = []

if discriminators is None:
discriminators = []

r: List[Type[BaseModel]] = list()
r: List[Union[Type[BaseModel], Type[TypeAdapter]]] = list()

for _type in Model.types(schema):
if _type == "null":
continue
r.append(Model.from_schema_type(schema, _type, schemanames, discriminators, extra))

if len(r) > 1:
Expand All @@ -134,7 +139,13 @@ def from_schema(
elif len(r) == 1:
m: Type[BaseModel] = cast(Type[BaseModel], r[0])
else: # == 0
raise ValueError(r)
assert schema.type == "null"
return TypeAdapter(None.__class__)

if not isinstance(m, TypeAdapter) and Model.is_nullable(schema):
n = TypeAdapter(Optional[m])
return cast(Type[TypeAdapter], n)

return cast(Type[BaseModel], m)

@classmethod
Expand All @@ -152,11 +163,8 @@ def from_schema_type(

classinfo = _ClassInfo()

# do not create models for primitive types
# create models for primitive types to be nullable
if _type in ("string", "integer", "number", "boolean"):
if _type == "boolean":
return bool

if typing.get_origin((_t := Model.typeof(schema, _type=_type))) != Literal:
classinfo.root = Annotated[_t, Model.fieldof_args(schema, None)]
else:
Expand Down Expand Up @@ -325,7 +333,7 @@ def typeof(
if schema is None:
return BaseModel
if isinstance(schema, SchemaBase):
nullable = False
nullable = Model.is_nullable(schema)
schema = cast("SchemaType", schema)
"""
Required, can be None: Optional[str]
Expand Down Expand Up @@ -520,7 +528,7 @@ def or_type(schema: "SchemaType", type_: str, l: Optional[int] = 2) -> bool:

@staticmethod
def is_nullable(schema: "SchemaType") -> bool:
return Model.or_type(schema, "null", l=None) or getattr(schema, "nullable", False)
return Model.or_type(schema, "null", l=None) or getattr(schema, "nullable", False) is True

@staticmethod
def is_type_any(schema: "SchemaType"):
Expand All @@ -537,14 +545,12 @@ def fieldof(schema: "SchemaType", classinfo: _ClassInfo):
for name, f in schema.properties.items():
args: Dict[str, Any] = dict()
assert schema.required is not None
if name not in schema.required:
if (v := getattr(f, "default", None)) is not None:
args["default"] = v
elif name not in schema.required:
args["default"] = None

name = Model.nameof(name, args=args)
if Model.is_nullable(f):
args["default"] = None
for i in ["default"]:
if (v := getattr(f, i, None)) is not None:
args[i] = v
classinfo.properties[name].default = Model.fieldof_args(f, args)
else:
raise ValueError(schema.type)
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,8 @@ def with_paths_response_error():
@pytest.fixture
def with_schema_ref_nesting():
yield _get_parsed_yaml("schema-ref-nesting.yaml")


@pytest.fixture
def with_schema_nullable(openapi_version):
yield _get_parsed_yaml(f"schema-nullable-v{openapi_version.major}{openapi_version.minor}.yaml")
45 changes: 45 additions & 0 deletions tests/fixtures/schema-nullable-v30.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
openapi: 3.0.3
info:
title: ''
version: 0.0.0
servers:
- url: http://127.0.0.1/api

security:
- {}

paths: {}

components:
schemas:
object:
type: object
additionalProperties: false
properties:
attr:
$ref: '#/components/schemas/nullable'
nullable: true
required:
- attr

array:
type: array
items:
$ref: '#/components/schemas/nullable'
nullable: true

string:
type: string
nullable: true

integer:
type: integer
nullable: true

boolean:
type: boolean
nullable: true

nullable:
nullable: true
type: string
40 changes: 40 additions & 0 deletions tests/fixtures/schema-nullable-v31.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
openapi: 3.1.0
info:
title: ''
version: 0.0.0
servers:
- url: http://127.0.0.1/api

security:
- {}

paths: {}

components:
schemas:
object:
type: [object, "null"]
additionalProperties: false
properties:
attr:
$ref: '#/components/schemas/nullable'
required:
- attr

array:
type: [array, "null"]
items:
$ref: '#/components/schemas/nullable'


string:
type: [string, "null"]

integer:
type: [integer, "null"]

boolean:
type: [boolean, "null"]

nullable:
type: [string, "null"]
37 changes: 33 additions & 4 deletions tests/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,11 @@ def test_schema_enum(with_schema_enum):
with pytest.raises(ValidationError):
String(None)

Nullable = api.components.schemas["Nullable"].get_type()
Nullable("a")
Nullable(None)
Nullable = api.components.schemas["Nullable"]
Nullable.model("a")
Nullable.model(None)
with pytest.raises(ValidationError):
Nullable("c")
Nullable.model("c")

Mixed = api.components.schemas["Mixed"].get_type()
Mixed(1)
Expand Down Expand Up @@ -470,3 +470,32 @@ def test_schema_baseurl_v20(with_schema_baseurl_v20):
def test_schema_ref_nesting(with_schema_ref_nesting):
for i in range(10):
OpenAPI("/", with_schema_ref_nesting)


@pytest.mark.parametrize(
"schema, input, output, okay",
[
("object", None, None, True),
("object", {"attr": "a"}, {"attr": "a"}, True),
("object", {"attr": None}, {"attr": None}, True),
("object", {}, {}, False),
("integer", None, None, True),
("integer", 1, 1, True),
("boolean", None, None, True),
("boolean", True, True, True),
("string", None, None, True),
("string", "a", "a", True),
("array", None, None, True),
("array", [], [], True),
],
)
def test_schema_nullable(with_schema_nullable, schema, input, output, okay):
api = OpenAPI("/", with_schema_nullable) # , plugins=[NullableRefs()])

m = api.components.schemas[schema]
t = m.get_type()
if okay:
m.model(input)
else:
with pytest.raises(ValidationError):
m.model(input)

0 comments on commit 6e71716

Please sign in to comment.