Skip to content

Commit

Permalink
Add Json Support To Params
Browse files Browse the repository at this point in the history
Add Json Support To Params
  • Loading branch information
emirthab committed Sep 1, 2023
1 parent a3f1689 commit 0ed16c0
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 16 deletions.
100 changes: 84 additions & 16 deletions fastapi/_compat.py
Expand Up @@ -43,6 +43,13 @@

sequence_types = tuple(sequence_annotation_to_type.keys())

mapping_annotation_to_type = {
Mapping: list,
}

mapping_types = tuple(mapping_annotation_to_type.keys())


if PYDANTIC_V2:
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import TypeAdapter
Expand All @@ -56,6 +63,7 @@
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import MultiHostUrl as MultiHostUrl
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from pydantic_core import Url as Url
from pydantic_core.core_schema import (
Expand Down Expand Up @@ -181,13 +189,9 @@ def get_schema_from_model_field(
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
override_mode: Union[Literal["validation"], None] = (
None if separate_input_output_schemas else "validation"
)
# This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = field_mapping[(field, override_mode or field.mode)]
json_schema = field_mapping[(field, field.mode)]
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
Expand All @@ -204,19 +208,14 @@ def get_definitions(
fields: List[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
Dict[str, Dict[str, Any]],
]:
override_mode: Union[Literal["validation"], None] = (
None if separate_input_output_schemas else "validation"
)
inputs = [
(field, override_mode or field.mode, field._type_adapter.core_schema)
for field in fields
(field, field.mode, field._type_adapter.core_schema) for field in fields
]
field_mapping, definitions = schema_generator.generate_definitions(
inputs=inputs
Expand All @@ -236,6 +235,12 @@ def is_sequence_field(field: ModelField) -> bool:
def is_scalar_sequence_field(field: ModelField) -> bool:
return field_annotation_is_scalar_sequence(field.field_info.annotation)

def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
return field_annotation_is_scalar_sequence_mapping(field.field_info.annotation)

def is_scalar_mapping_field(field: ModelField) -> bool:
return field_annotation_is_scalar_mapping(field.field_info.annotation)

def is_bytes_field(field: ModelField) -> bool:
return is_bytes_or_nonable_bytes_annotation(field.type_)

Expand Down Expand Up @@ -283,6 +288,7 @@ def create_body_model(
from pydantic.fields import ( # type: ignore[attr-defined]
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_MAPPING,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
Expand All @@ -302,6 +308,9 @@ def create_body_model(
from pydantic.fields import ( # type: ignore[no-redef, attr-defined]
UndefinedType as UndefinedType, # noqa: F401
)
from pydantic.networks import ( # type: ignore[no-redef]
MultiHostDsn as MultiHostUrl, # noqa: F401
)
from pydantic.schema import (
field_schema,
get_flat_models_from_fields,
Expand Down Expand Up @@ -330,6 +339,7 @@ def create_body_model(
SHAPE_SEQUENCE,
SHAPE_TUPLE_ELLIPSIS,
}

sequence_shape_to_type = {
SHAPE_LIST: list,
SHAPE_SET: set,
Expand All @@ -338,6 +348,11 @@ def create_body_model(
SHAPE_TUPLE_ELLIPSIS: list,
}

mapping_shapes = {
SHAPE_MAPPING,
}
mapping_shapes_to_type = {SHAPE_MAPPING: Mapping}

@dataclass
class GenerateJsonSchema: # type: ignore[no-redef]
ref_template: str
Expand Down Expand Up @@ -405,6 +420,28 @@ def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
return True
return False

def is_pv1_scalar_mapping_field(field: ModelField) -> bool:
if (field.shape in mapping_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
field.type_, BaseModel
):
if field.sub_fields is not None: # type: ignore[attr-defined]
for sub_field in field.sub_fields: # type: ignore[attr-defined]
if not is_scalar_field(sub_field):
return False
return True
return False

def is_pv1_scalar_sequence_mapping_field(field: ModelField) -> bool:
if (field.shape in mapping_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
field.type_, BaseModel
):
if field.sub_fields is not None: # type: ignore[attr-defined]
for sub_field in field.sub_fields: # type: ignore[attr-defined]
if not is_scalar_sequence_field(sub_field):
return False
return True
return False

def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
use_errors: List[Any] = []
for error in errors:
Expand Down Expand Up @@ -438,7 +475,6 @@ def get_schema_from_model_field(
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
# This expects that GenerateJsonSchema was already used to generate the definitions
return field_schema( # type: ignore[no-any-return]
Expand All @@ -454,7 +490,6 @@ def get_definitions(
fields: List[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
Expand All @@ -475,6 +510,12 @@ def is_sequence_field(field: ModelField) -> bool:
def is_scalar_sequence_field(field: ModelField) -> bool:
return is_pv1_scalar_sequence_field(field)

def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
return is_pv1_scalar_sequence_mapping_field(field)

def is_scalar_mapping_field(field: ModelField) -> bool:
return is_pv1_scalar_mapping_field(field)

def is_bytes_field(field: ModelField) -> bool:
return lenient_issubclass(field.type_, bytes)

Expand Down Expand Up @@ -524,14 +565,27 @@ def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
)


def _annotation_is_mapping(annotation: Union[Type[Any], None]) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
return lenient_issubclass(annotation, mapping_types)


def field_annotation_is_mapping(annotation: Union[Type[Any], None]) -> bool:
return _annotation_is_mapping(annotation) or _annotation_is_mapping(
get_origin(annotation)
)


def value_is_sequence(value: Any) -> bool:
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]


def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
return (
lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
lenient_issubclass(annotation, (BaseModel, UploadFile))
or _annotation_is_sequence(annotation)
or _annotation_is_mapping(annotation)
or is_dataclass(annotation)
)

Expand Down Expand Up @@ -562,15 +616,29 @@ def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> b
if field_annotation_is_scalar_sequence(arg):
at_least_one_scalar_sequence = True
continue
elif not field_annotation_is_scalar(arg):
return False
return at_least_one_scalar_sequence
return field_annotation_is_sequence(annotation) and all(
field_annotation_is_scalar(sub_annotation)
for sub_annotation in get_args(annotation)
)


def field_annotation_is_scalar_mapping(annotation: Union[Type[Any], None]) -> bool:
return field_annotation_is_mapping(annotation) and all(
field_annotation_is_scalar(sub_annotation)
for sub_annotation in get_args(annotation)
)


def field_annotation_is_scalar_sequence_mapping(
annotation: Union[Type[Any], None]
) -> bool:
return field_annotation_is_mapping(annotation) and all(
field_annotation_is_scalar_sequence(sub_annotation)
for sub_annotation in get_args(annotation)[1:]
)


def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, bytes):
return True
Expand Down
47 changes: 47 additions & 0 deletions fastapi/dependencies/utils.py
@@ -1,4 +1,5 @@
import inspect
from collections import defaultdict
from contextlib import contextmanager
from copy import deepcopy
from typing import (
Expand Down Expand Up @@ -35,7 +36,9 @@
is_bytes_field,
is_bytes_sequence_field,
is_scalar_field,
is_scalar_mapping_field,
is_scalar_sequence_field,
is_scalar_sequence_mapping_field,
is_sequence_field,
is_uploadfile_or_nonable_uploadfile_annotation,
is_uploadfile_sequence_annotation,
Expand Down Expand Up @@ -450,6 +453,11 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
param_field.field_info, (params.Query, params.Header)
) and is_scalar_sequence_field(param_field):
return False
elif isinstance(param_field.field_info, params.Query) and (
is_scalar_sequence_mapping_field(param_field)
or is_scalar_mapping_field(param_field)
):
return False
else:
assert isinstance(
param_field.field_info, params.Body
Expand Down Expand Up @@ -633,6 +641,10 @@ async def solve_dependencies(
return values, errors, background_tasks, response, dependency_cache


class Marker:
pass


def request_params_to_args(
required_params: Sequence[ModelField],
received_params: Union[Mapping[str, Any], QueryParams, Headers],
Expand All @@ -644,6 +656,16 @@ def request_params_to_args(
received_params, (QueryParams, Headers)
):
value = received_params.getlist(field.alias) or field.default
elif is_scalar_mapping_field(field) and isinstance(
received_params, QueryParams
):
value = dict(received_params.multi_items()) or field.default
elif is_scalar_sequence_mapping_field(field) and isinstance(
received_params, QueryParams
):
value = defaultdict(list)
for k, v in received_params.multi_items():
value[k].append(v)
else:
value = received_params.get(field.alias)
field_info = field.field_info
Expand All @@ -660,6 +682,31 @@ def request_params_to_args(
v_, errors_ = field.validate(value, values, loc=loc)
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif (
isinstance(errors_, list)
and is_scalar_sequence_mapping_field(field)
and isinstance(received_params, QueryParams)
):
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
# remove all invalid parameters
marker = Marker()
for _, _, key, index in [error["loc"] for error in new_errors]:
value[key][index] = marker
for key in value:
value[key] = [x for x in value[key] if x != marker]
v_, _ = field.validate(value, values, loc=loc)
values[field.name] = v_
elif (
isinstance(errors_, list)
and is_scalar_mapping_field(field)
and isinstance(received_params, QueryParams)
):
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
# remove all invalid parameters
for _, _, key in [error["loc"] for error in new_errors]:
value.pop(key)
v_, _ = field.validate(value, values, loc=loc)
values[field.name] = v_
elif isinstance(errors_, list):
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
errors.extend(new_errors)
Expand Down

0 comments on commit 0ed16c0

Please sign in to comment.