From 4be059eded48ff05a73c49b81bfb50e66479ea4b Mon Sep 17 00:00:00 2001 From: Jean-Christophe Bianic Date: Tue, 24 Jan 2023 17:09:55 +0100 Subject: [PATCH 01/10] :hammer: Rename test file --- tests/{test_request_parsing.py => test_inbound_handler.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_request_parsing.py => test_inbound_handler.py} (100%) diff --git a/tests/test_request_parsing.py b/tests/test_inbound_handler.py similarity index 100% rename from tests/test_request_parsing.py rename to tests/test_inbound_handler.py From e7d42d94aaa4793b0224aff7121c166cc7156844 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Bianic Date: Tue, 31 Jan 2023 17:21:37 +0100 Subject: [PATCH 02/10] :package: Introducing ViewParams --- flask_jeroboam/__init__.py | 6 + flask_jeroboam/view_params/__init__.py | 11 + flask_jeroboam/view_params/functions.py | 336 +++++++++++++++++++++++ flask_jeroboam/view_params/parameters.py | 191 +++++++++++++ flask_jeroboam/view_params/solved.py | 157 +++++++++++ 5 files changed, 701 insertions(+) create mode 100644 flask_jeroboam/view_params/__init__.py create mode 100644 flask_jeroboam/view_params/functions.py create mode 100644 flask_jeroboam/view_params/parameters.py create mode 100644 flask_jeroboam/view_params/solved.py diff --git a/flask_jeroboam/__init__.py b/flask_jeroboam/__init__.py index 3b4fd13..c602c3a 100644 --- a/flask_jeroboam/__init__.py +++ b/flask_jeroboam/__init__.py @@ -1,2 +1,8 @@ from .jeroboam import Jeroboam from .jeroboam import JeroboamBlueprint +from .view_params.functions import Body +from .view_params.functions import File +from .view_params.functions import Form +from .view_params.functions import Header +from .view_params.functions import Path +from .view_params.functions import Query diff --git a/flask_jeroboam/view_params/__init__.py b/flask_jeroboam/view_params/__init__.py new file mode 100644 index 0000000..698d659 --- /dev/null +++ b/flask_jeroboam/view_params/__init__.py @@ -0,0 +1,11 @@ +from .functions import Body +from .functions import Cookie +from .functions import File +from .functions import Form +from .functions import Header +from .functions import Path +from .functions import Query +from .parameters import ParamLocation +from .parameters import ViewParameter +from .parameters import get_parameter_class +from .solved import SolvedParameter diff --git a/flask_jeroboam/view_params/functions.py b/flask_jeroboam/view_params/functions.py new file mode 100644 index 0000000..8947ad4 --- /dev/null +++ b/flask_jeroboam/view_params/functions.py @@ -0,0 +1,336 @@ +"""Function to declare the Type of Parameters. + +This functions are used to declare the parameters of the view functions. +By annotating the return value with Any, we make sure that the code editor +don't complain too much about assigning a default value of type ViewParameter +to a parameter that have been annotated with a pydantic-compatible type... + +Credits: This module is essentially a fork from the params module of FlaskAPI. +""" + +from typing import Any +from typing import Dict +from typing import Optional + +from pydantic.fields import Undefined + +from .parameters import BodyParameter +from .parameters import CookieParameter +from .parameters import FileParameter +from .parameters import FormParameter +from .parameters import HeaderParameter +from .parameters import ParamLocation +from .parameters import PathParameter +from .parameters import QueryParameter + + +def Path( # noqa:N802 + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + **extra: Any, +) -> Any: + """Declare A Path parameter.""" + return PathParameter( + default=default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + example=example, + examples=examples, + deprecated=deprecated, + include_in_schema=include_in_schema, + **extra, + ) + + +def Query( # noqa:N802 + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + **extra: Any, +) -> Any: + """Declare A Query parameter.""" + return QueryParameter( + default=default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + example=example, + examples=examples, + deprecated=deprecated, + include_in_schema=include_in_schema, + **extra, + ) + + +def Header( # noqa:N802 + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + convert_underscores: bool = True, # for headers + **extra: Any, +) -> Any: + """Declare A Header parameter.""" + return HeaderParameter( + default=default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + example=example, + examples=examples, + deprecated=deprecated, + include_in_schema=include_in_schema, + convert_underscores=convert_underscores, + **extra, + ) + + +def Cookie( # noqa:N802 + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + **extra: Any, +) -> Any: + """Declare A Cookie parameter.""" + return CookieParameter( + default=default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + example=example, + examples=examples, + deprecated=deprecated, + include_in_schema=include_in_schema, + **extra, + ) + + +def Body( # noqa:N802 + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + location: ParamLocation = ParamLocation.query, # for all + required: bool = False, + convert_underscores: bool = True, # for headers + embed: bool = True, # for body + media_type: str = "application/json", + **extra: Any, +) -> Any: + """Declare A Body parameter.""" + return BodyParameter( + default=default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + example=example, + examples=examples, + deprecated=deprecated, + include_in_schema=include_in_schema, + location=location, + required=required, + convert_underscores=convert_underscores, + embed=embed, + media_type=media_type, + **extra, + ) + + +def Form( # noqa: N802 + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + location: ParamLocation = ParamLocation.query, # for all + required: bool = False, + convert_underscores: bool = True, # for headers + embed: bool = False, # for body + media_type: str = "application/json", + **extra: Any, +) -> Any: + """Declare A Form parameter.""" + return FormParameter( + default=default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + example=example, + examples=examples, + deprecated=deprecated, + include_in_schema=include_in_schema, + location=location, + required=required, + convert_underscores=convert_underscores, + embed=embed, + media_type=media_type, + **extra, + ) + + +def File( # noqa: N802 + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + location: ParamLocation = ParamLocation.query, # for all + required: bool = False, + convert_underscores: bool = True, # for headers + embed: bool = True, # for body + media_type: str = "application/json", + **extra: Any, +) -> Any: + """Declare A File parameter.""" + return FileParameter( + default=default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + example=example, + examples=examples, + deprecated=deprecated, + include_in_schema=include_in_schema, + location=location, + required=required, + convert_underscores=convert_underscores, + embed=embed, + media_type=media_type, + **extra, + ) diff --git a/flask_jeroboam/view_params/parameters.py b/flask_jeroboam/view_params/parameters.py new file mode 100644 index 0000000..422b78f --- /dev/null +++ b/flask_jeroboam/view_params/parameters.py @@ -0,0 +1,191 @@ +"""View Parameters. + +Subclasses of pydantic.fields.FieldInfo that are used to define +localised fields with some extra information. +""" + +from enum import Enum +from typing import Any +from typing import Type + +from pydantic.fields import FieldInfo +from pydantic.fields import Undefined + + +class ParamLocation(Enum): + """Enum for the possible source location of a view_function parameter.""" + + query = "query" + header = "header" + path = "path" + cookie = "cookie" + body = "body" + form = "form" + file = "file" + + +class ViewParameter(FieldInfo): + """Base class for all View parameters.""" + + location: ParamLocation + + def __init__( + self, + default: Any = Undefined, + **kwargs: Any, + ): + self.example = kwargs.pop("example", Undefined) + self.examples = kwargs.pop("examples", None) + self.embed = kwargs.pop("embed", False) + super().__init__( + default=default, + **kwargs, + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.default})" + + @property + def in_body(self): + """Is the parameter located in the body?""" + return self.location in { + ParamLocation.body, + ParamLocation.form, + ParamLocation.file, + } + + +class NonBodyParameter(ViewParameter): + """A Parameter that is not located in the body.""" + + def __init__( + self, + default: Any = Undefined, + **kwargs: Any, + ): + self.deprecated = kwargs.pop("deprecated", None) + self.include_in_schema = kwargs.pop("include_in_schema", True) + super().__init__( + default=default, + **kwargs, + ) + + +class QueryParameter(NonBodyParameter): + """A Parameter found in the Query String.""" + + location = ParamLocation.query + + +class PathParameter(NonBodyParameter): + """A Parameter found in Path.""" + + location = ParamLocation.path + + def __init__( + self, + default: Any = Undefined, + **kwargs: Any, + ): + self.required = True + super().__init__( + default=..., + **kwargs, + ) + + +class HeaderParameter(NonBodyParameter): + """A Header parameter.""" + + location = ParamLocation.header + + def __init__( + self, + default: Any = Undefined, + **kwargs: Any, + ): + self.convert_underscores = kwargs.pop("convert_underscores", True) + super().__init__( + default=default, + **kwargs, + ) + + +class CookieParameter(NonBodyParameter): + """A Parameter located in Cookies.""" + + location = ParamLocation.cookie + + +class BodyParameter(ViewParameter): + """A Parameter located in Body. + + Body Parameters can be embedded. which means that they must + be accessed by their name at the root of the body. + They also have a Media/Type that varies between body, form and file. + """ + + location = ParamLocation.body + + def __init__( + self, + default: Any = Undefined, + **kwargs: Any, + ): + self.embed = kwargs.get("embed", False) + self.media_type = kwargs.pop("media_type", "application/json") + super().__init__( + default=default, + **kwargs, + ) + + +class FormParameter(BodyParameter): + """A Parameter located in Body.""" + + location = ParamLocation.form + + def __init__( + self, + default: Any = Undefined, + **kwargs: Any, + ): + self.media_type = kwargs.pop("media_type", "application/x-www-form-urlencoded") + embed = kwargs.pop("embed", True) + super().__init__( + default=default, + embed=embed, + **kwargs, + ) + + +class FileParameter(FormParameter): + """A Parameter located in Body.""" + + location = ParamLocation.file + + def __init__( + self, + default: Any = Undefined, + **kwargs: Any, + ): + self.media_type = kwargs.pop("media_type", "multipart/form-data") + embed = kwargs.pop("embed", False) + super().__init__( + default=default, + embed=embed, + **kwargs, + ) + + +def get_parameter_class(location: ParamLocation) -> Type[ViewParameter]: + """Get the Parameter class for a given location.""" + return { + ParamLocation.query: QueryParameter, + ParamLocation.header: HeaderParameter, + ParamLocation.path: PathParameter, + ParamLocation.cookie: CookieParameter, + ParamLocation.body: BodyParameter, + ParamLocation.form: FormParameter, + ParamLocation.file: FileParameter, + }[location] diff --git a/flask_jeroboam/view_params/solved.py b/flask_jeroboam/view_params/solved.py new file mode 100644 index 0000000..9c94dad --- /dev/null +++ b/flask_jeroboam/view_params/solved.py @@ -0,0 +1,157 @@ +"""View params for solved problems.""" +import re +from copy import deepcopy +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Type +from typing import Union + +from flask import current_app +from flask import request +from pydantic import BaseConfig +from pydantic.error_wrappers import ErrorWrapper +from pydantic.errors import MissingError +from pydantic.fields import FieldInfo +from pydantic.fields import ModelField +from werkzeug.datastructures import MultiDict + +from flask_jeroboam.utils import is_scalar_sequence_field + +from .parameters import ParamLocation +from .parameters import ViewParameter + + +empty_field_info = FieldInfo() + + +class SolvedParameter(ModelField): + """A Parameter that have been solved, ready to validate data.""" + + def __init__( + self, + *, + name: str, + type_: type, + required: bool = False, + view_param: Optional[ViewParameter] = None, + class_validators: Optional[Dict] = None, + model_config: Type[BaseConfig] = BaseConfig, + field_info: FieldInfo = empty_field_info, + **kwargs, + ): + self.name = name + self.location: Optional[ParamLocation] = getattr(view_param, "location", None) + if self.location == ParamLocation.file: + BaseConfig.arbitrary_types_allowed = True + self.required = required + self.embed = getattr(view_param, "embed", None) + self.in_body = getattr(view_param, "in_body", None) + default = getattr(view_param, "default", field_info.default) + class_validators = class_validators or {} + if getattr(view_param, "convert_underscores", False): + self.alias = re.sub( + r"_(\w)", lambda x: f"-{x.group(1).upper()}", self.name.capitalize() + ) + kwargs["alias"] = self.alias + else: + kwargs["alias"] = kwargs.get("alias", getattr(view_param, "alias", None)) + super().__init__( + name=name, + type_=type_, + class_validators={}, + default=default, + required=required, + model_config=model_config, + field_info=view_param, + **kwargs, + ) + + def validate_request(self): + """Validate the request.""" + values = {} + errors = [] + + inbound_values = self._get_values() + if inbound_values is None: + if self.required: + errors.append( + ErrorWrapper(MissingError(), loc=(self.location.value, self.alias)) + ) + return values, errors + else: + values = {self.name: deepcopy(self.default)} + return values, errors + + values_, errors_ = self.validate( + inbound_values, values, loc=(self.location.value, self.alias) + ) + if isinstance(errors_, ErrorWrapper): + errors.append(errors_) + else: + values.update({self.name: values_}) + return values, errors + + def _get_values(self) -> Union[dict, Optional[str], List[Any]]: + """Get the values from the request.""" + if self.in_body: + return self._get_values_from_body() + else: + return self._get_values_from_request() + + def _get_values_from_body(self): + """Get the values from the request body.""" + if self.location == ParamLocation.form: + source = request.form + elif self.location == ParamLocation.file: + source = request.files + else: + source = request.json + if self.embed: + values = source.get(self.alias or self.name) + else: + values = source + return values + + def _get_values_from_request(self) -> Union[dict, Optional[str], List[Any]]: + """Get the values from the request. + + # TODO: Gestion des alias de fields. + # TODO: Gestion des default empty et des valeurs manquantes. + # Est-ce qu'on gère le embed dans les QueryParams ? + """ + values: Union[dict, Optional[str], List[Any]] = {} + source: MultiDict = MultiDict() + # Decide on the source of the values + if self.location == ParamLocation.query: + source = request.args + elif self.location == ParamLocation.path: + source = MultiDict(request.view_args) + elif self.location == ParamLocation.header: + source = MultiDict(request.headers) + elif self.location == ParamLocation.cookie: + source = request.cookies + else: + raise ValueError("Unknown location") + + if hasattr(self.type_, "__fields__"): + assert isinstance(values, dict) # noqa: S101 + for field_name, field in self.type_.__fields__.items(): + if is_scalar_sequence_field(field): + values[field_name] = source.getlist(field.alias or field_name) + else: + values[field_name] = source.get(field.alias or field_name) + if values[field_name] is None and getattr( + current_app, "query_string_key_transformer", False + ): + values_ = current_app.query_string_key_transformer( # type: ignore + current_app, source.to_dict() + ) + values[field_name] = values_.get(field.alias or field_name) + else: + if is_scalar_sequence_field(self): + values = source.getlist(self.alias or self.name) + else: + values = source.get(self.alias or self.name) + return values From 1b912f44da6ca30f7a032d809f2dd9b8d944355c Mon Sep 17 00:00:00 2001 From: Jean-Christophe Bianic Date: Tue, 31 Jan 2023 17:23:32 +0100 Subject: [PATCH 03/10] :hammer: InboundHandler as a Class --- flask_jeroboam/_inboundhandler.py | 260 ++++++++++++++++++--------- flask_jeroboam/view.py | 4 +- flask_jeroboam/view_params/solved.py | 1 + 3 files changed, 182 insertions(+), 83 deletions(-) diff --git a/flask_jeroboam/_inboundhandler.py b/flask_jeroboam/_inboundhandler.py index a7880d0..9aedd57 100644 --- a/flask_jeroboam/_inboundhandler.py +++ b/flask_jeroboam/_inboundhandler.py @@ -1,19 +1,29 @@ -import json +import inspect import re import typing as t from enum import Enum from functools import wraps +from typing import Any from typing import Callable -from typing import Type - -from flask import request -from flask.globals import current_app -from pydantic import BaseModel +from typing import Dict +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Union + +from pydantic.error_wrappers import ErrorWrapper +from pydantic.fields import Undefined +from pydantic.schema import get_annotation_from_field_info from typing_extensions import ParamSpec from flask_jeroboam.exceptions import InvalidRequest from flask_jeroboam.typing import JeroboamResponseReturnValue from flask_jeroboam.typing import JeroboamRouteCallable +from flask_jeroboam.view_params import ParamLocation +from flask_jeroboam.view_params import SolvedParameter +from flask_jeroboam.view_params import ViewParameter +from flask_jeroboam.view_params.parameters import get_parameter_class from .utils import get_typed_signature @@ -44,89 +54,177 @@ class InboundHandler: view function. It is also responsible for raising an InvalidRequest exception. The InboundHandler will only be called if the view function has type-annotated parameters. + + + + #TODO: Get Better at laying Out Levels of the Algorythm. Most Likely in the View + # class. + # And Moving away from the decorator scheme which feels obstrusive sometimes. """ def __init__(self, view_func: Callable, main_http_verb: str, rule: str): - self.typed_params = get_typed_signature(view_func) self.main_http_verb = main_http_verb + self.default_param_location = self._solve_default_params_location( + main_http_verb + ) self.rule = rule - - def __bool__(self) -> bool: - return bool(self.typed_params.parameters) - - def __call__(self, func: JeroboamRouteCallable) -> JeroboamRouteCallable: + self.path_param_names = set(re.findall("<(?:.*:)?(.*?)>", rule)) + self.query_params: List[SolvedParameter] = [] + self.path_params: List[SolvedParameter] = [] + self.header_params: List[SolvedParameter] = [] + self.cookie_params: List[SolvedParameter] = [] + self.body_params: List[SolvedParameter] = [] + self.form_params: List[SolvedParameter] = [] + self.file_params: List[SolvedParameter] = [] + self.other_params: List[SolvedParameter] = [] + self.locations_to_visit: Set[ParamLocation] = set() + self._solve_params(view_func) + self._check_compliance() + + @staticmethod + def _solve_default_params_location( + main_http_verb: str, + ) -> ParamLocation: + """Return the default FieldInfo for the InboundHandler.""" + if main_http_verb in ("POST", "PUT"): + return ParamLocation.body + elif main_http_verb == "GET": + return ParamLocation.query + else: + return ParamLocation.path + + @property + def is_valid(self) -> bool: + """Check if the InboundHandler has any Configured Parameters.""" + return len(self.locations_to_visit) > 0 + + def add_inbound_handling_to( + self, view_func: JeroboamRouteCallable + ) -> JeroboamRouteCallable: """It injects inbound parsed and validated data into the view function.""" - @wraps(func) + @wraps(view_func) def wrapper(*args, **kwargs) -> JeroboamResponseReturnValue: - location = self._parse_incoming_request_data() - kwargs = self._validate_inbound_data(location, kwargs) - return current_app.ensure_sync(func)(*args, **kwargs) + inbound_values, errors = self._parse_and_validate_inbound_data(**kwargs) + if errors: + raise InvalidRequest([errors]) + return view_func(*args, **inbound_values) return wrapper - def _parse_incoming_request_data(self) -> dict: - """Getting the Data out of the Request Object.""" - if self.main_http_verb == MethodEnum.GET: - location = dict(request.args.lists()) - location = self._rename_query_params_keys(location, pattern) - elif self.main_http_verb == MethodEnum.POST: - location = dict(request.form.lists()) - location = self._rename_query_params_keys(location, pattern) - if request.data: - # TODO: on.3.8.drop location |= dict(json.loads(request.data)) - location.update(dict(json.loads(request.data))) - # TODO: on.3.8.drop location |= dict(request.files) # type: ignore - location.update(dict(request.files)) # type: ignore - else: # pragma: no cover - # TODO: Statement cannot be reached at this point. - location = {} - return location - - def _validate_inbound_data(self, location, kwargs) -> dict: - """Getting the Data out of the Request Object.""" - for arg_name, typed_param in self.typed_params.parameters.items(): - if getattr(typed_param.annotation, "__origin__", None) == t.Union: - kwargs[arg_name] = self._validate_input( - typed_param.annotation.__args__[0], **location - ) - elif issubclass(typed_param.annotation, BaseModel): - kwargs[arg_name] = self._validate_input( - typed_param.annotation, **location - ) - elif arg_name not in self.rule: - kwargs[arg_name] = self._simple_validate_input( - typed_param.annotation, location, arg_name - ) - return kwargs - - def _validate_input(self, model: Type[BaseModel], **kwargs: ParamSpec) -> BaseModel: - try: - return model(**kwargs) - except ValueError as e: - raise InvalidRequest(msg=str(e)) from e - - def _simple_validate_input(self, type_: T, payload: dict, key: str) -> T: - try: - return type_(payload.get(key, None)) - except ValueError as e: - raise InvalidRequest(msg=str(e)) from e - - def _rename_query_params_keys(self, inbound_dict: dict, pattern: str) -> dict: - """Rename keys in a dictionary.""" - renamings = [] - for key, value in inbound_dict.items(): - match = re.match(pattern, key) - if len(value) == 1 and match is None: - inbound_dict[key] = value[0] - elif match is not None: - new_key = f"{match[1]}[]" - new_value = {match[2]: value[0]} - renamings.append((key, new_key, new_value)) - for key, new_key, new_value in renamings: - if new_key not in inbound_dict: - inbound_dict[new_key] = [new_value] - else: - inbound_dict[new_key].append(new_value) - del inbound_dict[key] - return inbound_dict + def _check_compliance(self): + """Will warn the user if their view function does something a bit off.""" + if len(self.form_params + self.file_params) > 0 and self.main_http_verb not in { + "POST", + "PUT", + "PATCH", + }: + import warnings + + warnings.warn( + f"You have defined Form or File Parameters on a " + f"{self.main_http_verb} request. " + "This is not supported by Flask:" + "https://flask.palletsprojects.com/en/2.2.x/api/#incoming-request-data", + UserWarning, + ) + + def _solve_params(self, view_func: Callable): + """Registering the Parameters of the View Function.""" + signature = get_typed_signature(view_func) + for parameter_name, parameter in signature.parameters.items(): + solved_param = self._solve_view_function_parameter( + param_name=parameter_name, param=parameter + ) + # Check if Param is in Path (not needed for now) + self._register_view_parameter(solved_param) + + def _solve_view_function_parameter( + self, + param_name: str, + param: inspect.Parameter, + force_location: Optional[ParamLocation] = None, + ignore_default: bool = False, + ) -> SolvedParameter: + """Analyse the param and its annotation to solve its configiration. + + At the end of this process, we want to know the following things: + - What is its location? + - What is its type/annotation? + - Is it a scalar or a sequence? + - Is it required and/or has a default value? + """ + # Solving Location + if param_name in self.path_param_names: + solved_location = ParamLocation.path + else: + solved_location = getattr( + param.default, "location", force_location or self.default_param_location + ) + # Get the ViewParam + if isinstance(param.default, ViewParameter): + view_param = param.default + else: + param_class = get_parameter_class(solved_location) + view_param = param_class(default=param.default) + + # Solving Default Value + default_value: Any = getattr(param.default, "default", param.default) + if default_value == param.empty or ignore_default: + default_value = Undefined + + # Solving Required + required: bool = default_value is Undefined + + # Solving Type + annotation: Any = Any + if not param.annotation == param.empty: + annotation = param.annotation + annotation = get_annotation_from_field_info(annotation, view_param, param_name) + + return SolvedParameter( + name=param_name, + type_=annotation, + required=required, + view_param=view_param, + ) + + def _register_view_parameter(self, solved_parameter: SolvedParameter) -> None: + """Registering the Solved View parameters for the View Function. + + The registration will put the params in the right list + and add the location to the locations_to_visit set. + """ + assert solved_parameter.location is not None # noqa: S101 + self.locations_to_visit.add(solved_parameter.location) + { + ParamLocation.query: self.query_params, + ParamLocation.path: self.path_params, + ParamLocation.header: self.header_params, + ParamLocation.body: self.body_params, + ParamLocation.form: self.form_params, + ParamLocation.cookie: self.cookie_params, + ParamLocation.file: self.file_params, + }.get(solved_parameter.location, self.other_params).append(solved_parameter) + + def _parse_and_validate_inbound_data( + self, **kwargs + ) -> Tuple[Dict, Union[List, ErrorWrapper]]: + """Parse and Validate the request Inbound data.""" + errors = [] + values = {} + for location in self.locations_to_visit: + params = { + ParamLocation.query: self.query_params, + ParamLocation.path: self.path_params, + ParamLocation.header: self.header_params, + ParamLocation.body: self.body_params, + ParamLocation.form: self.form_params, + ParamLocation.cookie: self.cookie_params, + ParamLocation.file: self.file_params, + }.get(location, []) + for param in params: + values_, errors_ = param.validate_request() + errors.extend(errors_) + values.update(values_) + return values, errors diff --git a/flask_jeroboam/view.py b/flask_jeroboam/view.py index 56fdede..63d5ea4 100644 --- a/flask_jeroboam/view.py +++ b/flask_jeroboam/view.py @@ -62,8 +62,8 @@ def as_view(self) -> JeroboamRouteCallable: name = view_func.__name__ doc = view_func.__doc__ - if self.inbound_handler: - view_func = self.inbound_handler(view_func) + if self.inbound_handler.is_valid: + view_func = self.inbound_handler.add_inbound_handling_to(view_func) if self.outbound_handler.is_valid_handler(): view_func = self.outbound_handler.add_outbound_handling_to(view_func) diff --git a/flask_jeroboam/view_params/solved.py b/flask_jeroboam/view_params/solved.py index 9c94dad..c0e1877 100644 --- a/flask_jeroboam/view_params/solved.py +++ b/flask_jeroboam/view_params/solved.py @@ -102,6 +102,7 @@ def _get_values(self) -> Union[dict, Optional[str], List[Any]]: def _get_values_from_body(self): """Get the values from the request body.""" + source: Any = {} if self.location == ParamLocation.form: source = request.form elif self.location == ParamLocation.file: From 9ff32e87c1fd417a2c5ae57ce541fc6f206b4363 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Bianic Date: Tue, 31 Jan 2023 17:24:38 +0100 Subject: [PATCH 04/10] :broom: Small Improvement and Utils (from FastAPI) --- flask_jeroboam/exceptions.py | 30 +++++++++---- flask_jeroboam/jeroboam.py | 9 ++++ flask_jeroboam/utils.py | 86 ++++++++++++++++++++++++++++++++++-- 3 files changed, 114 insertions(+), 11 deletions(-) diff --git a/flask_jeroboam/exceptions.py b/flask_jeroboam/exceptions.py index d867144..5cbf3e5 100644 --- a/flask_jeroboam/exceptions.py +++ b/flask_jeroboam/exceptions.py @@ -3,14 +3,24 @@ They are small wrappers around werkzeug HTTP exceptions that customize how the message is colllected and formatted. """ +from typing import Any from typing import Optional +from typing import Sequence from typing import Tuple +from typing import Type +from pydantic import BaseModel +from pydantic import ValidationError +from pydantic import create_model +from pydantic.error_wrappers import ErrorList from werkzeug.exceptions import BadRequest from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import NotFound +RequestErrorModel: Type[BaseModel] = create_model("Request") + + class RessourceNotFound(NotFound): """A slightly modified version of Werkzeug's RessourceNotFound Exception.""" @@ -40,18 +50,16 @@ def handle(self) -> Tuple[str, int]: return str(self), 404 -class InvalidRequest(BadRequest): +class InvalidRequest(ValidationError, BadRequest): """A slightly modifiedversion of Werkzeug's BadRequest Exception.""" - def __init__(self, msg: Optional[str]): - self.msg = msg - - def __str__(self) -> str: - return f"BadRequest: {self.msg}" + def __init__(self, errors: Sequence[ErrorList], *, body: Any = None): + self.body = body + super().__init__(errors, RequestErrorModel) - def handle(self) -> Tuple[str, int]: + def handle(self) -> Tuple[dict, int]: """Handle the exception and return a message to the user.""" - return str(self), 400 + return {"detail": self.errors()}, 400 class ServerError(InternalServerError): @@ -78,3 +86,9 @@ class ResponseValidationError(ServerError): def __str__(self) -> str: return f"InternalServerError: {self.msg}" + + +class JeroboamError(Exception): + """Base Exception for Flask-Jeroboam.""" + + pass diff --git a/flask_jeroboam/jeroboam.py b/flask_jeroboam/jeroboam.py index f96ba8b..81163e5 100644 --- a/flask_jeroboam/jeroboam.py +++ b/flask_jeroboam/jeroboam.py @@ -2,15 +2,20 @@ Here we overide the route method of the Flask object to use our custom implementation. This allow us to introduce new functionality to the route registration process. + +TODO: A probably better way to override it is to override the url_rule_class +with a custom JeroboamRule Object """ from typing import Any from typing import Callable +from typing import Optional from flask import Flask from flask.blueprints import Blueprint from flask.scaffold import setupmethod from typing_extensions import TypeVar +from .responses import JSONResponse from .typing import JeroboamRouteCallable from .view import JeroboamView @@ -117,6 +122,10 @@ class Jeroboam(JeroboamScaffoldOverRide, Flask): # type:ignore route decorator. """ + response_class = JSONResponse + + query_string_key_transformer: Optional[Callable] = None + pass diff --git a/flask_jeroboam/utils.py b/flask_jeroboam/utils.py index 8b32e55..eaddae0 100644 --- a/flask_jeroboam/utils.py +++ b/flask_jeroboam/utils.py @@ -1,16 +1,42 @@ """Utility Functions for Flask-Jeroboam. -Credits: the three methods in this module get_typed_signature, -get_typed_annotation, get_typed_return_annotation are from -FASTApi Source Code https://github.com/tiangolo/fastapi +Credits: this is essentially a fork of FastAPI's own utils.py +Original Source Code at https://github.com/tiangolo/fastapi """ +import dataclasses import inspect +import re from typing import Any from typing import Callable from typing import Dict from typing import ForwardRef +from pydantic import BaseModel +from pydantic.fields import SHAPE_FROZENSET +from pydantic.fields import SHAPE_LIST +from pydantic.fields import SHAPE_SEQUENCE +from pydantic.fields import SHAPE_SET +from pydantic.fields import SHAPE_SINGLETON +from pydantic.fields import SHAPE_TUPLE +from pydantic.fields import SHAPE_TUPLE_ELLIPSIS +from pydantic.fields import ModelField from pydantic.typing import evaluate_forwardref +from pydantic.utils import lenient_issubclass + +from flask_jeroboam.view_params import ParamLocation +from flask_jeroboam.view_params import ViewParameter + + +sequence_shapes = { + SHAPE_LIST, + SHAPE_SET, + SHAPE_FROZENSET, + SHAPE_TUPLE, + SHAPE_SEQUENCE, + SHAPE_TUPLE_ELLIPSIS, +} +sequence_types = (list, set, tuple) +body_locations = {ParamLocation.body, ParamLocation.form, ParamLocation.file} def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: @@ -47,3 +73,57 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any: # pragma: no globalns = getattr(call, "__globals__", {}) return get_typed_annotation(annotation, globalns) + + +def is_scalar_field(field: ModelField) -> bool: + """Check if a field is a scalar field.""" + field_info = field.field_info + if not ( + field.shape == SHAPE_SINGLETON + and not lenient_issubclass(field.type_, BaseModel) + and not lenient_issubclass(field.type_, sequence_types + (dict,)) + and not dataclasses.is_dataclass(field.type_) + and not isinstance(field_info, ViewParameter) + and not getattr(field_info, "location", None) in body_locations + ): + return False + if field.sub_fields: # pragma: no cover + if not all(is_scalar_field(f) for f in field.sub_fields): + return False + return True + + +def is_scalar_sequence_field(field: ModelField) -> bool: + """Check if a field is a sequence field.""" + if (field.shape in sequence_shapes) and not lenient_issubclass( + field.type_, BaseModel + ): + if field.sub_fields is not None: # pragma: no cover + for sub_field in field.sub_fields: + if not is_scalar_field(sub_field): + return False + return True + if lenient_issubclass(field.type_, sequence_types): # pragma: no cover + return True + return False + + +def _rename_query_params_keys(self, inbound_dict: dict, pattern: str) -> dict: + """Rename keys in a dictionary. + + Probablement Obsolete. + """ + renamings = [] + for key, value in inbound_dict.items(): + match = re.match(pattern, key) + if match is not None: + new_key = f"{match[1]}[]" + new_value = {match[2]: value} + renamings.append((key, new_key, new_value)) + for key, new_key, new_value in renamings: + if new_key not in inbound_dict: + inbound_dict[new_key] = [new_value] + else: + inbound_dict[new_key].append(new_value) + del inbound_dict[key] + return inbound_dict From ff2974b2a0af0a308db573c4a72ee1457003b16d Mon Sep 17 00:00:00 2001 From: Jean-Christophe Bianic Date: Tue, 31 Jan 2023 17:32:54 +0100 Subject: [PATCH 05/10] :white_check_mark: Adding the Test Suite (Refactoring Needed) --- .flake8 | 2 +- docs/installation.rst | 2 +- flask_jeroboam/view_params/solved.py | 6 +- tests/conftest.py | 340 +++++++++++++++++ .../test_base_configuration.py | 12 + tests/inbound_handler/test_body_parameter.py | 68 ++++ .../inbound_handler/test_cookie_parameter.py | 38 ++ .../inbound_handler/test_header_parameter.py | 37 ++ tests/inbound_handler/test_path_operations.py | 355 ++++++++++++++++++ .../inbound_handler/test_query_operations.py | 69 ++++ tests/inbound_handler/test_sub_fields.py | 48 +++ tests/inbound_handler/test_warnings.py | 17 + tests/test_error_handling.py | 21 +- tests/test_inbound_handler.py | 51 ++- tests/test_outbound_handler.py | 9 +- tests/test_utils.py | 81 +++- 16 files changed, 1118 insertions(+), 38 deletions(-) create mode 100644 tests/inbound_handler/test_base_configuration.py create mode 100644 tests/inbound_handler/test_body_parameter.py create mode 100644 tests/inbound_handler/test_cookie_parameter.py create mode 100644 tests/inbound_handler/test_header_parameter.py create mode 100644 tests/inbound_handler/test_path_operations.py create mode 100644 tests/inbound_handler/test_query_operations.py create mode 100644 tests/inbound_handler/test_sub_fields.py create mode 100644 tests/inbound_handler/test_warnings.py diff --git a/.flake8 b/.flake8 index 29049fd..c560eda 100644 --- a/.flake8 +++ b/.flake8 @@ -5,7 +5,7 @@ max-line-length = 80 max-complexity = 10 docstring-convention = google per-file-ignores = - tests/*:S101,D100,D205,D415,S106 + tests/*:S101,D100,D205,D415,S106,B008,D101 __init__.py:F401 typing.py:F401 rst-roles = class,const,func,meth,mod,ref diff --git a/docs/installation.rst b/docs/installation.rst index 803fa89..7943229 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -53,7 +53,7 @@ Python Version Your first dependency, and the main one at that, is your Python installation. When you overlook this, you end up using your system default, often outdated, Python installation. -The best practice is to use the latest stable version of Python, which is 3.11 as I write this. :ref:`see how `. The Python core team is doing an amazing job and it would be a shame to miss out on all the improvement they bring to the game release after release. +The best practice is to use the latest stable version of Python, which is 3.11 as I write this. :ref:`see how `. The Python core team is doing an amazing job and it would be a shame to miss out on all the improvement they bring to the game release after release. That being said, **Flask-Jeroboam** supports Python down to its 3.8 installment. It means that the CI/CD pipeline tests the package from Python 3.8 to the most recent release. In the future, I will progressively diff --git a/flask_jeroboam/view_params/solved.py b/flask_jeroboam/view_params/solved.py index c0e1877..8750a76 100644 --- a/flask_jeroboam/view_params/solved.py +++ b/flask_jeroboam/view_params/solved.py @@ -72,7 +72,7 @@ def validate_request(self): """Validate the request.""" values = {} errors = [] - + assert self.location is not None # noqa: S101 inbound_values = self._get_values() if inbound_values is None: if self.required: @@ -100,7 +100,7 @@ def _get_values(self) -> Union[dict, Optional[str], List[Any]]: else: return self._get_values_from_request() - def _get_values_from_body(self): + def _get_values_from_body(self) -> Any: """Get the values from the request body.""" source: Any = {} if self.location == ParamLocation.form: @@ -108,7 +108,7 @@ def _get_values_from_body(self): elif self.location == ParamLocation.file: source = request.files else: - source = request.json + source = request.json or {} if self.embed: values = source.get(self.alias or self.name) else: diff --git a/tests/conftest.py b/tests/conftest.py index 12ae1d1..4ff5bd5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ """Configuration File for pytest.""" import os +from typing import FrozenSet +from typing import Optional import pytest @@ -9,6 +11,11 @@ from flask_jeroboam.exceptions import ResponseValidationError from flask_jeroboam.exceptions import RessourceNotFound from flask_jeroboam.exceptions import ServerError +from flask_jeroboam.view_params import Body +from flask_jeroboam.view_params import Header +from flask_jeroboam.view_params import Path +from flask_jeroboam.view_params import Query +from flask_jeroboam.view_params.functions import Cookie @pytest.fixture @@ -20,10 +27,313 @@ def app() -> Jeroboam: SECRET_KEY="RandomSecretKey", ) # TODO: Add it by default with CONFIG OPT-OUT + + def handle_404(e): + return {"message": "Not Found"}, 404 + app.register_error_handler(InvalidRequest, InvalidRequest.handle) app.register_error_handler(RessourceNotFound, RessourceNotFound.handle) app.register_error_handler(ServerError, ServerError.handle) app.register_error_handler(ResponseValidationError, ResponseValidationError.handle) + app.register_error_handler(404, handle_404) + + @app.route("/api_route") + def non_operation(): + return {"message": "Hello World"} + + @app.get("/text") + def get_text(): + return "Hello World" + + return app + + +@pytest.fixture +def app_with_path_operations(app) -> Jeroboam: # noqa: C901 + """App with reigistered path operations.""" + + @app.get("/path/") + def get_id(item_id): + return {"item_id": item_id} + + @app.get("/path/str/") + def get_str_id(item_id: str): + return {"item_id": item_id} + + @app.get("/path/int/") + def get_int_id(item_id: int): + return {"item_id": item_id} + + @app.get("/path/float/") + def get_float_id(item_id: float): + return {"item_id": item_id} + + @app.get("/path/bool/") + def get_bool_id(item_id: bool): + return {"item_id": item_id} + + @app.get("/path/param/") + def get_path_param_id(item_id: str = Path()): + return {"item_id": item_id} + + @app.get("/path/param-required/") + def get_path_param_required_id(item_id: str = Path()): + return {"item_id": item_id} + + @app.get("/path/param-minlength/") + def get_path_param_min_length(item_id: str = Path(min_length=3)): + return {"item_id": item_id} + + @app.get("/path/param-maxlength/") + def get_path_param_max_length(item_id: str = Path(max_length=3)): + return {"item_id": item_id} + + @app.get("/path/param-min_maxlength/") + def get_path_param_min_max_length(item_id: str = Path(max_length=3, min_length=2)): + return {"item_id": item_id} + + @app.get("/path/param-gt/") + def get_path_param_gt(item_id: float = Path(gt=3)): + return {"item_id": item_id} + + @app.get("/path/param-gt0/") + def get_path_param_gt0(item_id: float = Path(gt=0)): + return {"item_id": item_id} + + @app.get("/path/param-ge/") + def get_path_param_ge(item_id: float = Path(ge=3)): + return {"item_id": item_id} + + @app.get("/path/param-lt/") + def get_path_param_lt(item_id: float = Path(lt=3)): + return {"item_id": item_id} + + @app.get("/path/param-lt0/") + def get_path_param_lt0(item_id: float = Path(lt=0)): + return {"item_id": item_id} + + @app.get("/path/param-le/") + def get_path_param_le(item_id: float = Path(le=3)): + return {"item_id": item_id} + + @app.get("/path/param-lt-gt/") + def get_path_param_lt_gt(item_id: float = Path(lt=3, gt=1)): + return {"item_id": item_id} + + @app.get("/path/param-le-ge/") + def get_path_param_le_ge(item_id: float = Path(le=3, ge=1)): + return {"item_id": item_id} + + @app.get("/path/param-lt-int/") + def get_path_param_lt_int(item_id: int = Path(lt=3)): + return {"item_id": item_id} + + @app.get("/path/param-gt-int/") + def get_path_param_gt_int(item_id: int = Path(gt=3)): + return {"item_id": item_id} + + @app.get("/path/param-le-int/") + def get_path_param_le_int(item_id: int = Path(le=3)): + return {"item_id": item_id} + + @app.get("/path/param-ge-int/") + def get_path_param_ge_int(item_id: int = Path(ge=3)): + return {"item_id": item_id} + + @app.get("/path/param-lt-gt-int/") + def get_path_param_lt_gt_int(item_id: int = Path(lt=3, gt=1)): + return {"item_id": item_id} + + @app.get("/path/param-le-ge-int/") + def get_path_param_le_ge_int(item_id: int = Path(le=3, ge=1)): + return {"item_id": item_id} + + @app.get("/path/with_converter/") + def get_with_preproc_id(item_id): + return {"item_id": item_id} + + @app.get("/path/with_converter/str/") + def get_with_preproc_str_id(item_id: str): + return {"item_id": item_id} + + @app.get("/path/with_converter/int/") + def get_with_preproc_int_id(item_id: int): + return {"item_id": item_id} + + @app.get("/path/with_converter/float/") + def get_with_preproc_float_id(item_id: float): + return {"item_id": item_id} + + @app.get("/path/with_converter/param/") + def get_with_preproc_path_param_id(item_id: str = Path()): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-required/") + def get_with_preproc_path_param_required_id(item_id: str = Path()): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-minlength/") + def get_with_preproc_path_param_min_length(item_id: str = Path(min_length=3)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-maxlength/") + def get_with_preproc_path_param_max_length(item_id: str = Path(max_length=3)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-min_maxlength/") + def get_with_preproc_path_param_min_max_length( + item_id: str = Path(max_length=3, min_length=2) + ): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-gt/") + def get_with_preproc_path_param_gt(item_id: float = Path(gt=3)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-gt0/") + def get_with_preproc_path_param_gt0(item_id: float = Path(gt=0)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-ge/") + def get_with_preproc_path_param_ge(item_id: float = Path(ge=3)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-lt/") + def get_with_preproc_path_param_lt(item_id: float = Path(lt=3)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-lt0/") + def get_with_preproc_path_param_lt0(item_id: float = Path(lt=0)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-le/") + def get_with_preproc_path_param_le(item_id: float = Path(le=3)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-lt-gt/") + def get_with_preproc_path_param_lt_gt(item_id: float = Path(lt=3, gt=1)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-le-ge/") + def get_with_preproc_path_param_le_ge(item_id: float = Path(le=3, ge=1)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-lt-int/") + def get_with_preproc_path_param_lt_int(item_id: int = Path(lt=3)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-gt-int/") + def get_with_preproc_path_param_gt_int(item_id: int = Path(gt=3)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-le-int/") + def get_with_preproc_path_param_le_int(item_id: int = Path(le=3)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-ge-int/") + def get_with_preproc_path_param_ge_int(item_id: int = Path(ge=3)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-lt-gt-int/") + def get_with_preproc_path_param_lt_gt_int(item_id: int = Path(lt=3, gt=1)): + return {"item_id": item_id} + + @app.get("/path/with_converter/param-le-ge-int/") + def get_with_preproc_path_param_le_ge_int(item_id: int = Path(le=3, ge=1)): + return {"item_id": item_id} + + return app + + +@pytest.fixture +def app_with_query_operations(app: Jeroboam) -> Jeroboam: # noqa: C901 + """App with reigistered query operations.""" + + @app.get("/query/frozenset/") + def get_query_type_frozenset(query: FrozenSet[int] = Query(...)): + return {"query": ",".join(map(str, sorted(query)))} + + @app.get("/query") + def get_query(query): + return {"query": query} + + @app.get("/query/optional") + def get_query_optional(query=None): + return {"query": query} + + @app.get("/query/int") + def get_query_type(query: int): + return {"query": query} + + @app.get("/query/int/optional") + def get_query_type_optional(query: Optional[int] = None): + return {"query": query} + + @app.get("/query/int/default") + def get_query_type_int_default(query: int = 10): + return {"query": query} + + @app.get("/query/param") + def get_query_param(query=Query(default=None)): + return {"query": query} + + @app.get("/query/param-required") + def get_query_param_required(query=Query()): + return {"query": query} + + @app.get("/query/param-required/int") + def get_query_param_required_type(query: int = Query()): + return {"query": query} + + @app.get("/enum-status-code", status_code=201) + def get_enum_status_code(): + return {} + + return app + + +@pytest.fixture +def app_with_cookie_parameters(app: Jeroboam) -> Jeroboam: + """App with reigistered cookie parameters.""" + + @app.get("/cookie/int") + def get_cookie_as_int(cookie: int = Cookie()): + return {"cookie": cookie} + + @app.get("/cookie/str") + def get_cookie_as_str(cookie: str = Cookie()): + return {"cookie": cookie} + + return app + + +@pytest.fixture +def app_with_header_parameters(app: Jeroboam) -> Jeroboam: + """App with reigistered cookie parameters.""" + + @app.get("/headers/int") + def get_header_as_int(test_header: int = Header()): + return {"header": test_header} + + @app.get("/headers/str") + def get_header_as_str(test_header: str = Header()): + return {"header": test_header} + + return app + + +@pytest.fixture +def app_with_body_parameters(app: Jeroboam) -> Jeroboam: + """App with reigistered cookie parameters.""" + + @app.post("/body/int") + def post_body_as_int(payload: int = Body()): + return {"payload": payload} + + @app.post("/body/str") + def post_body_as_str(payload: str = Body()): + return {"payload": payload} + return app @@ -51,3 +361,33 @@ def request_context(app: Jeroboam): def client(app: Jeroboam): """Test Client from the Test App.""" return app.test_client() + + +@pytest.fixture +def query_client(app_with_query_operations: Jeroboam): + """Test Client from the Test App.""" + return app_with_query_operations.test_client() + + +@pytest.fixture +def path_client(app_with_path_operations: Jeroboam): + """Test Client from the Test App.""" + return app_with_path_operations.test_client() + + +@pytest.fixture +def cookie_client(app_with_cookie_parameters: Jeroboam): + """Test Client from the Test App.""" + return app_with_cookie_parameters.test_client() + + +@pytest.fixture +def header_client(app_with_header_parameters: Jeroboam): + """Test Client from the Test App.""" + return app_with_header_parameters.test_client() + + +@pytest.fixture +def body_client(app_with_body_parameters: Jeroboam): + """Test Client from the Test App.""" + return app_with_body_parameters.test_client() diff --git a/tests/inbound_handler/test_base_configuration.py b/tests/inbound_handler/test_base_configuration.py new file mode 100644 index 0000000..7ff3e53 --- /dev/null +++ b/tests/inbound_handler/test_base_configuration.py @@ -0,0 +1,12 @@ +def test_text_get(client): + """Test GET /text""" + response = client.get("/text") + assert response.status_code == 200, response.text + assert response.data == b"Hello World" + + +def test_nonexistent(client): + """Test GET /nonexistent""" + response = client.get("/nonexistent") + assert response.status_code == 404, response.text + assert response.json == {"message": "Not Found"} diff --git a/tests/inbound_handler/test_body_parameter.py b/tests/inbound_handler/test_body_parameter.py new file mode 100644 index 0000000..69be259 --- /dev/null +++ b/tests/inbound_handler/test_body_parameter.py @@ -0,0 +1,68 @@ +from typing import List + +import pytest + +from flask_jeroboam.models import Parser +from flask_jeroboam.view_params.functions import Body + + +response_not_valid_int = { + "detail": [ + { + "loc": ["body", "payload"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] +} + + +def _valid(value) -> dict: + """Valid function.""" + return {"payload": value} + + +@pytest.mark.parametrize( + "url,body_value,expected_status,expected_response", + [ + ("/body/str", {"payload": "foobar"}, 200, _valid("foobar")), + ("/body/int", {"payload": 123}, 200, _valid(123)), + ("/body/int", {"payload": "not_a_valid_int"}, 400, response_not_valid_int), + ], +) +def test_post_body_operations( + body_client, url, body_value, expected_status, expected_response +): + """Testing Various GET operations with query parameters. + + GIVEN a GET endpoint configiured with query parameters + WHEN a request is made to the endpoint + THEN the request is parsed and validated accordingly + """ + response = body_client.post(url, json=body_value) + assert response.json == expected_response + assert response.status_code == expected_status + + +def test_post_body_list_of_base_model(app, client): + """Test Body Parameter with POST method.""" + + class InBound(Parser): + """Inbound model.""" + + item: str + count: int + + @app.post("/body/list_non_scalar", response_model=List[InBound]) + def post_body_list_non_scalar(payload: List[InBound] = Body(embed=False)): + return payload + + response = client.post( + "/body/list_non_scalar", + json=[{"item": "foobar", "count": 1}, {"item": "bar", "count": 3}], + ) + assert response.json == [ + {"item": "foobar", "count": 1}, + {"item": "bar", "count": 3}, + ] + assert response.status_code == 201 diff --git a/tests/inbound_handler/test_cookie_parameter.py b/tests/inbound_handler/test_cookie_parameter.py new file mode 100644 index 0000000..dd0c6e0 --- /dev/null +++ b/tests/inbound_handler/test_cookie_parameter.py @@ -0,0 +1,38 @@ +import pytest + + +not_a_valid_int = { + "detail": [ + { + "loc": ["cookie", "cookie"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] +} + + +def _valid(x): + return {"cookie": x} + + +@pytest.mark.parametrize( + "url,cookie_value,expected_status,expected_response", + [ + ("/cookie/str", b"foobar", 200, _valid("foobar")), + ("/cookie/int", b"123", 200, _valid(123)), + ("/cookie/int", b"not_valid_int", 400, not_a_valid_int), + ], +) +def test_get_cookie( + cookie_client, url, cookie_value, expected_status, expected_response +): + """Test Cookie Parameter with GET method. + + + TODO: Allow Configuration of the returned Status Code. + """ + cookie_client.set_cookie("localhost", "cookie", cookie_value) + response = cookie_client.get(url) + assert response.status_code == expected_status + assert response.json == expected_response diff --git a/tests/inbound_handler/test_header_parameter.py b/tests/inbound_handler/test_header_parameter.py new file mode 100644 index 0000000..0754159 --- /dev/null +++ b/tests/inbound_handler/test_header_parameter.py @@ -0,0 +1,37 @@ +import pytest + + +not_a_valid_int = { + "detail": [ + { + "loc": ["header", "Test-Header"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] +} + + +def _valid(x): + return {"header": x} + + +@pytest.mark.parametrize( + "url,header_value,expected_status,expected_response", + [ + ("/headers/str", {"test-header": "foobar"}, 200, _valid("foobar")), + ("/headers/int", {"test-header": "123"}, 200, _valid(123)), + ("/headers/int", {"test-header": "not_a_valid_int"}, 400, not_a_valid_int), + ], +) +def test_get_headers( + header_client, url, header_value, expected_status, expected_response +): + """Test Cookie Parameter with GET method. + + + TODO: Allow Configuration of the returned Status Code. + """ + response = header_client.get(url, headers=header_value) + assert response.status_code == expected_status + assert response.json == expected_response diff --git a/tests/inbound_handler/test_path_operations.py b/tests/inbound_handler/test_path_operations.py new file mode 100644 index 0000000..ba36095 --- /dev/null +++ b/tests/inbound_handler/test_path_operations.py @@ -0,0 +1,355 @@ +"""Testing path operations.""" +import pytest + + +response_not_valid_bool = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "value could not be parsed to a boolean", + "type": "type_error.bool", + } + ] +} + +response_not_valid_int = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] +} + +response_not_valid_float = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "value is not a valid float", + "type": "type_error.float", + } + ] +} + +response_at_least_3 = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "ensure this value has at least 3 characters", + "type": "value_error.any_str.min_length", + "ctx": {"limit_value": 3}, + } + ] +} + + +response_at_least_2 = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "ensure this value has at least 2 characters", + "type": "value_error.any_str.min_length", + "ctx": {"limit_value": 2}, + } + ] +} + + +response_maximum_3 = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "ensure this value has at most 3 characters", + "type": "value_error.any_str.max_length", + "ctx": {"limit_value": 3}, + } + ] +} + + +response_greater_than_3 = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "ensure this value is greater than 3", + "type": "value_error.number.not_gt", + "ctx": {"limit_value": 3}, + } + ] +} + + +response_greater_than_0 = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "ensure this value is greater than 0", + "type": "value_error.number.not_gt", + "ctx": {"limit_value": 0}, + } + ] +} + + +response_greater_than_1 = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "ensure this value is greater than 1", + "type": "value_error.number.not_gt", + "ctx": {"limit_value": 1}, + } + ] +} + + +response_greater_than_equal_3 = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "ensure this value is greater than or equal to 3", + "type": "value_error.number.not_ge", + "ctx": {"limit_value": 3}, + } + ] +} + + +response_less_than_3 = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "ensure this value is less than 3", + "type": "value_error.number.not_lt", + "ctx": {"limit_value": 3}, + } + ] +} + + +response_less_than_0 = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "ensure this value is less than 0", + "type": "value_error.number.not_lt", + "ctx": {"limit_value": 0}, + } + ] +} + + +response_less_than_equal_3 = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "ensure this value is less than or equal to 3", + "type": "value_error.number.not_le", + "ctx": {"limit_value": 3}, + } + ] +} + + +valid_foobar_str = {"item_id": "foobar"} +valid_42_str = {"item_id": "42"} +valid_true_str = {"item_id": "True"} +valid_true_bool = {"item_id": True} +valid_false_bool = {"item_id": False} +valid_42_int = {"item_id": 42} +valid_42_5_float = {"item_id": 42.5} + + +def _valid(x): + return {"item_id": x} + + +not_found = {"message": "Not Found"} + + +@pytest.mark.parametrize( + "url,expected_status,expected_response", + [ + ("/path/foobar", 200, valid_foobar_str), + ("/path/str/foobar", 200, valid_foobar_str), + ("/path/str/42", 200, valid_42_str), + ("/path/str/True", 200, valid_true_str), + ("/path/int/foobar", 400, response_not_valid_int), + ("/path/int/True", 400, response_not_valid_int), + ("/path/int/42", 200, valid_42_int), + ("/path/int/42.5", 400, response_not_valid_int), + ("/path/float/foobar", 400, response_not_valid_float), + ("/path/float/True", 400, response_not_valid_float), + ("/path/float/42", 200, _valid(42.0)), + ("/path/float/42.5", 200, valid_42_5_float), + ("/path/bool/foobar", 400, response_not_valid_bool), + ("/path/bool/True", 200, valid_true_bool), + ("/path/bool/42", 400, response_not_valid_bool), + ("/path/bool/42.5", 400, response_not_valid_bool), + ("/path/bool/1", 200, valid_true_bool), + ("/path/bool/0", 200, valid_false_bool), + ("/path/bool/true", 200, valid_true_bool), + ("/path/bool/False", 200, valid_false_bool), + ("/path/bool/false", 200, valid_false_bool), + ("/path/param/foo", 200, _valid("foo")), + ("/path/param-required/foo", 200, _valid("foo")), + ("/path/param-minlength/foo", 200, _valid("foo")), + ("/path/param-minlength/fo", 400, response_at_least_3), + ("/path/param-maxlength/foo", 200, _valid("foo")), + ("/path/param-maxlength/foobar", 400, response_maximum_3), + ("/path/param-min_maxlength/foo", 200, _valid("foo")), + ("/path/param-min_maxlength/foobar", 400, response_maximum_3), + ("/path/param-min_maxlength/f", 400, response_at_least_2), + ("/path/param-gt/42", 200, valid_42_int), + ("/path/param-gt/2", 400, response_greater_than_3), + ("/path/param-gt0/0.05", 200, _valid(0.05)), + ("/path/param-gt0/0", 400, response_greater_than_0), + ("/path/param-ge/42", 200, valid_42_int), + ("/path/param-ge/3", 200, _valid(3)), + ("/path/param-ge/2", 400, response_greater_than_equal_3), + ("/path/param-lt/42", 400, response_less_than_3), + ("/path/param-lt/2", 200, _valid(2)), + ("/path/param-lt0/-1", 200, _valid(-1)), + ("/path/param-lt0/0", 400, response_less_than_0), + ("/path/param-le/42", 400, response_less_than_equal_3), + ("/path/param-le/3", 200, _valid(3)), + ("/path/param-le/2", 200, _valid(2)), + ("/path/param-lt-gt/2", 200, _valid(2)), + ("/path/param-lt-gt/4", 400, response_less_than_3), + ("/path/param-lt-gt/0", 400, response_greater_than_1), + ("/path/param-le-ge/2", 200, _valid(2)), + ("/path/param-le-ge/1", 200, _valid(1)), + ("/path/param-le-ge/3", 200, _valid(3)), + ("/path/param-le-ge/4", 400, response_less_than_equal_3), + ("/path/param-lt-int/2", 200, _valid(2)), + ("/path/param-lt-int/42", 400, response_less_than_3), + ("/path/param-lt-int/2.7", 400, response_not_valid_int), + ("/path/param-gt-int/42", 200, valid_42_int), + ("/path/param-gt-int/2", 400, response_greater_than_3), + ("/path/param-gt-int/2.7", 400, response_not_valid_int), + ("/path/param-le-int/42", 400, response_less_than_equal_3), + ("/path/param-le-int/3", 200, _valid(3)), + ("/path/param-le-int/2", 200, _valid(2)), + ("/path/param-le-int/2.7", 400, response_not_valid_int), + ("/path/param-ge-int/42", 200, valid_42_int), + ("/path/param-ge-int/3", 200, _valid(3)), + ("/path/param-ge-int/2", 400, response_greater_than_equal_3), + ("/path/param-ge-int/2.7", 400, response_not_valid_int), + ("/path/param-lt-gt-int/2", 200, _valid(2)), + ("/path/param-lt-gt-int/4", 400, response_less_than_3), + ("/path/param-lt-gt-int/0", 400, response_greater_than_1), + ("/path/param-lt-gt-int/2.7", 400, response_not_valid_int), + ("/path/param-le-ge-int/2", 200, _valid(2)), + ("/path/param-le-ge-int/1", 200, _valid(1)), + ("/path/param-le-ge-int/3", 200, _valid(3)), + ("/path/param-le-ge-int/4", 400, response_less_than_equal_3), + ("/path/param-le-ge-int/2.7", 400, response_not_valid_int), + ], +) +def test_get_path(path_client, url, expected_status, expected_response): + """Test Path Operation with GET method. + + + TODO: Allow Configuration of the returned Status Code. + """ + response = path_client.get(url) + assert response.status_code == expected_status + assert response.json == expected_response + + +@pytest.mark.parametrize( + "url,expected_status,expected_response", + [ + ("/path/with_converter/foobar", 200, valid_foobar_str), + ("/path/with_converter/str/foobar", 200, valid_foobar_str), + ("/path/with_converter/str/42", 200, valid_42_str), + ("/path/with_converter/str/True", 200, _valid("True")), + ("/path/with_converter/int/42", 200, valid_42_int), + ("/path/with_converter/float/42.5", 200, valid_42_5_float), + ("/path/with_converter/param/foo", 200, _valid("foo")), + ("/path/with_converter/param-required/foo", 200, _valid("foo")), + ("/path/with_converter/param-minlength/foo", 200, _valid("foo")), + ("/path/with_converter/param-minlength/fo", 400, response_at_least_3), + ("/path/with_converter/param-maxlength/foo", 200, _valid("foo")), + ("/path/with_converter/param-maxlength/foobar", 400, response_maximum_3), + ("/path/with_converter/param-min_maxlength/foo", 200, _valid("foo")), + ("/path/with_converter/param-min_maxlength/foobar", 400, response_maximum_3), + ("/path/with_converter/param-min_maxlength/f", 400, response_at_least_2), + ("/path/with_converter/param-gt/42.0", 200, _valid(42.0)), + ("/path/with_converter/param-gt/2.0", 400, response_greater_than_3), + ("/path/with_converter/param-gt0/0.05", 200, _valid(0.05)), + ("/path/with_converter/param-gt0/0.0", 400, response_greater_than_0), + ("/path/with_converter/param-ge/42.0", 200, _valid(42.0)), + ("/path/with_converter/param-ge/3.0", 200, _valid(3.0)), + ("/path/with_converter/param-ge/2.0", 400, response_greater_than_equal_3), + ("/path/with_converter/param-lt/42.0", 400, response_less_than_3), + ("/path/with_converter/param-lt/2.0", 200, _valid(2.0)), + ("/path/with_converter/param-lt0/-1.0", 200, _valid(-1)), + ("/path/with_converter/param-lt0/0.0", 400, response_less_than_0), + ("/path/with_converter/param-le/42.0", 400, response_less_than_equal_3), + ("/path/with_converter/param-le/3.0", 200, _valid(3)), + ("/path/with_converter/param-le/2.0", 200, _valid(2)), + ("/path/with_converter/param-lt-gt/2.0", 200, _valid(2)), + ("/path/with_converter/param-lt-gt/4.0", 400, response_less_than_3), + ("/path/with_converter/param-lt-gt/0.0", 400, response_greater_than_1), + ("/path/with_converter/param-le-ge/2.0", 200, _valid(2)), + ("/path/with_converter/param-le-ge/1.0", 200, _valid(1)), + ("/path/with_converter/param-le-ge/3.0", 200, _valid(3)), + ("/path/with_converter/param-le-ge/4.0", 400, response_less_than_equal_3), + ("/path/with_converter/param-lt-int/2", 200, _valid(2)), + ("/path/with_converter/param-lt-int/42", 400, response_less_than_3), + ("/path/with_converter/param-gt-int/42", 200, _valid(42)), + ("/path/with_converter/param-gt-int/2", 400, response_greater_than_3), + ("/path/with_converter/param-le-int/42", 400, response_less_than_equal_3), + ("/path/with_converter/param-le-int/3", 200, _valid(3)), + ("/path/with_converter/param-le-int/2", 200, _valid(2)), + ("/path/with_converter/param-ge-int/42", 200, valid_42_int), + ("/path/with_converter/param-ge-int/3", 200, _valid(3)), + ("/path/with_converter/param-ge-int/2", 400, response_greater_than_equal_3), + ("/path/with_converter/param-lt-gt-int/2", 200, _valid(2)), + ("/path/with_converter/param-lt-gt-int/4", 400, response_less_than_3), + ("/path/with_converter/param-lt-gt-int/0", 400, response_greater_than_1), + ("/path/with_converter/param-le-ge-int/2", 200, _valid(2)), + ("/path/with_converter/param-le-ge-int/1", 200, _valid(1)), + ("/path/with_converter/param-le-ge-int/3", 200, _valid(3)), + ("/path/with_converter/param-le-ge-int/4", 400, response_less_than_equal_3), + ], +) +def test_get_path_with_converter(path_client, url, expected_status, expected_response): + """Test Path Operation with GET method. + + + TODO: Allow Configuration of the returned Status Code. + """ + response = path_client.get(url) + assert response.json == expected_response + assert response.status_code == expected_status + + +@pytest.mark.parametrize( + "url,expected_status,expected_response", + [ + ("/path/with_converter/int/foobar", 404, not_found), + ("/path/with_converter/int/True", 404, not_found), + ("/path/with_converter/int/42.5", 404, not_found), + ("/path/with_converter/float/foobar", 404, not_found), + ("/path/with_converter/float/True", 404, not_found), + ("/path/with_converter/float/42", 404, not_found), + ("/path/with_converter/param-le-int/2.7", 404, not_found), + ("/path/with_converter/param-ge-int/2.7", 404, not_found), + ("/path/with_converter/param-le-ge-int/2.7", 404, not_found), + ], +) +def test_path_converter_error_override_jeroboam_validation( + path_client, url, expected_status, expected_response +): + """Test Url Converter Overides PathParams Validation. + + GIVEN a Path Parameter with both Path (Jeroboam) and Converter (Flask) validation + WHEN the url is called with a value that does not match the converter + THEN the converter error is returned (404), not the Jeroboam one (400) + """ + response = path_client.get(url) + assert response.json == expected_response + assert response.status_code == expected_status diff --git a/tests/inbound_handler/test_query_operations.py b/tests/inbound_handler/test_query_operations.py new file mode 100644 index 0000000..c5b1a02 --- /dev/null +++ b/tests/inbound_handler/test_query_operations.py @@ -0,0 +1,69 @@ +import pytest + + +response_missing = { + "detail": [ + { + "loc": ["query", "query"], + "msg": "field required", + "type": "value_error.missing", + } + ] +} + +response_not_valid_int = { + "detail": [ + { + "loc": ["query", "query"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] +} + + +def _valid(value) -> dict: + """Valid function.""" + return {"query": value} + + +@pytest.mark.parametrize( + "url,expected_status,expected_response", + [ + ("/query", 400, response_missing), + ("/query?query=baz", 200, _valid("baz")), + ("/query?not_declared=baz", 400, response_missing), + ("/query/optional", 200, _valid(None)), + ("/query/optional?query=baz", 200, _valid("baz")), + ("/query/optional?not_declared=baz", 200, _valid(None)), + ("/query/int", 400, response_missing), + ("/query/int?query=42", 200, _valid(42)), + ("/query/int?query=42.5", 400, response_not_valid_int), + ("/query/int?query=baz", 400, response_not_valid_int), + ("/query/int?not_declared=baz", 400, response_missing), + ("/query/int/optional", 200, _valid(None)), + ("/query/int/optional?query=50", 200, _valid(50)), + ("/query/int/optional?query=foo", 400, response_not_valid_int), + ("/query/int/default", 200, _valid(10)), + ("/query/int/default?query=50", 200, _valid(50)), + ("/query/int/default?query=foo", 400, response_not_valid_int), + ("/query/param", 200, _valid(None)), + ("/query/param?query=50", 200, _valid("50")), + ("/query/param-required", 400, response_missing), + ("/query/param-required?query=50", 200, _valid("50")), + ("/query/param-required/int", 400, response_missing), + ("/query/param-required/int?query=50", 200, _valid(50)), + ("/query/param-required/int?query=foo", 400, response_not_valid_int), + ("/query/frozenset/?query=1&query=1&query=2", 200, _valid("1,2")), + ], +) +def test_get_query_operations(query_client, url, expected_status, expected_response): + """Testing Various GET operations with query parameters. + + GIVEN a GET endpoint configiured with query parameters + WHEN a request is made to the endpoint + THEN the request is parsed and validated accordingly + """ + response = query_client.get(url) + assert response.status_code == expected_status + assert response.json == expected_response diff --git a/tests/inbound_handler/test_sub_fields.py b/tests/inbound_handler/test_sub_fields.py new file mode 100644 index 0000000..f41ea36 --- /dev/null +++ b/tests/inbound_handler/test_sub_fields.py @@ -0,0 +1,48 @@ +from typing import Optional + +from pydantic import BaseModel +from pydantic import validator + + +class ModelB(BaseModel): + username: str + + +class ModelC(ModelB): + password: str + + +class ModelA(BaseModel): + name: str + description: Optional[str] = None + model_b: ModelB + + @validator("name") + def lower_username(cls, name: str, values): # noqa: B902, N805 + """Validate that the name ends in A.""" + if not name.endswith("A"): + raise ValueError("name must end in A") + return name + + +def test_get_query_operations(app, client): + """Testing Various GET operations with query parameters. + + GIVEN a GET endpoint configiured with query parameters + WHEN a request is made to the endpoint + THEN the request is parsed and validated accordingly + """ + + @app.post("/sub_model", response_model=ModelA) + def post_sub_model(model: ModelA): + return model + + response = client.post( + "/sub_model", json={"name": "fooA", "model_b": {"username": "bar"}} + ) + assert response.status_code == 201 + assert response.json == { + "name": "fooA", + "model_b": {"username": "bar"}, + "description": None, + } diff --git a/tests/inbound_handler/test_warnings.py b/tests/inbound_handler/test_warnings.py new file mode 100644 index 0000000..5339d41 --- /dev/null +++ b/tests/inbound_handler/test_warnings.py @@ -0,0 +1,17 @@ +import pytest + +from flask_jeroboam import Form + + +def test_form_param_on_get_raise_warning(app): + """A warning is raised when a Form parameter is used on a GET view. + + GIVEN a GET view with a Form parameter + WHEN registering the view + THEN a warning is raised + """ + with pytest.warns(UserWarning): + + @app.get("/form_on_get") + def form_on_get(not_allowed: str = Form(...)): + return "OK" diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py index 964ea99..ef859a6 100644 --- a/tests/test_error_handling.py +++ b/tests/test_error_handling.py @@ -3,7 +3,6 @@ """ from flask.testing import FlaskClient -from flask_jeroboam.exceptions import InvalidRequest from flask_jeroboam.exceptions import RessourceNotFound from flask_jeroboam.exceptions import ServerError from flask_jeroboam.jeroboam import Jeroboam @@ -19,13 +18,21 @@ def test_invalid_request( """ @app.get("/invalid_request") - def ping(): - raise InvalidRequest(msg="The Request was not valid") + def ping(missing_param: int): + return {} r = client.get("invalid_request") assert r.status_code == 400 - assert r.data.startswith(b"BadRequest:") + assert r.json == { + "detail": [ + { + "loc": ["query", "missing_param"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } def test_ressource_not_found_named_ressource( @@ -56,11 +63,11 @@ def test_ressource_not_found_generic_message( THEN I get a 404 with RessourceNotFound Generic Message """ - @app.get("/generic_ressource/") - def ping(id: int): + @app.get("/generic_ressource") + def generic_ressource(): raise RessourceNotFound(msg="My Message") - r = client.get("/generic_ressource/0") + r = client.get("/generic_ressource") assert r.status_code == 404 assert r.data.startswith(b"RessourceNotFound: My Message") diff --git a/tests/test_inbound_handler.py b/tests/test_inbound_handler.py index 71bcfab..ca7468d 100644 --- a/tests/test_inbound_handler.py +++ b/tests/test_inbound_handler.py @@ -3,7 +3,6 @@ verbs (GET, POST, DELETE) and configuration (Lists or Plain)... """ from io import BytesIO -from typing import Dict from typing import List from typing import Optional @@ -14,6 +13,9 @@ from flask_jeroboam.jeroboam import Jeroboam from flask_jeroboam.models import Parser +from flask_jeroboam.view_params.functions import Body +from flask_jeroboam.view_params.functions import File +from flask_jeroboam.view_params.functions import Form class InBoundModel(BaseModel): @@ -71,7 +73,7 @@ def test_valid_payload_in_json_is_injected( """ @app.post("/payload_in_json") - def read_test(payload: InBoundModel): + def read_test(payload: InBoundModel = Body(embed=False)): return payload.json() r = client.post("/payload_in_json", json={"page": 1, "type": "item"}) @@ -90,7 +92,7 @@ def test_valid_payload_in_data_is_injected( """ @app.post("/payload_in_json") - def read_test(payload: InBoundModel): + def read_test(payload: InBoundModel = Form()): return payload.json() r = client.post("/payload_in_json", data={"page": 1, "type": "item"}) @@ -105,16 +107,9 @@ def test_valid_payload_in_files_is_injected(app: Jeroboam, client: FlaskClient): THEN the parsed input is injected into the view function. """ - class InBoundModelWithFile(BaseModel): - type: str - file: FileStorage - - class Config: - arbitrary_types_allowed: bool = True - @app.post("/payload_in_file") - def ping(payload: InBoundModelWithFile): - return {"type": payload.type, "file_content": str(payload.file.read())} + def ping(file: FileStorage = File(...)): + return {"file_content": str(file.read())} data = {"file": (BytesIO(b"Hello World !!"), "hello.txt"), "type": "file"} @@ -127,7 +122,7 @@ def ping(payload: InBoundModelWithFile): ) assert r.status_code == 200 - assert r.data == b'{"file_content":"b\'Hello World !!\'","type":"file"}\n' + assert r.json == {"file_content": "b'Hello World !!'"} def test_invalid_query_string_raise_400( @@ -146,7 +141,20 @@ def read_test(payload: InBoundModel): r = client.get("/strict_endpoint?page=not_a_valid_param") assert r.status_code == 400 - assert r.data.startswith(b"BadRequest") + assert r.json == { + "detail": [ + { + "loc": ["query", "payload", "page"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + }, + { + "loc": ["query", "payload", "type"], + "msg": "none is not an allowed value", + "type": "type_error.none.not_allowed", + }, + ] + } def test_invalid_simple_param_raise_400( @@ -165,7 +173,15 @@ def ping(simple_param: int): r = client.get("query_string_as_int?simple_param=imparsable") assert r.status_code == 400 - assert r.data.startswith(b"BadRequest:") + assert r.json == { + "detail": [ + { + "loc": ["query", "simple_param"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] + } def test_query_string_for_list_arguments( @@ -179,15 +195,12 @@ def test_query_string_for_list_arguments( class QueryStringWithList(Parser): id: List[int] = Field(alias="id[]") - order: List[Dict[str, str]] = Field(alias="order[]") @app.get("/query_string_with_list") def ping(query_string: QueryStringWithList): return ",".join([str(id) for id in query_string.id]) - r = client.get( - "/query_string_with_list?order[name]=asc&order[group]=desc&id[]=1&id[]=2" - ) + r = client.get("/query_string_with_list?id[]=1&id[]=2") assert r.data == b"1,2" diff --git a/tests/test_outbound_handler.py b/tests/test_outbound_handler.py index ecba6f0..3eae048 100644 --- a/tests/test_outbound_handler.py +++ b/tests/test_outbound_handler.py @@ -2,7 +2,6 @@ We test for various return values (Response, Dict, ResponseModel), configuration (response_model or not) and error handling. """ -import json import warnings from dataclasses import dataclass from typing import List @@ -14,6 +13,7 @@ from pydantic import BaseModel from flask_jeroboam.jeroboam import Jeroboam +from flask_jeroboam.view_params.functions import Body class OutBoundModel(BaseModel): @@ -501,15 +501,16 @@ class InBoundUser(SecureOutBoundUser): password: str @app.post("/filters_data", response_model=SecureOutBoundUser) - def filters_data(sensitive_data: InBoundUser): + def filters_data(sensitive_data: InBoundUser = Body()): return sensitive_data r = client.post( - "/filters_data", data=json.dumps({"username": "test", "password": "test"}) + "/filters_data", + json={"sensitive_data": {"username": "test", "password": "test"}}, ) assert r.status_code == 201 - assert r.data == json.dumps({"username": "test"}).encode() + assert r.json == {"username": "test"} def test_wrong_tuple_length_raise_error( diff --git a/tests/test_utils.py b/tests/test_utils.py index b5069c5..61ac27d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,12 +1,17 @@ """Testing Utils.""" +from functools import partial from typing import List +import pytest from flask.testing import FlaskClient from pydantic import Field +from flask_jeroboam import Body from flask_jeroboam.jeroboam import Jeroboam from flask_jeroboam.models import Parser from flask_jeroboam.models import Serializer +from flask_jeroboam.utils import _rename_query_params_keys +from flask_jeroboam.view_params.solved import SolvedParameter def test_pascal_case_in_and_out_snake_case(app: Jeroboam, client: FlaskClient): @@ -19,17 +24,87 @@ class OutboundModel(Serializer): page: int per_page: int ids: List[int] + order: List[dict] class InboundModel(Parser): page: int per_page: int ids: List[int] = Field(alias="id[]") + order: List[dict] = Field(alias="order[]") + + app.query_string_key_transformer = partial( + _rename_query_params_keys, pattern=r"(.*)\[(.+)\]$" + ) @app.get("/web_boundaries", response_model=OutboundModel) def read_items(payload: InboundModel): - return {"page": payload.page, "per_page": payload.per_page, "ids": payload.ids} + return payload - r = client.get("web_boundaries?page=1&perPage=10&id[]=1&id[]=2") + r = client.get( + "web_boundaries?page=1&perPage=10&id[]=1&id[]=2&order[name]=asc&order[age]=desc" + ) assert r.status_code == 200 - assert r.data == b'{"page": 1, "perPage": 10, "ids": [1, 2]}' + assert r.json == { + "page": 1, + "perPage": 10, + "ids": [1, 2], + "order": [{"name": "asc"}, {"age": "desc"}], + } + + +def test_pascal_case_in_and_out_snake_case_without_transformer( + app: Jeroboam, client: FlaskClient +): + """GIVEN an endpoint with param typed with a Parser and response_model a Serializer + WHEN payload is send in pascalCase + THEN it lives in python in snake_case and send back in pascalCase + """ + + class OutboundModel(Serializer): + page: int + per_page: int + ids: List[int] + order: List[dict] + + class InboundModel(Parser): + page: int + per_page: int + ids: List[int] = Field(alias="id[]") + order: List[dict] = Field(alias="order[]") + + @app.get("/web_boundaries", response_model=OutboundModel) + def read_items(payload: InboundModel): + return payload + + r = client.get( + "web_boundaries?page=1&perPage=10&id[]=1&id[]=2&order[name]=asc&order[age]=desc" + ) + + assert r.status_code == 400 + assert r.json == { + "detail": [ + { + "loc": ["query", "payload", "order[]"], + "msg": "none is not an allowed value", + "type": "type_error.none.not_allowed", + } + ] + } + + +def test_view_param_str_repr(): + """Test an internal function of ViewParameter.""" + param = Body("MyDefautValue") + assert param.__repr__() == "BodyParameter(MyDefautValue)" + + +def test_solved_param_erroring(): + """Test an internal function of SolvedParameter.""" + param = Body("MyDefautValue") + param.location = None + solved_param = SolvedParameter( + name="FaultySolvedParam", type_=str, view_param=param + ) + with pytest.raises(ValueError): + solved_param._get_values() From 4c43d10adcc69c72ecfc853f69dc52a3ee232c51 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Bianic Date: Tue, 31 Jan 2023 22:27:48 +0100 Subject: [PATCH 06/10] :hammer: Sourcery Refactoring --- flask_jeroboam/_inboundhandler.py | 7 ++----- flask_jeroboam/utils.py | 28 ++++++++++++---------------- flask_jeroboam/view_params/solved.py | 28 +++++++++++----------------- 3 files changed, 25 insertions(+), 38 deletions(-) diff --git a/flask_jeroboam/_inboundhandler.py b/flask_jeroboam/_inboundhandler.py index 9aedd57..05aee3e 100644 --- a/flask_jeroboam/_inboundhandler.py +++ b/flask_jeroboam/_inboundhandler.py @@ -86,7 +86,7 @@ def _solve_default_params_location( main_http_verb: str, ) -> ParamLocation: """Return the default FieldInfo for the InboundHandler.""" - if main_http_verb in ("POST", "PUT"): + if main_http_verb in {"POST", "PUT"}: return ParamLocation.body elif main_http_verb == "GET": return ParamLocation.query @@ -176,10 +176,7 @@ def _solve_view_function_parameter( # Solving Required required: bool = default_value is Undefined - # Solving Type - annotation: Any = Any - if not param.annotation == param.empty: - annotation = param.annotation + annotation = param.annotation if param.annotation != param.empty else Any annotation = get_annotation_from_field_info(annotation, view_param, param_name) return SolvedParameter( diff --git a/flask_jeroboam/utils.py b/flask_jeroboam/utils.py index eaddae0..aefaac2 100644 --- a/flask_jeroboam/utils.py +++ b/flask_jeroboam/utils.py @@ -78,19 +78,17 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any: # pragma: no def is_scalar_field(field: ModelField) -> bool: """Check if a field is a scalar field.""" field_info = field.field_info - if not ( - field.shape == SHAPE_SINGLETON - and not lenient_issubclass(field.type_, BaseModel) - and not lenient_issubclass(field.type_, sequence_types + (dict,)) - and not dataclasses.is_dataclass(field.type_) - and not isinstance(field_info, ViewParameter) - and not getattr(field_info, "location", None) in body_locations - ): - return False - if field.sub_fields: # pragma: no cover - if not all(is_scalar_field(f) for f in field.sub_fields): - return False - return True + return ( + False + if field.shape != SHAPE_SINGLETON + or lenient_issubclass(field.type_, BaseModel) + or lenient_issubclass(field.type_, sequence_types + (dict,)) + or dataclasses.is_dataclass(field.type_) + or isinstance(field_info, ViewParameter) + or getattr(field_info, "location", None) in body_locations + else not field.sub_fields # pragma: no cover + or all(is_scalar_field(f) for f in field.sub_fields) + ) def is_scalar_sequence_field(field: ModelField) -> bool: @@ -103,9 +101,7 @@ def is_scalar_sequence_field(field: ModelField) -> bool: if not is_scalar_field(sub_field): return False return True - if lenient_issubclass(field.type_, sequence_types): # pragma: no cover - return True - return False + return bool(lenient_issubclass(field.type_, sequence_types)) def _rename_query_params_keys(self, inbound_dict: dict, pattern: str) -> dict: diff --git a/flask_jeroboam/view_params/solved.py b/flask_jeroboam/view_params/solved.py index 8750a76..b79af6c 100644 --- a/flask_jeroboam/view_params/solved.py +++ b/flask_jeroboam/view_params/solved.py @@ -79,18 +79,16 @@ def validate_request(self): errors.append( ErrorWrapper(MissingError(), loc=(self.location.value, self.alias)) ) - return values, errors else: values = {self.name: deepcopy(self.default)} - return values, errors - + return values, errors values_, errors_ = self.validate( inbound_values, values, loc=(self.location.value, self.alias) ) if isinstance(errors_, ErrorWrapper): errors.append(errors_) else: - values.update({self.name: values_}) + values[self.name] = values_ return values, errors def _get_values(self) -> Union[dict, Optional[str], List[Any]]: @@ -109,11 +107,7 @@ def _get_values_from_body(self) -> Any: source = request.files else: source = request.json or {} - if self.embed: - values = source.get(self.alias or self.name) - else: - values = source - return values + return source.get(self.alias or self.name) if self.embed else source def _get_values_from_request(self) -> Union[dict, Optional[str], List[Any]]: """Get the values from the request. @@ -139,10 +133,11 @@ def _get_values_from_request(self) -> Union[dict, Optional[str], List[Any]]: if hasattr(self.type_, "__fields__"): assert isinstance(values, dict) # noqa: S101 for field_name, field in self.type_.__fields__.items(): - if is_scalar_sequence_field(field): - values[field_name] = source.getlist(field.alias or field_name) - else: - values[field_name] = source.get(field.alias or field_name) + values[field_name] = ( + source.getlist(field.alias or field_name) + if is_scalar_sequence_field(field) + else source.get(field.alias or field_name) + ) if values[field_name] is None and getattr( current_app, "query_string_key_transformer", False ): @@ -150,9 +145,8 @@ def _get_values_from_request(self) -> Union[dict, Optional[str], List[Any]]: current_app, source.to_dict() ) values[field_name] = values_.get(field.alias or field_name) + elif is_scalar_sequence_field(self): + values = source.getlist(self.alias or self.name) else: - if is_scalar_sequence_field(self): - values = source.getlist(self.alias or self.name) - else: - values = source.get(self.alias or self.name) + values = source.get(self.alias or self.name) return values From dffc4320e01898a19bf371a47efb120d9a0c9813 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Bianic Date: Thu, 2 Feb 2023 16:20:21 +0100 Subject: [PATCH 07/10] :hammer: :rotating_light: Refactored the Test Suite --- .flake8 | 1 + tests/app_test/__init__.py | 0 tests/app_test/application_factory.py | 57 ++ tests/app_test/apps/__init__.py | 0 tests/app_test/apps/body.py | 28 + tests/app_test/apps/cookies.py | 19 + tests/app_test/apps/file.py | 16 + tests/app_test/apps/form.py | 16 + tests/app_test/apps/header.py | 19 + tests/app_test/apps/misc.py | 33 ++ tests/app_test/apps/outbound.py | 138 +++++ tests/app_test/apps/path.py | 247 ++++++++ tests/app_test/apps/query.py | 93 +++ tests/app_test/models/__init__.py | 0 tests/app_test/models/inbound.py | 36 ++ tests/app_test/models/outbound.py | 41 ++ tests/conftest.py | 371 +----------- .../inbound_handler/test_query_operations.py | 69 --- tests/test_error_handling.py | 86 +-- tests/test_inbound_handler.py | 243 -------- .../test_base_configuration.py | 0 .../test_body_parameter.py | 18 +- .../test_cookie_parameter.py | 8 +- .../test_file_parameter.py | 28 + .../test_inbound_handler/test_form_params.py | 14 + .../test_header_parameter.py | 6 +- .../test_path_operations.py | 12 +- .../test_query_operations.py | 145 +++++ .../test_sub_fields.py | 0 .../test_warnings.py | 0 tests/test_misc.py | 13 + tests/test_outbound_handler.py | 532 ------------------ tests/test_outbound_handler/__init__.py | 0 .../test_outbound_handler.py | 251 +++++++++ .../test_outbound_handler/test_status_code.py | 104 ++++ tests/test_utils.py | 63 +-- 36 files changed, 1369 insertions(+), 1338 deletions(-) create mode 100644 tests/app_test/__init__.py create mode 100644 tests/app_test/application_factory.py create mode 100644 tests/app_test/apps/__init__.py create mode 100644 tests/app_test/apps/body.py create mode 100644 tests/app_test/apps/cookies.py create mode 100644 tests/app_test/apps/file.py create mode 100644 tests/app_test/apps/form.py create mode 100644 tests/app_test/apps/header.py create mode 100644 tests/app_test/apps/misc.py create mode 100644 tests/app_test/apps/outbound.py create mode 100644 tests/app_test/apps/path.py create mode 100644 tests/app_test/apps/query.py create mode 100644 tests/app_test/models/__init__.py create mode 100644 tests/app_test/models/inbound.py create mode 100644 tests/app_test/models/outbound.py delete mode 100644 tests/inbound_handler/test_query_operations.py delete mode 100644 tests/test_inbound_handler.py rename tests/{inbound_handler => test_inbound_handler}/test_base_configuration.py (100%) rename tests/{inbound_handler => test_inbound_handler}/test_body_parameter.py (77%) rename tests/{inbound_handler => test_inbound_handler}/test_cookie_parameter.py (78%) create mode 100644 tests/test_inbound_handler/test_file_parameter.py create mode 100644 tests/test_inbound_handler/test_form_params.py rename tests/{inbound_handler => test_inbound_handler}/test_header_parameter.py (83%) rename tests/{inbound_handler => test_inbound_handler}/test_path_operations.py (97%) create mode 100644 tests/test_inbound_handler/test_query_operations.py rename tests/{inbound_handler => test_inbound_handler}/test_sub_fields.py (100%) rename tests/{inbound_handler => test_inbound_handler}/test_warnings.py (100%) create mode 100644 tests/test_misc.py delete mode 100644 tests/test_outbound_handler.py create mode 100644 tests/test_outbound_handler/__init__.py create mode 100644 tests/test_outbound_handler/test_outbound_handler.py create mode 100644 tests/test_outbound_handler/test_status_code.py diff --git a/.flake8 b/.flake8 index c560eda..4bc241b 100644 --- a/.flake8 +++ b/.flake8 @@ -6,6 +6,7 @@ max-complexity = 10 docstring-convention = google per-file-ignores = tests/*:S101,D100,D205,D415,S106,B008,D101 + tests/app_test/*:D103,B008 __init__.py:F401 typing.py:F401 rst-roles = class,const,func,meth,mod,ref diff --git a/tests/app_test/__init__.py b/tests/app_test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app_test/application_factory.py b/tests/app_test/application_factory.py new file mode 100644 index 0000000..56254ef --- /dev/null +++ b/tests/app_test/application_factory.py @@ -0,0 +1,57 @@ +"""Jeroboam Test App factory.""" +import os + +from flask_jeroboam import Jeroboam +from flask_jeroboam.exceptions import InvalidRequest +from flask_jeroboam.exceptions import ResponseValidationError +from flask_jeroboam.exceptions import RessourceNotFound +from flask_jeroboam.exceptions import ServerError + +from .apps.body import router as body_router +from .apps.cookies import router as cookies_router +from .apps.file import router as file_router +from .apps.form import router as form_router +from .apps.header import router as header_router +from .apps.misc import router as misc_router +from .apps.outbound import router as outbound_router +from .apps.path import router as path_router +from .apps.query import router as query_router + + +def create_test_app(): + """Jeroboam test app factory.""" + app = Jeroboam("jeroboam_test", root_path=os.path.dirname(__file__)) + app.config.update( + TESTING=True, + ) + # TODO: Add it by default with CONFIG OPT-OUT + + def handle_404(e): + return {"message": "Not Found"}, 404 + + app.register_error_handler(InvalidRequest, InvalidRequest.handle) + app.register_error_handler(RessourceNotFound, RessourceNotFound.handle) + app.register_error_handler(ServerError, ServerError.handle) + app.register_error_handler(ResponseValidationError, ResponseValidationError.handle) + app.register_error_handler(404, handle_404) + + @app.route("/api_route") + def non_operation(): + return {"message": "Hello World"} + + @app.get("/text") + def get_text(): + return "Hello World" + + # register blueprints. + app.register_blueprint(body_router) + app.register_blueprint(cookies_router) + app.register_blueprint(file_router) + app.register_blueprint(form_router) + app.register_blueprint(header_router) + app.register_blueprint(path_router) + app.register_blueprint(query_router) + app.register_blueprint(misc_router) + app.register_blueprint(outbound_router) + + return app diff --git a/tests/app_test/apps/__init__.py b/tests/app_test/apps/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app_test/apps/body.py b/tests/app_test/apps/body.py new file mode 100644 index 0000000..9490de5 --- /dev/null +++ b/tests/app_test/apps/body.py @@ -0,0 +1,28 @@ +"""A Test Blueprint for testing Body Params. + +The corresponding test can be found in tests/test_inbound/test_body +""" +from flask_jeroboam import Body +from flask_jeroboam import JeroboamBlueprint +from tests.app_test.models.inbound import SimpleModelIn + + +router = JeroboamBlueprint("body_params_router", __name__) + + +@router.post("/body/int") +def post_body_as_int(payload: int = Body()): + """Body Param as plain int.""" + return {"payload": payload} + + +@router.post("/body/str") +def post_body_as_str(payload: str = Body()): + """Body Param as plain str.""" + return {"payload": payload} + + +@router.post("/body/base_model") +def post_base_model_in_form(payload: SimpleModelIn = Body(embed=False)): + """POST Form Parameter as pydantic BaseModel.""" + return payload.json() diff --git a/tests/app_test/apps/cookies.py b/tests/app_test/apps/cookies.py new file mode 100644 index 0000000..6ccfb74 --- /dev/null +++ b/tests/app_test/apps/cookies.py @@ -0,0 +1,19 @@ +"""A Test Blueprint for testing Cookie Params. + +The corresponding test can be found in tests/test_inbound/test_cookie +""" +from flask_jeroboam import Cookie +from flask_jeroboam import JeroboamBlueprint + + +router = JeroboamBlueprint("cookies_params_router", __name__) + + +@router.get("/cookie/int") +def get_cookie_as_int(cookie: int = Cookie()): + return {"cookie": cookie} + + +@router.get("/cookie/str") +def get_cookie_as_str(cookie: str = Cookie()): + return {"cookie": cookie} diff --git a/tests/app_test/apps/file.py b/tests/app_test/apps/file.py new file mode 100644 index 0000000..2821180 --- /dev/null +++ b/tests/app_test/apps/file.py @@ -0,0 +1,16 @@ +"""A Test Blueprint for testing File Params. + +The corresponding test can be found in tests/test_inbound/test_file +""" +from werkzeug.datastructures import FileStorage + +from flask_jeroboam import File +from flask_jeroboam import JeroboamBlueprint + + +router = JeroboamBlueprint("file_params_router", __name__) + + +@router.post("/file") +def ping(file: FileStorage = File(...)): + return {"file_content": str(file.read())} diff --git a/tests/app_test/apps/form.py b/tests/app_test/apps/form.py new file mode 100644 index 0000000..cfcbe3b --- /dev/null +++ b/tests/app_test/apps/form.py @@ -0,0 +1,16 @@ +"""A Test Blueprint for testing Form Params. + +The corresponding test can be found in tests/test_inbound/test_form +""" +from flask_jeroboam import Form +from flask_jeroboam import JeroboamBlueprint +from tests.app_test.models.inbound import SimpleModelIn + + +router = JeroboamBlueprint("form_params_router", __name__) + + +@router.post("/form/base_model") +def post_base_model_in_form(payload: SimpleModelIn = Form()): + """POST Form Parameter as pydantic BaseModel.""" + return payload.json() diff --git a/tests/app_test/apps/header.py b/tests/app_test/apps/header.py new file mode 100644 index 0000000..6b8d429 --- /dev/null +++ b/tests/app_test/apps/header.py @@ -0,0 +1,19 @@ +"""A Test Blueprint for testing Header Params. + +The corresponding test can be found in tests/test_inbound/test_header +""" +from flask_jeroboam import Header +from flask_jeroboam import JeroboamBlueprint + + +router = JeroboamBlueprint("headers_params_router", __name__) + + +@router.get("/headers/int") +def get_header_as_plain_int(test_header: int = Header()): + return {"header": test_header} + + +@router.get("/headers/str") +def get_header_as_plain_str(test_header: str = Header()): + return {"header": test_header} diff --git a/tests/app_test/apps/misc.py b/tests/app_test/apps/misc.py new file mode 100644 index 0000000..ced8029 --- /dev/null +++ b/tests/app_test/apps/misc.py @@ -0,0 +1,33 @@ +"""A Test Blueprint for testing Misc Behavior.""" + +from flask_jeroboam import JeroboamBlueprint +from flask_jeroboam.exceptions import RessourceNotFound +from flask_jeroboam.exceptions import ServerError + + +router = JeroboamBlueprint("misc_router", __name__) + + +@router.get("/invalid_request") +def get_invalid_request(missing_param: int): + return {"missing_param": missing_param} + + +@router.get("/ressource_not_found") +def get_ressource_not_found(): + raise RessourceNotFound(ressource_name="TestRessource", context=f"with id {id}") + + +@router.get("/generic_ressource") +def get_generic_ressource_not_found(): + raise RessourceNotFound(msg="My Message") + + +@router.get("/server_error") +def get_a_server_error(): + raise ServerError(msg="My Message", error=Exception(), trace="FakeTrace") + + +@router.delete("/delete") +def delete(): + return {} diff --git a/tests/app_test/apps/outbound.py b/tests/app_test/apps/outbound.py new file mode 100644 index 0000000..9156c35 --- /dev/null +++ b/tests/app_test/apps/outbound.py @@ -0,0 +1,138 @@ +"""A Test Blueprint for testing Outbound Behavior. + +The corresponding test can be found in tests/test_outbound.py +""" +from typing import List + +from flask_jeroboam import JeroboamBlueprint +from flask_jeroboam.responses import JSONResponse +from flask_jeroboam.view_params.functions import Body +from tests.app_test.models.outbound import MyDataClass +from tests.app_test.models.outbound import SimpleModelOut +from tests.app_test.models.outbound import UserIn +from tests.app_test.models.outbound import UserOut + + +router = JeroboamBlueprint("outbound_router", __name__) + +valid_outbound_data = {"total_count": 10, "items": ["Apple", "Banana"]} + + +@router.route( + "/methods/explicit_options", + methods=["GET", "OPTIONS"], + response_model=SimpleModelOut, +) +def get_with_explicit_verb_options(): + return valid_outbound_data + + +@router.route( + "/methods/explicit_options_and_head", + methods=["GET", "OPTIONS", "HEAD"], + response_model=SimpleModelOut, +) +def get_with_explicit_verbs_options_and_head(): + return valid_outbound_data + + +@router.put("/verb/put/without_explicit_status_code", response_model=SimpleModelOut) +def put_http_verb(): + return valid_outbound_data + + +@router.patch("/verb/patch/without_explicit_status_code", response_model=SimpleModelOut) +def patch_http_verb(): + return valid_outbound_data + + +@router.get("/response_model/no_response_model") +def no_response_model(): + return "Don't have a response model" + + +@router.get("/response_model/infered_from_return_annotation") +def response_model_is_infered_from_return_annotation() -> SimpleModelOut: + return SimpleModelOut(**valid_outbound_data) + + +@router.get( + "/response_model/configuration_over_inference", response_model=SimpleModelOut +) +def configuration_over_inference() -> dict: + return valid_outbound_data + + +@router.get("/response_model/turned_off", response_model=None) +def response_model_inference_is_turned_off() -> SimpleModelOut: + return SimpleModelOut(**{"total_count": 10, "items": ["Apple", "Banana"]}) + + +@router.get("/return_type/dict", response_model=SimpleModelOut) +def view_function_returns_a_dict() -> dict: + return valid_outbound_data + + +@router.get("/return_type/list", response_model=List[SimpleModelOut]) +def view_function_returns_a_list(): + return [valid_outbound_data, valid_outbound_data] + + +@router.get("/return_type/base_model", response_model=SimpleModelOut) +def view_function_returns_a_base_model(): + return SimpleModelOut(total_count=10, items=["Apple", "Banana"]) + + +@router.get("/return_type/response", response_model=SimpleModelOut) +def view_function_returns_a_response(): + return JSONResponse( + SimpleModelOut(total_count=10, items=["Apple", "Banana"]).json() + ) + + +@router.get("/return_type/dataclass", response_model=SimpleModelOut) +def view_function_returns_a_dataclass(): + return MyDataClass(**valid_outbound_data) + + +@router.get("/return_type/not_valid", response_model=SimpleModelOut) +def view_function_returns_unvalid(): + return "not a list" + + +@router.get("/return_shape/with_headers", response_model=SimpleModelOut) +def view_function_returns_dict_and_headers(): + return valid_outbound_data, {"X-Test": "Test"} + + +@router.get("/return_shape/with_headers_and_status_code", response_model=SimpleModelOut) +def view_function_returns_dict_status_code_and_headers(): + return valid_outbound_data, 218, {"X-Test": "Test"} + + +@router.get("/return_shape/with_status_code", response_model=SimpleModelOut) +def view_function_returns_dict_and_status_code(): + return valid_outbound_data, 218 + + +@router.get("/return_shape/wrong_tuple_length", response_model=SimpleModelOut) +def view_function_returns_wrong_tuple_length(): + return valid_outbound_data, 200, {"X-Test": "Test"}, "extra" + + +@router.get("/status_code/204_has_no_body/as_returned") +def returned_status_code_204_has_no_body_returned(): + return "Some Content that will be ignored", 204 + + +@router.get( + "/status_code/204_has_no_body/as_configured", + status_code=204, +) +def configured_status_code_204_has_no_body(): + return "Some Content that will be ignored" + + +@router.post("/sensitive_data", response_model=UserOut) +def reponse_model_filters_data(sensitive_data: UserIn = Body()): + return sensitive_data diff --git a/tests/app_test/apps/path.py b/tests/app_test/apps/path.py new file mode 100644 index 0000000..b81af70 --- /dev/null +++ b/tests/app_test/apps/path.py @@ -0,0 +1,247 @@ +"""A Test Blueprint for testing Path Params. + +The corresponding test can be found in tests/test_inbound/test_path +""" + +from flask_jeroboam import JeroboamBlueprint +from flask_jeroboam import Path + + +router = JeroboamBlueprint("path_params_router", __name__) + + +@router.get("/path/") +def get_id(item_id): + return {"item_id": item_id} + + +@router.get("/path/str/") +def get_str_id(item_id: str): + return {"item_id": item_id} + + +@router.get("/path/int/") +def get_int_id(item_id: int): + return {"item_id": item_id} + + +@router.get("/path/float/") +def get_float_id(item_id: float): + return {"item_id": item_id} + + +@router.get("/path/bool/") +def get_bool_id(item_id: bool): + return {"item_id": item_id} + + +@router.get("/path/param/") +def get_path_param_id(item_id: str = Path()): + return {"item_id": item_id} + + +@router.get("/path/param-required/") +def get_path_param_required_id(item_id: str = Path()): + return {"item_id": item_id} + + +@router.get("/path/param-minlength/") +def get_path_param_min_length(item_id: str = Path(min_length=3)): + return {"item_id": item_id} + + +@router.get("/path/param-maxlength/") +def get_path_param_max_length(item_id: str = Path(max_length=3)): + return {"item_id": item_id} + + +@router.get("/path/param-min_maxlength/") +def get_path_param_min_max_length(item_id: str = Path(max_length=3, min_length=2)): + return {"item_id": item_id} + + +@router.get("/path/param-gt/") +def get_path_param_gt(item_id: float = Path(gt=3)): + return {"item_id": item_id} + + +@router.get("/path/param-gt0/") +def get_path_param_gt0(item_id: float = Path(gt=0)): + return {"item_id": item_id} + + +@router.get("/path/param-ge/") +def get_path_param_ge(item_id: float = Path(ge=3)): + return {"item_id": item_id} + + +@router.get("/path/param-lt/") +def get_path_param_lt(item_id: float = Path(lt=3)): + return {"item_id": item_id} + + +@router.get("/path/param-lt0/") +def get_path_param_lt0(item_id: float = Path(lt=0)): + return {"item_id": item_id} + + +@router.get("/path/param-le/") +def get_path_param_le(item_id: float = Path(le=3)): + return {"item_id": item_id} + + +@router.get("/path/param-lt-gt/") +def get_path_param_lt_gt(item_id: float = Path(lt=3, gt=1)): + return {"item_id": item_id} + + +@router.get("/path/param-le-ge/") +def get_path_param_le_ge(item_id: float = Path(le=3, ge=1)): + return {"item_id": item_id} + + +@router.get("/path/param-lt-int/") +def get_path_param_lt_int(item_id: int = Path(lt=3)): + return {"item_id": item_id} + + +@router.get("/path/param-gt-int/") +def get_path_param_gt_int(item_id: int = Path(gt=3)): + return {"item_id": item_id} + + +@router.get("/path/param-le-int/") +def get_path_param_le_int(item_id: int = Path(le=3)): + return {"item_id": item_id} + + +@router.get("/path/param-ge-int/") +def get_path_param_ge_int(item_id: int = Path(ge=3)): + return {"item_id": item_id} + + +@router.get("/path/param-lt-gt-int/") +def get_path_param_lt_gt_int(item_id: int = Path(lt=3, gt=1)): + return {"item_id": item_id} + + +@router.get("/path/param-le-ge-int/") +def get_path_param_le_ge_int(item_id: int = Path(le=3, ge=1)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/") +def get_with_preproc_id(item_id): + return {"item_id": item_id} + + +@router.get("/path/with_converter/str/") +def get_with_preproc_str_id(item_id: str): + return {"item_id": item_id} + + +@router.get("/path/with_converter/int/") +def get_with_preproc_int_id(item_id: int): + return {"item_id": item_id} + + +@router.get("/path/with_converter/float/") +def get_with_preproc_float_id(item_id: float): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param/") +def get_with_preproc_path_param_id(item_id: str = Path()): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-required/") +def get_with_preproc_path_param_required_id(item_id: str = Path()): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-minlength/") +def get_with_preproc_path_param_min_length(item_id: str = Path(min_length=3)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-maxlength/") +def get_with_preproc_path_param_max_length(item_id: str = Path(max_length=3)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-min_maxlength/") +def get_with_preproc_path_param_min_max_length( + item_id: str = Path(max_length=3, min_length=2) +): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-gt/") +def get_with_preproc_path_param_gt(item_id: float = Path(gt=3)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-gt0/") +def get_with_preproc_path_param_gt0(item_id: float = Path(gt=0)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-ge/") +def get_with_preproc_path_param_ge(item_id: float = Path(ge=3)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-lt/") +def get_with_preproc_path_param_lt(item_id: float = Path(lt=3)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-lt0/") +def get_with_preproc_path_param_lt0(item_id: float = Path(lt=0)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-le/") +def get_with_preproc_path_param_le(item_id: float = Path(le=3)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-lt-gt/") +def get_with_preproc_path_param_lt_gt(item_id: float = Path(lt=3, gt=1)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-le-ge/") +def get_with_preproc_path_param_le_ge(item_id: float = Path(le=3, ge=1)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-lt-int/") +def get_with_preproc_path_param_lt_int(item_id: int = Path(lt=3)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-gt-int/") +def get_with_preproc_path_param_gt_int(item_id: int = Path(gt=3)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-le-int/") +def get_with_preproc_path_param_le_int(item_id: int = Path(le=3)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-ge-int/") +def get_with_preproc_path_param_ge_int(item_id: int = Path(ge=3)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-lt-gt-int/") +def get_with_preproc_path_param_lt_gt_int(item_id: int = Path(lt=3, gt=1)): + return {"item_id": item_id} + + +@router.get("/path/with_converter/param-le-ge-int/") +def get_with_preproc_path_param_le_ge_int(item_id: int = Path(le=3, ge=1)): + return {"item_id": item_id} diff --git a/tests/app_test/apps/query.py b/tests/app_test/apps/query.py new file mode 100644 index 0000000..fa7a6fb --- /dev/null +++ b/tests/app_test/apps/query.py @@ -0,0 +1,93 @@ +"""A Test Blueprint for testing Query Params. + +The corresponding test can be found in tests/test_inbound/test_query +""" +from typing import FrozenSet +from typing import Optional + +from flask_jeroboam import JeroboamBlueprint +from flask_jeroboam import Query +from tests.app_test.models.inbound import ModelWithListIn +from tests.app_test.models.inbound import OptionalModelIn +from tests.app_test.models.inbound import QueryStringWithList +from tests.app_test.models.inbound import SimpleModelIn +from tests.app_test.models.outbound import ModelWithListOut + + +router = JeroboamBlueprint("query_params_router", __name__) + + +@router.get("/query/frozenset/") +def get_query_type_frozenset(query: FrozenSet[int] = Query(...)): + return {"query": ",".join(map(str, sorted(query)))} + + +@router.get("/query") +def get_query(query): + return {"query": query} + + +@router.get("/query/optional") +def get_query_optional(query=None): + return {"query": query} + + +@router.get("/query/int") +def get_query_type(query: int): + return {"query": query} + + +@router.get("/query/int/optional") +def get_query_type_optional(query: Optional[int] = None): + return {"query": query} + + +@router.get("/query/int/default") +def get_query_type_int_default(query: int = 10): + return {"query": query} + + +@router.get("/query/param") +def get_query_param(query=Query(default=None)): + return {"query": query} + + +@router.get("/query/param-required") +def get_query_param_required(query=Query()): + return {"query": query} + + +@router.get("/query/param-required/int") +def get_query_param_required_type(query: int = Query()): + return {"query": query} + + +@router.get("/enum-status-code", status_code=201) +def get_enum_status_code(): + return {} + + +@router.get("/query/base_model") +def get_base_model(payload: SimpleModelIn): + """Base Model as Query Param.""" + return payload.json() + + +@router.get("/query/base_model/forward_ref") +def get_base_model_as_forward_ref(payload: "SimpleModelIn"): + return payload.json() + + +@router.get("/query/list_of_strings") +def get_list_of_strings(query_string: QueryStringWithList): + return query_string.json() + + +@router.get("/query/optional_model") +def get_optional_param(payload: Optional[OptionalModelIn]): + return payload.json() if payload else {} + + +@router.get("/query/special_pattern", response_model=ModelWithListOut) +def read_items(payload: ModelWithListIn): + return payload.json() diff --git a/tests/app_test/models/__init__.py b/tests/app_test/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app_test/models/inbound.py b/tests/app_test/models/inbound.py new file mode 100644 index 0000000..b310025 --- /dev/null +++ b/tests/app_test/models/inbound.py @@ -0,0 +1,36 @@ +"""Inbound Models for Testing.""" +from typing import List +from typing import Optional + +from pydantic import Field + +from flask_jeroboam import InboundModel + + +class SimpleModelIn(InboundModel): + """Simple InboundModel for Testing Parsing Request.""" + + page: int + type: str + + +class QueryStringWithList(InboundModel): + """A List of ids.""" + + id: List[int] = Field(alias="id[]") + + +class OptionalModelIn(InboundModel): + """a BaseModel with Optional Fields.""" + + page: Optional[int] + per_page: Optional[int] + + +class ModelWithListIn(InboundModel): + """InboundModel with lists.""" + + page: int + per_page: int + ids: List[int] = Field(alias="id[]") + order: List[dict] = Field(alias="order[]") diff --git a/tests/app_test/models/outbound.py b/tests/app_test/models/outbound.py new file mode 100644 index 0000000..3f6e988 --- /dev/null +++ b/tests/app_test/models/outbound.py @@ -0,0 +1,41 @@ +"""Outbound Models for Testing.""" +from dataclasses import dataclass +from typing import List + +from flask_jeroboam import OutboundModel + + +class SimpleModelOut(OutboundModel): + """Base OutBoundModel for Testing.""" + + total_count: int + items: List[str] + + +@dataclass +class MyDataClass: + """A Simple DataClass for Testing.""" + + total_count: int + items: List[str] + + +class UserOut(OutboundModel): + """Only the username must be returned.""" + + username: str + + +class UserIn(UserOut): + """Inbound contains the password.""" + + password: str + + +class ModelWithListOut(OutboundModel): + """OutboundModel with lists.""" + + page: int + per_page: int + ids: List[int] + order: List[dict] diff --git a/tests/conftest.py b/tests/conftest.py index 4ff5bd5..ffa8170 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,346 +1,15 @@ -"""Configuration File for pytest.""" -import os -from typing import FrozenSet -from typing import Optional - +"""Defining Fixtures for the Test Suite.""" import pytest from flask_jeroboam import Jeroboam -from flask_jeroboam import JeroboamBlueprint -from flask_jeroboam.exceptions import InvalidRequest -from flask_jeroboam.exceptions import ResponseValidationError -from flask_jeroboam.exceptions import RessourceNotFound -from flask_jeroboam.exceptions import ServerError -from flask_jeroboam.view_params import Body -from flask_jeroboam.view_params import Header -from flask_jeroboam.view_params import Path -from flask_jeroboam.view_params import Query -from flask_jeroboam.view_params.functions import Cookie - - -@pytest.fixture -def app() -> Jeroboam: - """A Basic Jeroboam Test App.""" - app = Jeroboam("jeroboam_test", root_path=os.path.dirname(__file__)) - app.config.update( - TESTING=True, - SECRET_KEY="RandomSecretKey", - ) - # TODO: Add it by default with CONFIG OPT-OUT - - def handle_404(e): - return {"message": "Not Found"}, 404 - - app.register_error_handler(InvalidRequest, InvalidRequest.handle) - app.register_error_handler(RessourceNotFound, RessourceNotFound.handle) - app.register_error_handler(ServerError, ServerError.handle) - app.register_error_handler(ResponseValidationError, ResponseValidationError.handle) - app.register_error_handler(404, handle_404) - - @app.route("/api_route") - def non_operation(): - return {"message": "Hello World"} - @app.get("/text") - def get_text(): - return "Hello World" - - return app +from .app_test.application_factory import create_test_app @pytest.fixture -def app_with_path_operations(app) -> Jeroboam: # noqa: C901 - """App with reigistered path operations.""" - - @app.get("/path/") - def get_id(item_id): - return {"item_id": item_id} - - @app.get("/path/str/") - def get_str_id(item_id: str): - return {"item_id": item_id} - - @app.get("/path/int/") - def get_int_id(item_id: int): - return {"item_id": item_id} - - @app.get("/path/float/") - def get_float_id(item_id: float): - return {"item_id": item_id} - - @app.get("/path/bool/") - def get_bool_id(item_id: bool): - return {"item_id": item_id} - - @app.get("/path/param/") - def get_path_param_id(item_id: str = Path()): - return {"item_id": item_id} - - @app.get("/path/param-required/") - def get_path_param_required_id(item_id: str = Path()): - return {"item_id": item_id} - - @app.get("/path/param-minlength/") - def get_path_param_min_length(item_id: str = Path(min_length=3)): - return {"item_id": item_id} - - @app.get("/path/param-maxlength/") - def get_path_param_max_length(item_id: str = Path(max_length=3)): - return {"item_id": item_id} - - @app.get("/path/param-min_maxlength/") - def get_path_param_min_max_length(item_id: str = Path(max_length=3, min_length=2)): - return {"item_id": item_id} - - @app.get("/path/param-gt/") - def get_path_param_gt(item_id: float = Path(gt=3)): - return {"item_id": item_id} - - @app.get("/path/param-gt0/") - def get_path_param_gt0(item_id: float = Path(gt=0)): - return {"item_id": item_id} - - @app.get("/path/param-ge/") - def get_path_param_ge(item_id: float = Path(ge=3)): - return {"item_id": item_id} - - @app.get("/path/param-lt/") - def get_path_param_lt(item_id: float = Path(lt=3)): - return {"item_id": item_id} - - @app.get("/path/param-lt0/") - def get_path_param_lt0(item_id: float = Path(lt=0)): - return {"item_id": item_id} - - @app.get("/path/param-le/") - def get_path_param_le(item_id: float = Path(le=3)): - return {"item_id": item_id} - - @app.get("/path/param-lt-gt/") - def get_path_param_lt_gt(item_id: float = Path(lt=3, gt=1)): - return {"item_id": item_id} - - @app.get("/path/param-le-ge/") - def get_path_param_le_ge(item_id: float = Path(le=3, ge=1)): - return {"item_id": item_id} - - @app.get("/path/param-lt-int/") - def get_path_param_lt_int(item_id: int = Path(lt=3)): - return {"item_id": item_id} - - @app.get("/path/param-gt-int/") - def get_path_param_gt_int(item_id: int = Path(gt=3)): - return {"item_id": item_id} - - @app.get("/path/param-le-int/") - def get_path_param_le_int(item_id: int = Path(le=3)): - return {"item_id": item_id} - - @app.get("/path/param-ge-int/") - def get_path_param_ge_int(item_id: int = Path(ge=3)): - return {"item_id": item_id} - - @app.get("/path/param-lt-gt-int/") - def get_path_param_lt_gt_int(item_id: int = Path(lt=3, gt=1)): - return {"item_id": item_id} - - @app.get("/path/param-le-ge-int/") - def get_path_param_le_ge_int(item_id: int = Path(le=3, ge=1)): - return {"item_id": item_id} - - @app.get("/path/with_converter/") - def get_with_preproc_id(item_id): - return {"item_id": item_id} - - @app.get("/path/with_converter/str/") - def get_with_preproc_str_id(item_id: str): - return {"item_id": item_id} - - @app.get("/path/with_converter/int/") - def get_with_preproc_int_id(item_id: int): - return {"item_id": item_id} - - @app.get("/path/with_converter/float/") - def get_with_preproc_float_id(item_id: float): - return {"item_id": item_id} - - @app.get("/path/with_converter/param/") - def get_with_preproc_path_param_id(item_id: str = Path()): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-required/") - def get_with_preproc_path_param_required_id(item_id: str = Path()): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-minlength/") - def get_with_preproc_path_param_min_length(item_id: str = Path(min_length=3)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-maxlength/") - def get_with_preproc_path_param_max_length(item_id: str = Path(max_length=3)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-min_maxlength/") - def get_with_preproc_path_param_min_max_length( - item_id: str = Path(max_length=3, min_length=2) - ): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-gt/") - def get_with_preproc_path_param_gt(item_id: float = Path(gt=3)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-gt0/") - def get_with_preproc_path_param_gt0(item_id: float = Path(gt=0)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-ge/") - def get_with_preproc_path_param_ge(item_id: float = Path(ge=3)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-lt/") - def get_with_preproc_path_param_lt(item_id: float = Path(lt=3)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-lt0/") - def get_with_preproc_path_param_lt0(item_id: float = Path(lt=0)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-le/") - def get_with_preproc_path_param_le(item_id: float = Path(le=3)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-lt-gt/") - def get_with_preproc_path_param_lt_gt(item_id: float = Path(lt=3, gt=1)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-le-ge/") - def get_with_preproc_path_param_le_ge(item_id: float = Path(le=3, ge=1)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-lt-int/") - def get_with_preproc_path_param_lt_int(item_id: int = Path(lt=3)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-gt-int/") - def get_with_preproc_path_param_gt_int(item_id: int = Path(gt=3)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-le-int/") - def get_with_preproc_path_param_le_int(item_id: int = Path(le=3)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-ge-int/") - def get_with_preproc_path_param_ge_int(item_id: int = Path(ge=3)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-lt-gt-int/") - def get_with_preproc_path_param_lt_gt_int(item_id: int = Path(lt=3, gt=1)): - return {"item_id": item_id} - - @app.get("/path/with_converter/param-le-ge-int/") - def get_with_preproc_path_param_le_ge_int(item_id: int = Path(le=3, ge=1)): - return {"item_id": item_id} - - return app - - -@pytest.fixture -def app_with_query_operations(app: Jeroboam) -> Jeroboam: # noqa: C901 - """App with reigistered query operations.""" - - @app.get("/query/frozenset/") - def get_query_type_frozenset(query: FrozenSet[int] = Query(...)): - return {"query": ",".join(map(str, sorted(query)))} - - @app.get("/query") - def get_query(query): - return {"query": query} - - @app.get("/query/optional") - def get_query_optional(query=None): - return {"query": query} - - @app.get("/query/int") - def get_query_type(query: int): - return {"query": query} - - @app.get("/query/int/optional") - def get_query_type_optional(query: Optional[int] = None): - return {"query": query} - - @app.get("/query/int/default") - def get_query_type_int_default(query: int = 10): - return {"query": query} - - @app.get("/query/param") - def get_query_param(query=Query(default=None)): - return {"query": query} - - @app.get("/query/param-required") - def get_query_param_required(query=Query()): - return {"query": query} - - @app.get("/query/param-required/int") - def get_query_param_required_type(query: int = Query()): - return {"query": query} - - @app.get("/enum-status-code", status_code=201) - def get_enum_status_code(): - return {} - - return app - - -@pytest.fixture -def app_with_cookie_parameters(app: Jeroboam) -> Jeroboam: - """App with reigistered cookie parameters.""" - - @app.get("/cookie/int") - def get_cookie_as_int(cookie: int = Cookie()): - return {"cookie": cookie} - - @app.get("/cookie/str") - def get_cookie_as_str(cookie: str = Cookie()): - return {"cookie": cookie} - - return app - - -@pytest.fixture -def app_with_header_parameters(app: Jeroboam) -> Jeroboam: - """App with reigistered cookie parameters.""" - - @app.get("/headers/int") - def get_header_as_int(test_header: int = Header()): - return {"header": test_header} - - @app.get("/headers/str") - def get_header_as_str(test_header: str = Header()): - return {"header": test_header} - - return app - - -@pytest.fixture -def app_with_body_parameters(app: Jeroboam) -> Jeroboam: - """App with reigistered cookie parameters.""" - - @app.post("/body/int") - def post_body_as_int(payload: int = Body()): - return {"payload": payload} - - @app.post("/body/str") - def post_body_as_str(payload: str = Body()): - return {"payload": payload} - - return app - - -@pytest.fixture -def blueprint() -> JeroboamBlueprint: - """A Basic Jeroboam Test App.""" - return JeroboamBlueprint("TestBluePrint", __name__) +def app() -> Jeroboam: + """The Jeroboam Test App.""" + return create_test_app() @pytest.fixture @@ -361,33 +30,3 @@ def request_context(app: Jeroboam): def client(app: Jeroboam): """Test Client from the Test App.""" return app.test_client() - - -@pytest.fixture -def query_client(app_with_query_operations: Jeroboam): - """Test Client from the Test App.""" - return app_with_query_operations.test_client() - - -@pytest.fixture -def path_client(app_with_path_operations: Jeroboam): - """Test Client from the Test App.""" - return app_with_path_operations.test_client() - - -@pytest.fixture -def cookie_client(app_with_cookie_parameters: Jeroboam): - """Test Client from the Test App.""" - return app_with_cookie_parameters.test_client() - - -@pytest.fixture -def header_client(app_with_header_parameters: Jeroboam): - """Test Client from the Test App.""" - return app_with_header_parameters.test_client() - - -@pytest.fixture -def body_client(app_with_body_parameters: Jeroboam): - """Test Client from the Test App.""" - return app_with_body_parameters.test_client() diff --git a/tests/inbound_handler/test_query_operations.py b/tests/inbound_handler/test_query_operations.py deleted file mode 100644 index c5b1a02..0000000 --- a/tests/inbound_handler/test_query_operations.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest - - -response_missing = { - "detail": [ - { - "loc": ["query", "query"], - "msg": "field required", - "type": "value_error.missing", - } - ] -} - -response_not_valid_int = { - "detail": [ - { - "loc": ["query", "query"], - "msg": "value is not a valid integer", - "type": "type_error.integer", - } - ] -} - - -def _valid(value) -> dict: - """Valid function.""" - return {"query": value} - - -@pytest.mark.parametrize( - "url,expected_status,expected_response", - [ - ("/query", 400, response_missing), - ("/query?query=baz", 200, _valid("baz")), - ("/query?not_declared=baz", 400, response_missing), - ("/query/optional", 200, _valid(None)), - ("/query/optional?query=baz", 200, _valid("baz")), - ("/query/optional?not_declared=baz", 200, _valid(None)), - ("/query/int", 400, response_missing), - ("/query/int?query=42", 200, _valid(42)), - ("/query/int?query=42.5", 400, response_not_valid_int), - ("/query/int?query=baz", 400, response_not_valid_int), - ("/query/int?not_declared=baz", 400, response_missing), - ("/query/int/optional", 200, _valid(None)), - ("/query/int/optional?query=50", 200, _valid(50)), - ("/query/int/optional?query=foo", 400, response_not_valid_int), - ("/query/int/default", 200, _valid(10)), - ("/query/int/default?query=50", 200, _valid(50)), - ("/query/int/default?query=foo", 400, response_not_valid_int), - ("/query/param", 200, _valid(None)), - ("/query/param?query=50", 200, _valid("50")), - ("/query/param-required", 400, response_missing), - ("/query/param-required?query=50", 200, _valid("50")), - ("/query/param-required/int", 400, response_missing), - ("/query/param-required/int?query=50", 200, _valid(50)), - ("/query/param-required/int?query=foo", 400, response_not_valid_int), - ("/query/frozenset/?query=1&query=1&query=2", 200, _valid("1,2")), - ], -) -def test_get_query_operations(query_client, url, expected_status, expected_response): - """Testing Various GET operations with query parameters. - - GIVEN a GET endpoint configiured with query parameters - WHEN a request is made to the endpoint - THEN the request is parsed and validated accordingly - """ - response = query_client.get(url) - assert response.status_code == expected_status - assert response.json == expected_response diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py index ef859a6..e464858 100644 --- a/tests/test_error_handling.py +++ b/tests/test_error_handling.py @@ -1,30 +1,17 @@ -"""Testing Error Handling. -We test for both Exception Initialisation and Error Handling. -""" +"""Testing Error Handling.""" from flask.testing import FlaskClient -from flask_jeroboam.exceptions import RessourceNotFound -from flask_jeroboam.exceptions import ServerError -from flask_jeroboam.jeroboam import Jeroboam - def test_invalid_request( - app: Jeroboam, client: FlaskClient, ): - """GIVEN an endpoint - WHEN it raises a InvalidRequest - THEN I get a 404 with InvalidRequest Message + """GIVEN a view function + WHEN it raises a ResponseValidationError + THEN I get a 400 response with message details """ - - @app.get("/invalid_request") - def ping(missing_param: int): - return {} - - r = client.get("invalid_request") - - assert r.status_code == 400 - assert r.json == { + response = client.get("/invalid_request") + assert response.status_code == 400 + assert response.json == { "detail": [ { "loc": ["query", "missing_param"], @@ -35,58 +22,37 @@ def ping(missing_param: int): } -def test_ressource_not_found_named_ressource( - app: Jeroboam, +def test_ressource_not_found( client: FlaskClient, ): - """GIVEN an endpoint - WHEN it raises a ResssourceNotFound - THEN I get a 404 with RessourceNotFound Message + """GIVEN a decorated view function + WHEN it raises a Generic ResssourceNotFound + THEN I get a 404 response with RessourceNotFound Generic Message """ + response = client.get("/ressource_not_found") + assert response.status_code == 404 + assert response.data.startswith(b"RessourceNotFound: TestRessource not found :") - @app.get("/ressource_not_found") - def ping(): - raise RessourceNotFound(ressource_name="TestRessource", context=f"with id {id}") - - r = client.get("/ressource_not_found") - - assert r.status_code == 404 - assert r.data.startswith(b"RessourceNotFound: TestRessource not found :") - -def test_ressource_not_found_generic_message( - app: Jeroboam, +def test_generic_ressource_not_found_message( client: FlaskClient, ): - """GIVEN an endpoint + """GIVEN a decorated view function WHEN it raises a ResssourceNotFound - THEN I get a 404 with RessourceNotFound Generic Message + THEN I get a 404 response with RessourceNotFound Generic Message """ - - @app.get("/generic_ressource") - def generic_ressource(): - raise RessourceNotFound(msg="My Message") - - r = client.get("/generic_ressource") - - assert r.status_code == 404 - assert r.data.startswith(b"RessourceNotFound: My Message") + response = client.get("/generic_ressource") + assert response.status_code == 404 + assert response.data.startswith(b"RessourceNotFound: My Message") def test_internal_server_error( - app: Jeroboam, client: FlaskClient, ): - """GIVEN an endpoint - WHEN it raises a InternalServerError - THEN I get a 500 with RessourceNotFound Generic Message + """GIVEN a decorated view function + WHEN it raises a ServerError + THEN I get a 500 response with ServerError Message """ - - @app.get("/server_error") - def ping(): - raise ServerError(msg="My Message", error=Exception(), trace="FakeTrace") - - r = client.get("/server_error") - - assert r.status_code == 500 - assert r.data.startswith(b"InternalServerError: My Message") + response = client.get("/server_error") + assert response.status_code == 500 + assert response.data.startswith(b"InternalServerError: My Message") diff --git a/tests/test_inbound_handler.py b/tests/test_inbound_handler.py deleted file mode 100644 index ca7468d..0000000 --- a/tests/test_inbound_handler.py +++ /dev/null @@ -1,243 +0,0 @@ -"""Testing Request Parsing Use Cases. -We test for various payload location (QueryString, Data, Json, Files), -verbs (GET, POST, DELETE) and configuration (Lists or Plain)... -""" -from io import BytesIO -from typing import List -from typing import Optional - -from flask.testing import FlaskClient -from pydantic import BaseModel -from pydantic import Field -from werkzeug.datastructures import FileStorage - -from flask_jeroboam.jeroboam import Jeroboam -from flask_jeroboam.models import Parser -from flask_jeroboam.view_params.functions import Body -from flask_jeroboam.view_params.functions import File -from flask_jeroboam.view_params.functions import Form - - -class InBoundModel(BaseModel): - """Base InboundModel for Testing Parsing Request.""" - - page: int - type: str - - -def test_valid_payload_in_query_string_is_injected( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN a GET endpoint with properly annotated argument - WHEN hit with a valid query string - THEN the parsed input is injected into the view function. - """ - - @app.get("/payload_in_query_string") - def read_test(payload: InBoundModel): - return payload.json() - - r = client.get("/payload_in_query_string?page=1&type=item") - - assert r.status_code == 200 - assert r.data == b'{"page": 1, "type": "item"}' - - -def test_forward_ref_in_query_string_is_injected( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN a GET endpoint with properly annotated as forward ref - WHEN hit with a valid query string - THEN the parsed input is injected into the view function. - """ - - @app.get("/payload_in_query_string") - def read_test(payload: "InBoundModel"): - return payload.json() - - r = client.get("/payload_in_query_string?page=1&type=item") - - assert r.status_code == 200 - assert r.data == b'{"page": 1, "type": "item"}' - - -def test_valid_payload_in_json_is_injected( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN a POST endpoint with properly annotated argument - WHEN hit with a valid json payload - THEN the parsed input is injected into the view function. - """ - - @app.post("/payload_in_json") - def read_test(payload: InBoundModel = Body(embed=False)): - return payload.json() - - r = client.post("/payload_in_json", json={"page": 1, "type": "item"}) - - assert r.status_code == 200 - assert r.data == b'{"page": 1, "type": "item"}' - - -def test_valid_payload_in_data_is_injected( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN a POST endpoint with properly annotated argument - WHEN hit with a valid data payload - THEN the parsed input is injected into the view function. - """ - - @app.post("/payload_in_json") - def read_test(payload: InBoundModel = Form()): - return payload.json() - - r = client.post("/payload_in_json", data={"page": 1, "type": "item"}) - - assert r.status_code == 200 - assert r.data == b'{"page": 1, "type": "item"}' - - -def test_valid_payload_in_files_is_injected(app: Jeroboam, client: FlaskClient): - """GIVEN a POST endpoint with properly annotated argument - WHEN hit with a valid files payload - THEN the parsed input is injected into the view function. - """ - - @app.post("/payload_in_file") - def ping(file: FileStorage = File(...)): - return {"file_content": str(file.read())} - - data = {"file": (BytesIO(b"Hello World !!"), "hello.txt"), "type": "file"} - - r = client.post( - "payload_in_file", - data=data, - headers={ - "enctype": "multipart/form-data", - }, - ) - - assert r.status_code == 200 - assert r.json == {"file_content": "b'Hello World !!'"} - - -def test_invalid_query_string_raise_400( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN a GET endpoint with properly annotated argument - WHEN hit with invalid queryString - THEN the endpoint raise a 400 InvalidRequest Error - """ - - @app.get("/strict_endpoint") - def read_test(payload: InBoundModel): - return payload.json() - - r = client.get("/strict_endpoint?page=not_a_valid_param") - - assert r.status_code == 400 - assert r.json == { - "detail": [ - { - "loc": ["query", "payload", "page"], - "msg": "value is not a valid integer", - "type": "type_error.integer", - }, - { - "loc": ["query", "payload", "type"], - "msg": "none is not an allowed value", - "type": "type_error.none.not_allowed", - }, - ] - } - - -def test_invalid_simple_param_raise_400( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint that native-type-annotated argument - WHEN hit with a wrong parameters - THEN the endpoint raise a 400 InvalidRequest Error - """ - - @app.get("/query_string_as_int") - def ping(simple_param: int): - return {} - - r = client.get("query_string_as_int?simple_param=imparsable") - - assert r.status_code == 400 - assert r.json == { - "detail": [ - { - "loc": ["query", "simple_param"], - "msg": "value is not a valid integer", - "type": "type_error.integer", - } - ] - } - - -def test_query_string_for_list_arguments( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN a GET endpoint with list arguments - WHEN hit with proper formatted queryString - THEN the arguments get injected into a Array - """ - - class QueryStringWithList(Parser): - id: List[int] = Field(alias="id[]") - - @app.get("/query_string_with_list") - def ping(query_string: QueryStringWithList): - return ",".join([str(id) for id in query_string.id]) - - r = client.get("/query_string_with_list?id[]=1&id[]=2") - assert r.data == b"1,2" - - -def test_optionnal_param( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with Optionnal typed argument with Optional params - WHEN hit with an empty payload - THEN the endpoint is properly executed - """ - - class InBoundModel(BaseModel): - page: Optional[int] - per_page: Optional[int] - - @app.get("/optionnal_param") - def ping(payload: Optional[InBoundModel]): - return payload.json() if payload else {} - - r = client.get("/optionnal_param") - - assert r.status_code == 200 - assert r.data == b'{"page": null, "per_page": null}' - - -def test_other_methods(app: Jeroboam, client: FlaskClient): - """GVIEN an endpoint with a different verb than GET or POST - WHEN hit - THEN it works like a regular endpoint - """ - - @app.delete("/other_verb") - def ping(): - return {}, 201 - - r = client.delete("/other_verb") - - assert r.status_code == 201 - assert r.data == b"{}\n" diff --git a/tests/inbound_handler/test_base_configuration.py b/tests/test_inbound_handler/test_base_configuration.py similarity index 100% rename from tests/inbound_handler/test_base_configuration.py rename to tests/test_inbound_handler/test_base_configuration.py diff --git a/tests/inbound_handler/test_body_parameter.py b/tests/test_inbound_handler/test_body_parameter.py similarity index 77% rename from tests/inbound_handler/test_body_parameter.py rename to tests/test_inbound_handler/test_body_parameter.py index 69be259..4260f5a 100644 --- a/tests/inbound_handler/test_body_parameter.py +++ b/tests/test_inbound_handler/test_body_parameter.py @@ -2,7 +2,7 @@ import pytest -from flask_jeroboam.models import Parser +from flask_jeroboam.models import InboundModel from flask_jeroboam.view_params.functions import Body @@ -25,13 +25,19 @@ def _valid(value) -> dict: @pytest.mark.parametrize( "url,body_value,expected_status,expected_response", [ - ("/body/str", {"payload": "foobar"}, 200, _valid("foobar")), - ("/body/int", {"payload": 123}, 200, _valid(123)), + ("/body/str", {"payload": "foobar"}, 201, _valid("foobar")), + ("/body/int", {"payload": 123}, 201, _valid(123)), ("/body/int", {"payload": "not_a_valid_int"}, 400, response_not_valid_int), + ( + "/body/base_model", + {"page": 1, "type": "item"}, + 201, + {"page": 1, "type": "item"}, + ), ], ) def test_post_body_operations( - body_client, url, body_value, expected_status, expected_response + client, url, body_value, expected_status, expected_response ): """Testing Various GET operations with query parameters. @@ -39,7 +45,7 @@ def test_post_body_operations( WHEN a request is made to the endpoint THEN the request is parsed and validated accordingly """ - response = body_client.post(url, json=body_value) + response = client.post(url, json=body_value) assert response.json == expected_response assert response.status_code == expected_status @@ -47,7 +53,7 @@ def test_post_body_operations( def test_post_body_list_of_base_model(app, client): """Test Body Parameter with POST method.""" - class InBound(Parser): + class InBound(InboundModel): """Inbound model.""" item: str diff --git a/tests/inbound_handler/test_cookie_parameter.py b/tests/test_inbound_handler/test_cookie_parameter.py similarity index 78% rename from tests/inbound_handler/test_cookie_parameter.py rename to tests/test_inbound_handler/test_cookie_parameter.py index dd0c6e0..fa43dc0 100644 --- a/tests/inbound_handler/test_cookie_parameter.py +++ b/tests/test_inbound_handler/test_cookie_parameter.py @@ -24,15 +24,13 @@ def _valid(x): ("/cookie/int", b"not_valid_int", 400, not_a_valid_int), ], ) -def test_get_cookie( - cookie_client, url, cookie_value, expected_status, expected_response -): +def test_get_cookie(client, url, cookie_value, expected_status, expected_response): """Test Cookie Parameter with GET method. TODO: Allow Configuration of the returned Status Code. """ - cookie_client.set_cookie("localhost", "cookie", cookie_value) - response = cookie_client.get(url) + client.set_cookie("localhost", "cookie", cookie_value) + response = client.get(url) assert response.status_code == expected_status assert response.json == expected_response diff --git a/tests/test_inbound_handler/test_file_parameter.py b/tests/test_inbound_handler/test_file_parameter.py new file mode 100644 index 0000000..26e9fa4 --- /dev/null +++ b/tests/test_inbound_handler/test_file_parameter.py @@ -0,0 +1,28 @@ +"""Test the FileParam. + +The corresponding endpoints are defined in the test_app.apps.file.py module. +""" + + +from io import BytesIO + +from flask.testing import FlaskClient + + +def test_valid_payload_in_files_is_injected(client: FlaskClient): + """GIVEN a POST endpoint with FileParam + WHEN hit with a valid files payload + THEN the parsed input is injected into the view function. + """ + data = {"file": (BytesIO(b"Hello World !!"), "hello.txt"), "type": "file"} + + response = client.post( + "/file", + data=data, + headers={ + "enctype": "multipart/form-data", + }, + ) + + assert response.status_code == 201 + assert response.json == {"file_content": "b'Hello World !!'"} diff --git a/tests/test_inbound_handler/test_form_params.py b/tests/test_inbound_handler/test_form_params.py new file mode 100644 index 0000000..81aa950 --- /dev/null +++ b/tests/test_inbound_handler/test_form_params.py @@ -0,0 +1,14 @@ +from flask.testing import FlaskClient + + +def test_valid_payload_in_data_is_injected( + client: FlaskClient, +): + """GIVEN a POST endpoint with a BaseModel FormParam + WHEN hit with a valid json payload + THEN the parsed input is injected into the view function. + """ + response = client.post("/form/base_model", data={"page": 1, "type": "item"}) + + assert response.status_code == 201 + assert response.json == {"page": 1, "type": "item"} diff --git a/tests/inbound_handler/test_header_parameter.py b/tests/test_inbound_handler/test_header_parameter.py similarity index 83% rename from tests/inbound_handler/test_header_parameter.py rename to tests/test_inbound_handler/test_header_parameter.py index 0754159..ccb1b3d 100644 --- a/tests/inbound_handler/test_header_parameter.py +++ b/tests/test_inbound_handler/test_header_parameter.py @@ -24,14 +24,12 @@ def _valid(x): ("/headers/int", {"test-header": "not_a_valid_int"}, 400, not_a_valid_int), ], ) -def test_get_headers( - header_client, url, header_value, expected_status, expected_response -): +def test_get_headers(client, url, header_value, expected_status, expected_response): """Test Cookie Parameter with GET method. TODO: Allow Configuration of the returned Status Code. """ - response = header_client.get(url, headers=header_value) + response = client.get(url, headers=header_value) assert response.status_code == expected_status assert response.json == expected_response diff --git a/tests/inbound_handler/test_path_operations.py b/tests/test_inbound_handler/test_path_operations.py similarity index 97% rename from tests/inbound_handler/test_path_operations.py rename to tests/test_inbound_handler/test_path_operations.py index ba36095..283cd44 100644 --- a/tests/inbound_handler/test_path_operations.py +++ b/tests/test_inbound_handler/test_path_operations.py @@ -247,13 +247,13 @@ def _valid(x): ("/path/param-le-ge-int/2.7", 400, response_not_valid_int), ], ) -def test_get_path(path_client, url, expected_status, expected_response): +def test_get_path(client, url, expected_status, expected_response): """Test Path Operation with GET method. TODO: Allow Configuration of the returned Status Code. """ - response = path_client.get(url) + response = client.get(url) assert response.status_code == expected_status assert response.json == expected_response @@ -316,13 +316,13 @@ def test_get_path(path_client, url, expected_status, expected_response): ("/path/with_converter/param-le-ge-int/4", 400, response_less_than_equal_3), ], ) -def test_get_path_with_converter(path_client, url, expected_status, expected_response): +def test_get_path_with_converter(client, url, expected_status, expected_response): """Test Path Operation with GET method. TODO: Allow Configuration of the returned Status Code. """ - response = path_client.get(url) + response = client.get(url) assert response.json == expected_response assert response.status_code == expected_status @@ -342,7 +342,7 @@ def test_get_path_with_converter(path_client, url, expected_status, expected_res ], ) def test_path_converter_error_override_jeroboam_validation( - path_client, url, expected_status, expected_response + client, url, expected_status, expected_response ): """Test Url Converter Overides PathParams Validation. @@ -350,6 +350,6 @@ def test_path_converter_error_override_jeroboam_validation( WHEN the url is called with a value that does not match the converter THEN the converter error is returned (404), not the Jeroboam one (400) """ - response = path_client.get(url) + response = client.get(url) assert response.json == expected_response assert response.status_code == expected_status diff --git a/tests/test_inbound_handler/test_query_operations.py b/tests/test_inbound_handler/test_query_operations.py new file mode 100644 index 0000000..2a0902a --- /dev/null +++ b/tests/test_inbound_handler/test_query_operations.py @@ -0,0 +1,145 @@ +import pytest +from flask.testing import FlaskClient + +from tests.app_test.models.inbound import OptionalModelIn + + +response_missing = { + "detail": [ + { + "loc": ["query", "query"], + "msg": "field required", + "type": "value_error.missing", + } + ] +} + +response_not_valid_int = { + "detail": [ + { + "loc": ["query", "query"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] +} + + +def _valid(value) -> dict: + """Valid function.""" + return {"query": value} + + +@pytest.mark.parametrize( + "url,expected_status,expected_response", + [ + ("/query", 400, response_missing), + ("/query?query=baz", 200, _valid("baz")), + ("/query?not_declared=baz", 400, response_missing), + ("/query/optional", 200, _valid(None)), + ("/query/optional?query=baz", 200, _valid("baz")), + ("/query/optional?not_declared=baz", 200, _valid(None)), + ("/query/int", 400, response_missing), + ("/query/int?query=42", 200, _valid(42)), + ("/query/int?query=42.5", 400, response_not_valid_int), + ("/query/int?query=baz", 400, response_not_valid_int), + ("/query/int?not_declared=baz", 400, response_missing), + ("/query/int/optional", 200, _valid(None)), + ("/query/int/optional?query=50", 200, _valid(50)), + ("/query/int/optional?query=foo", 400, response_not_valid_int), + ("/query/int/default", 200, _valid(10)), + ("/query/int/default?query=50", 200, _valid(50)), + ("/query/int/default?query=foo", 400, response_not_valid_int), + ("/query/param", 200, _valid(None)), + ("/query/param?query=50", 200, _valid("50")), + ("/query/param-required", 400, response_missing), + ("/query/param-required?query=50", 200, _valid("50")), + ("/query/param-required/int", 400, response_missing), + ("/query/param-required/int?query=50", 200, _valid(50)), + ("/query/param-required/int?query=foo", 400, response_not_valid_int), + ("/query/frozenset/?query=1&query=1&query=2", 200, _valid("1,2")), + ], +) +def test_get_query_operations(client, url, expected_status, expected_response): + """Testing Various GET operations with query parameters. + + GIVEN a GET endpoint configiured with query parameters + WHEN a request is made to the endpoint + THEN the request is parsed and validated accordingly + """ + response = client.get(url) + assert response.status_code == expected_status + assert response.json == expected_response + + +def test_valid_base_model_as_query_parameter( + client: FlaskClient, +): + """GIVEN a GET endpoint with a BaseModel as QueryParam + WHEN hit with a valid query string + THEN the parsed input is injected into the view function. + """ + response = client.get("/query/base_model?page=1&type=item") + assert response.status_code == 200 + assert response.json == {"page": 1, "type": "item"} + + +def test_valid_base_model_as_forwarded_query_parameter( + client: FlaskClient, +): + """GIVEN a GET endpoint with a BaseModel as Forward Ref QueryParam + WHEN hit with a valid query string + THEN the parsed input is injected into the view function. + """ + response = client.get("/query/base_model/forward_ref?page=1&type=item") + assert response.status_code == 200 + assert response.json == {"page": 1, "type": "item"} + + +def test_invalid_query_string_raise_400( + client: FlaskClient, +): + """GIVEN a GET endpoint with properly annotated argument + WHEN hit with invalid queryString + THEN the endpoint raise a 400 InvalidRequest Error + """ + response = client.get("/query/base_model?page=not_a_valid_param") + assert response.status_code == 400 + assert response.json == { + "detail": [ + { + "loc": ["query", "payload", "page"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + }, + { + "loc": ["query", "payload", "type"], + "msg": "none is not an allowed value", + "type": "type_error.none.not_allowed", + }, + ] + } + + +def test_query_string_for_list_arguments( + client: FlaskClient, +): + """GIVEN a GET endpoint with list arguments + WHEN hit with proper formatted queryString + THEN the arguments get injected into a Array + """ + response = client.get("/query/list_of_strings?id[]=1&id[]=2") + assert response.json == {"id": [1, 2]} + assert response.status_code == 200 + + +def test_query_optionnal_base_model( + client: FlaskClient, +): + """GIVEN an endpoint with Optionnal typed argument with Optional fields + WHEN hit with an empty querystring + THEN the endpoint is properly executed + """ + response = client.get("/query/optional_model") + assert response.status_code == 200 + assert response.json == OptionalModelIn(**{}).dict() diff --git a/tests/inbound_handler/test_sub_fields.py b/tests/test_inbound_handler/test_sub_fields.py similarity index 100% rename from tests/inbound_handler/test_sub_fields.py rename to tests/test_inbound_handler/test_sub_fields.py diff --git a/tests/inbound_handler/test_warnings.py b/tests/test_inbound_handler/test_warnings.py similarity index 100% rename from tests/inbound_handler/test_warnings.py rename to tests/test_inbound_handler/test_warnings.py diff --git a/tests/test_misc.py b/tests/test_misc.py new file mode 100644 index 0000000..7d05252 --- /dev/null +++ b/tests/test_misc.py @@ -0,0 +1,13 @@ +"""Misc Testing for Flask-Jeroboam.""" + +from flask.testing import FlaskClient + + +def test_delete_method(client: FlaskClient): + """GIIVEN an endpoint with a different verb than GET or POST + WHEN hit + THEN it works like a regular endpoint + """ + response = client.delete("/delete") + assert response.status_code == 204 + assert response.data == b"" diff --git a/tests/test_outbound_handler.py b/tests/test_outbound_handler.py deleted file mode 100644 index 3eae048..0000000 --- a/tests/test_outbound_handler.py +++ /dev/null @@ -1,532 +0,0 @@ -"""Testing Outbound Handler Use Cases. -We test for various return values (Response, Dict, ResponseModel), -configuration (response_model or not) and error handling. -""" -import warnings -from dataclasses import dataclass -from typing import List -from unittest.mock import patch - -import pytest -from flask import Response -from flask.testing import FlaskClient -from pydantic import BaseModel - -from flask_jeroboam.jeroboam import Jeroboam -from flask_jeroboam.view_params.functions import Body - - -class OutBoundModel(BaseModel): - """Base OutBoundModel for Testing.""" - - total_count: int - items: List[str] - - -valid_outbound_data = {"total_count": 10, "items": ["Apple", "Banana"]} -valid_response_body = b'{"total_count": 10, "items": ["Apple", "Banana"]}' - - -def test_register_route_with_additionnal_secondary_verb( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with only one main http verb - WHEN configured with secondary verbs (OPTIONS, HEAD) - THEN it register and keeps the main verb - """ - - @app.route( - "/endpoint_with_options", - methods=["GET", "OPTIONS"], - response_model=OutBoundModel, - ) - def with_one_secondary_verb(): - return valid_outbound_data - - @app.route( - "/endpoint_with_options", - methods=["GET", "OPTIONS", "HEAD"], - response_model=OutBoundModel, - ) - def with_two_secondary_verb(): - return valid_outbound_data - - r = client.get("/endpoint_with_options") - assert r.data == valid_response_body - assert app.url_map._rules_by_endpoint["with_one_secondary_verb"][0].methods == { - "GET", - "OPTIONS", - "HEAD", - } - assert app.url_map._rules_by_endpoint["with_two_secondary_verb"][0].methods == { - "GET", - "OPTIONS", - "HEAD", - } - - -def test_register_route_with_two_main_verb_raise_a_warning( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint configured with two main http verb - WHEN it is registered - THEN it raises a warning and keep the first one - """ - with pytest.warns(UserWarning): - - @app.route( - "/endpoint_with_stwo_main_verb", - methods=["GET", "POST"], - response_model=OutBoundModel, - ) - def with_two_main_verb(): - return valid_outbound_data - - r = client.get("/endpoint_with_stwo_main_verb") - assert r.data == valid_response_body - - -def test_register_route_with_method_route_and_methods_option_raise_a_exception( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint configured with the method_route - WHEN it is registered with the methods option - THEN it raises an Exception - """ - with pytest.raises(TypeError): - - @app.get( - "/route_method_and_methods_option", - methods=["GET"], - response_model=OutBoundModel, - ) - def route_method_and_methods_option(): - return valid_outbound_data - - -def test_endpoint_with_put( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint registered wit put - WHEN hit with a put request - THEN it responds with a 201 status code - """ - - @app.put("/put_http_verb", response_model=OutBoundModel) - def put_http_verb(): - return valid_outbound_data - - r = client.put("/put_http_verb") - assert r.status_code == 201 - assert r.data == valid_response_body - - -def test_endpoint_with_patch( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint registered wit put - WHEN hit with a patch request - THEN it responds with a 201 status code - """ - - @app.patch("/patch_http_verb", response_model=OutBoundModel) - def patch_http_verb(): - return valid_outbound_data - - r = client.patch("/patch_http_verb") - assert r.status_code == 200 - assert r.data == valid_response_body - - -def test_endpoint_without_response_model( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with no response_model configured or return annotation - WHEN hit - THEN it behaves like a regular Flask endpoint - """ - - @app.get("/no_response_model") - def no_response_model(): - return "Don't have a response model" - - r = client.get("/no_response_model") - assert r.status_code == 200 - assert r.data == b"Don't have a response model" - - -def test_endpoint_with_valid_return_annocation( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint without response_model but with a valid return annotation - WHEN registered - THEN the annotation is stored as response_model - """ - - @app.get("/valid_return_annotation") - def valid_return_annotation() -> OutBoundModel: - return OutBoundModel(**valid_outbound_data) - - r = client.get("/valid_return_annotation") - - assert r.status_code == 200 - assert r.data == valid_response_body - - -def test_invalid_response_model_raise_type_error_at_registration( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with invladidly typed response_model - WHEN registered - THEN it raises a TypeError - """ - with pytest.raises(TypeError): - - @app.get("/invalid_return_annotation") - def invalid_return_annotation() -> dict: - return valid_outbound_data - - with pytest.raises(TypeError): - - @app.get("/invalid_configuration", response_model=dict) - def invalid_configuration(): - return valid_outbound_data - - -def test_configured_response_model_take_prescedence_over_return_annotation( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint without a configured response_model and a return annotation - WHEN registered - THEN configrued response_model take prescedence over return annotation - """ - - @app.get("/configured_reponse_model_take_prescedence", response_model=OutBoundModel) - def test() -> dict: - return {"total_count": 10, "items": ["Apple", "Banana"]} - - r = client.get("/configured_reponse_model_take_prescedence") - - assert r.status_code == 200 - assert r.data == valid_response_body - - -def test_endpoint_can_turn_off_return_annocation( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with a valid return annotation - WHEN response_model is configured to be None - THEN the response_model registration is turned off - """ - - @app.get("/with_return_annotation_turned_off", response_model=None) - def test() -> OutBoundModel: - return OutBoundModel(**{"total_count": 10, "items": ["Apple", "Banana"]}) - - with pytest.raises(TypeError): - client.get("/with_return_annotation_turned_off") - - -def test_endpoint_with_response_model_and_dict_as_return_value( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with a response_model defined and dict return value - WHEN hit - THEN it serialize the dict using the response_model - #TODO: find a case where it wouldn't pass without the response_model !! - """ - - @app.get("/endpoint_returns_a_dict", response_model=OutBoundModel) - def endpoint_returns_a_dict() -> dict: - return valid_outbound_data - - r = client.get("/endpoint_returns_a_dict") - - assert r.status_code == 200 - assert r.data == valid_response_body - - -def test_endpoint_with_list_as_return_value( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with a response_model as a List - WHEN hit - THEN it serialize the list using the response_model - """ - - @app.get("/endpoint_returns_a_list", response_model=List[OutBoundModel]) - def test(): - return [valid_outbound_data, valid_outbound_data] - - r = client.get("/endpoint_returns_a_list") - - assert r.status_code == 200 - assert ( - r.data == b'[{"total_count": 10, "items": ["Apple", "Banana"]}, ' - b'{"total_count": 10, "items": ["Apple", "Banana"]}]' - ) - - -def test_endpoint_with_response_model_and_response_model_as_return_value( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with a response_model defined and dict return value - WHEN hit - THEN it serialize the dict using the response_model - #TODO: find a case where it wouldn't pass without the response_model !! - """ - - @app.get("/endpoint_returns_a_dict", response_model=OutBoundModel) - def test(): - return OutBoundModel(total_count=10, items=["Apple", "Banana"]) - - r = client.get("/endpoint_returns_a_dict") - - assert r.status_code == 200 - assert r.data == b'{"total_count": 10, "items": ["Apple", "Banana"]}' - - -def test_endpoint_with_response_model_and_response_as_return_value( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with a response_model defined and a Response return value - WHEN hit - THEN it sends the Response untouched - """ - - @app.get("/endpoint_returns_a_response", response_model=OutBoundModel) - def test(): - return Response(OutBoundModel(total_count=10, items=["Apple", "Banana"]).json()) - - r = client.get("/endpoint_returns_a_response") - - assert r.status_code == 200 - assert r.data == b'{"total_count": 10, "items": ["Apple", "Banana"]}' - - -def test_wrong_dict_being_sent( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with a response_model defined and a dict return value - WHEN hit and the return value is not valid - THEN it raises a InternalServerError, 500 - """ - - @app.get("/invalid_return_value", response_model=OutBoundModel) - def ping(): - return {"total_count": "not_valid", "items": ["Apple", "Banana"]} - - r = client.get("/invalid_return_value") - - assert r.status_code == 500 - assert r.data.startswith(b"InternalServerError") - - -def test_status_code_204_has_no_body( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with a response_model defined and a dict return value - WHEN hit and the return value is not valid - THEN it raises a InternalServerError, 500 - """ - - @app.get("/status_code_204_has_no_body", response_model=OutBoundModel) - def status_code_204_has_no_body(): - return "Some Content that will be ignored", 204 - - r = client.get("/status_code_204_has_no_body") - - assert r.status_code == 204 - assert r.data == b"" - - -def test_can_pass_headers_values( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with a response_model defined and a dict return value - WHEN hit and the return value is not valid - THEN it raises a InternalServerError, 500 - """ - - @app.get("/returned_headers_are_sent", response_model=OutBoundModel) - def returned_headers_are_sent(): - return valid_outbound_data, {"X-Test": "Test"} - - r = client.get("/returned_headers_are_sent") - - assert r.status_code == 200 - assert r.headers["X-Test"] == "Test" - - -def test_can_pass_headers_values_and_status_code( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint that returns a content, a status code and headers - WHEN hit - THEN the response has the status code and headers - """ - - @app.get("/returned_headers_and_status_code_are_sent", response_model=OutBoundModel) - def returned_headers_are_sent(): - return valid_outbound_data, 201, {"X-Test": "Test"} - - r = client.get("/returned_headers_and_status_code_are_sent") - - assert r.status_code == 201 - assert r.headers["X-Test"] == "Test" - - -@patch("flask_jeroboam._outboundhandler.METHODS_DEFAULT_STATUS_CODE", {"POST": 201}) -def test_exotic_http_verb_raise_a_warning_when_no_status_code_is_set( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with an exotic HTTP verb and no status_code defined - WHEN registered - THEN a User Warning is raised - """ - with pytest.warns(UserWarning): - - @app.get( - "/exotic_http_verb_raise_a_warning", - response_model=OutBoundModel, - ) - def exotic_http_verb(): - return valid_outbound_data - - r = client.get("/exotic_http_verb_raise_a_warning") - - assert r.status_code == 200 - - -@patch("flask_jeroboam._outboundhandler.METHODS_DEFAULT_STATUS_CODE", {"POST": 201}) -def test_exotic_http_verb_dont_raise_a_warning_when_status_code_is_set( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with an exotic HTTP verb and no status_code defined - WHEN registered - THEN a User Warning is raised - """ - with warnings.catch_warnings(): - warnings.simplefilter("error") - - @app.get( - "/exotic_http_verb_dont_raise_a_warning", - response_model=OutBoundModel, - status_code=200, - ) - def exotic_http_verb(): - return valid_outbound_data - - r = client.get("/exotic_http_verb_dont_raise_a_warning") - - assert r.status_code == 200 - - -def test_content_can_be_a_dataclass( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with an exotic HTTP verb and no status_code defined - WHEN registered - THEN a User Warning is raised - """ - - @dataclass - class MyDataClass: - total_count: int - items: List[str] - - @app.get("/response_value_as_dataclass", response_model=OutBoundModel) - def response_value_as_list(): - return MyDataClass(**valid_outbound_data) - - r = client.get("/response_value_as_dataclass") - - assert r.status_code == 200 - assert r.data == valid_response_body - - -def test_content_raise_an_error_if_anything_else( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with an exotic HTTP verb and no status_code defined - WHEN registered - THEN a User Warning is raised - """ - with pytest.raises(ValueError): - - @app.get("/response_value_is_not_valid_format", response_model=OutBoundModel) - def response_value_as_list(): - return "not a list" - - r = client.get("/response_value_is_not_valid_format") - - assert r.status_code == 500 - - -def test_reponse_model_filters_outbound_data_even_when_subclassing( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with an exotic HTTP verb and no status_code defined - WHEN registered - THEN a User Warning is raised - """ - - class SecureOutBoundUser(BaseModel): - username: str - - class InBoundUser(SecureOutBoundUser): - password: str - - @app.post("/filters_data", response_model=SecureOutBoundUser) - def filters_data(sensitive_data: InBoundUser = Body()): - return sensitive_data - - r = client.post( - "/filters_data", - json={"sensitive_data": {"username": "test", "password": "test"}}, - ) - - assert r.status_code == 201 - assert r.json == {"username": "test"} - - -def test_wrong_tuple_length_raise_error( - app: Jeroboam, - client: FlaskClient, -): - """GIVEN an endpoint with an exotic HTTP verb and no status_code defined - WHEN registered - THEN a User Warning is raised - """ - - @app.get("/wrong_tuple_length", response_model=OutBoundModel) - def wrong_tuple_length(): - return valid_outbound_data, 200, {"X-Test": "Test"}, "extra" - - with pytest.raises(TypeError): - r = client.get("/wrong_tuple_length") - - assert r.status_code == 500 diff --git a/tests/test_outbound_handler/__init__.py b/tests/test_outbound_handler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_outbound_handler/test_outbound_handler.py b/tests/test_outbound_handler/test_outbound_handler.py new file mode 100644 index 0000000..f82f5a9 --- /dev/null +++ b/tests/test_outbound_handler/test_outbound_handler.py @@ -0,0 +1,251 @@ +"""Testing Outbound Handler Use Cases. + +The endpoints are defined in the app_test.apps.outbound.py module. +""" +from typing import Any +from typing import Dict + +import pytest +from flask.testing import FlaskClient + +from flask_jeroboam.jeroboam import Jeroboam +from tests.app_test.models.outbound import SimpleModelOut + + +valid_outbound_data = {"items": ["Apple", "Banana"], "total_count": 10} +valid_response_body = {"items": ["Apple", "Banana"], "totalCount": 10} +unsorted_reponse_body = {"total_count": 10, "items": ["Apple", "Banana"]} + + +@pytest.mark.parametrize( + "url", ["/methods/explicit_options", "/methods/explicit_options_and_head"] +) +def test_register_route_with_additionnal_secondary_verb( + url: str, + client: FlaskClient, +): + """GIVEN an endpoint with secondary methods defined + WHEN configured with secondary verbs (OPTIONS, HEAD) + THEN it register and keeps the main verb + """ + response = client.get(url) + + assert response.json == valid_response_body + assert response.status_code == 200 + + +def test_register_route_with_two_main_verb_raise_a_warning( + app: Jeroboam, + client: FlaskClient, +): + """GIVEN an endpoint configured with two main http verb + WHEN it is registered + THEN it raises a warning and keep the first one + """ + with pytest.warns(UserWarning): + + @app.route( + "/endpoint_with_two_main_verb", + methods=["GET", "POST"], + response_model=SimpleModelOut, + ) + def with_two_main_verb(): + return valid_outbound_data + + response = client.get("/endpoint_with_two_main_verb") + assert response.json == valid_response_body + + +def test_register_route_with_method_route_and_methods_option_raise_a_exception( + app: Jeroboam, +): + """GIVEN an endpoint configured with the method_route + WHEN it is registered with the methods option + THEN it raises an Exception + """ + with pytest.raises(TypeError): + + @app.get( + "/route_method_and_methods_option", + methods=["GET"], + response_model=SimpleModelOut, + ) + def route_method_and_methods_option(): + return valid_outbound_data + + +def test_endpoint_without_response_model( + client: FlaskClient, +): + """GIVEN an endpoint with no response_model configured or return annotation + WHEN hit + THEN it behaves like a regular Flask endpoint + """ + response = client.get("/response_model/no_response_model") + assert response.status_code == 200 + assert response.data == b"Don't have a response model" + + +def test_endpoint_with_valid_return_annocation( + client: FlaskClient, +): + """GIVEN an endpoint without response_model but with a valid return annotation + WHEN registered + THEN the annotation is stored as response_model + """ + response = client.get("/response_model/infered_from_return_annotation") + assert response.status_code == 200 + assert response.json == valid_response_body + + +def test_invalid_response_model_raise_type_error_at_registration( + app: Jeroboam, +): + """GIVEN an endpoint with invalid response_model + WHEN registered + THEN it raises a TypeError + """ + with pytest.raises(TypeError): + + @app.get("/invalid_return_annotation_and_no_response_model") + def invalid_return_annotation() -> dict: + return valid_outbound_data + + with pytest.raises(TypeError): + + @app.get("/invalid_response_model", response_model=dict) + def invalid_configuration(): + return valid_outbound_data + + +def test_configured_response_model_take_prescedence_over_return_annotation( + client: FlaskClient, +): + """GIVEN an endpoint without a configured response_model and a return annotation + WHEN registered + THEN the configrued response_model take prescedence over the return annotation + """ + response = client.get("/response_model/configuration_over_inference") + assert response.status_code == 200 + assert response.json == valid_response_body + + +def test_endpoint_can_turn_off_return_annocation( + client: FlaskClient, +): + """GIVEN an endpoint with a valid return annotation + WHEN response_model is configured to be None + THEN the response_model registration is turned off + """ + with pytest.raises(TypeError): + client.get("/response_model/turned_off") + + +@pytest.mark.parametrize( + "type_,expected_response", + [ + ("dict", valid_response_body), + ("list", [unsorted_reponse_body, unsorted_reponse_body]), + ("base_model", valid_response_body), + ("response", valid_response_body), + ("dataclass", valid_response_body), + ], +) +def test_view_function_with_response_model_return_type( + type_: str, + expected_response: Any, + client: FlaskClient, +): + """GIVEN an endpoint with a response_model defined and dict return value + WHEN hit + THEN it serialize the dict using the response_model + #TODO: find a case where it wouldn't pass without the response_model !! + """ + response = client.get(f"/return_type/{type_}") + + assert response.status_code == 200 + assert response.json == expected_response + + +def test_wrong_dict_being_sent( + app: Jeroboam, + client: FlaskClient, +): + """GIVEN an endpoint with a response_model defined and a dict return value + WHEN hit and the return value is not valid + THEN it raises a InternalServerError, 500 + """ + + @app.get("/invalid_return_value", response_model=SimpleModelOut) + def ping(): + return {"total_count": "not_valid", "items": ["Apple", "Banana"]} + + response = client.get("/invalid_return_value") + + assert response.status_code == 500 + assert response.data.startswith(b"InternalServerError") + + +@pytest.mark.parametrize( + "shape,status_code,headers", + [ + ("with_headers", 200, {"X-Test": "Test"}), + ("with_status_code", 218, {}), + ("with_headers_and_status_code", 218, {"X-Test": "Test"}), + ], +) +def test_view_function_tuple_return_shape( + shape: str, + status_code: int, + headers: Dict[str, str], + client: FlaskClient, +): + """GIVEN an endpoint with a response_model defined and a dict return value + WHEN hit and the return value is not valid + THEN it raises a InternalServerError, 500 + """ + response = client.get(f"/return_shape/{shape}") + assert response.status_code == status_code + assert response.headers.get("X-Test", "Empty") == headers.get("X-Test", "Empty") + + +def test_wrong_tuple_length_raise_error( + client: FlaskClient, +): + """GIVEN a viewfunction with the wrongly shaped tuple (>3) + WHEN hit + THEN it raises a TypeError (as in Flask) and return a code 500 + """ + with pytest.raises(TypeError): + respone = client.get("/return_shape/wrong_tuple_length") + + assert respone.status_code == 500 + + +def test_content_raise_an_error_if_anything_else( + client: FlaskClient, +): + """GIVEN an endpoint with an exotic HTTP verb and no status_code defined + WHEN registered + THEN a Value Warning is raised + """ + with pytest.raises(ValueError): + response = client.get("/return_type/not_valid") + + assert response.status_code == 500 + + +def test_reponse_model_filters_outbound_data_even_when_subclassing( + client: FlaskClient, +): + """GIVEN an endpoint with an exotic HTTP verb and no status_code defined + WHEN registered + THEN a User Warning is raised + """ + response = client.post( + "/sensitive_data", + json={"sensitive_data": {"username": "test", "password": "test"}}, + ) + + assert response.status_code == 201 + assert response.json == {"username": "test"} diff --git a/tests/test_outbound_handler/test_status_code.py b/tests/test_outbound_handler/test_status_code.py new file mode 100644 index 0000000..6a1fe6b --- /dev/null +++ b/tests/test_outbound_handler/test_status_code.py @@ -0,0 +1,104 @@ +"""Testing for status code configuration. + +Endpoints are defined in the app_test.apps.outbound.py module. +""" +import warnings +from unittest.mock import patch + +import pytest +from flask.testing import FlaskClient + +from flask_jeroboam.jeroboam import Jeroboam + +from ..app_test.models.outbound import SimpleModelOut + + +valid_outbound_data = {"items": ["Apple", "Banana"], "total_count": 10} +valid_response_body = {"items": ["Apple", "Banana"], "totalCount": 10} +unsorted_reponse_body = {"total_count": 10, "items": ["Apple", "Banana"]} + + +def test_endpoint_with_put( + client: FlaskClient, +): + """GIVEN an endpoint registered wit put + WHEN hit with a put request + THEN it responds with a 201 status code + """ + response = client.put("/verb/put/without_explicit_status_code") + assert response.status_code == 201 + assert response.json == valid_response_body + + +def test_endpoint_with_patch( + client: FlaskClient, +): + """GIVEN an endpoint registered wit put + WHEN hit with a patch request + THEN it responds with a 201 status code + """ + response = client.patch("/verb/patch/without_explicit_status_code") + assert response.status_code == 200 + assert response.json == valid_response_body + + +@pytest.mark.parametrize("variant", ["as_returned", "as_configured"]) +def test_status_code_204_has_no_body( + variant: str, + client: FlaskClient, +): + """GIVEN a 204 Status Code + WHEN building the response + THEN the response has no body + """ + response = client.get(f"/status_code/204_has_no_body/{variant}") + assert response.status_code == 204 + assert response.data == b"" + + +@patch("flask_jeroboam._outboundhandler.METHODS_DEFAULT_STATUS_CODE", {"POST": 201}) +def test_exotic_http_verb_raise_a_warning_when_no_status_code_is_set( + app: Jeroboam, + client: FlaskClient, +): + """GIVEN an endpoint with an exotic HTTP verb and no status_code defined + WHEN registered + THEN a User Warning is raised + """ + with pytest.warns(UserWarning): + + @app.get( + "/exotic_http_verb_raise_a_warning", + response_model=SimpleModelOut, + ) + def exotic_http_verb(): + return valid_outbound_data + + response = client.get("/exotic_http_verb_raise_a_warning") + + assert response.status_code == 200 + + +@patch("flask_jeroboam._outboundhandler.METHODS_DEFAULT_STATUS_CODE", {"POST": 201}) +def test_exotic_http_verb_dont_raise_a_warning_when_status_code_is_set( + app: Jeroboam, + client: FlaskClient, +): + """GIVEN an endpoint with an exotic HTTP verb and no status_code defined + WHEN registered + THEN no User Warning is raised + """ + with warnings.catch_warnings(): + warnings.simplefilter("error") + + @app.get( + "/exotic_http_verb_dont_raise_a_warning", + response_model=SimpleModelOut, + status_code=200, + ) + def exotic_http_verb(): + return valid_outbound_data + + response = client.get("/exotic_http_verb_dont_raise_a_warning") + + assert response.status_code == 200 diff --git a/tests/test_utils.py b/tests/test_utils.py index 61ac27d..488ebf3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,17 +1,15 @@ """Testing Utils.""" from functools import partial -from typing import List import pytest from flask.testing import FlaskClient -from pydantic import Field from flask_jeroboam import Body from flask_jeroboam.jeroboam import Jeroboam -from flask_jeroboam.models import Parser -from flask_jeroboam.models import Serializer from flask_jeroboam.utils import _rename_query_params_keys from flask_jeroboam.view_params.solved import SolvedParameter +from tests.app_test.models.inbound import ModelWithListIn +from tests.app_test.models.outbound import ModelWithListOut def test_pascal_case_in_and_out_snake_case(app: Jeroboam, client: FlaskClient): @@ -19,33 +17,23 @@ def test_pascal_case_in_and_out_snake_case(app: Jeroboam, client: FlaskClient): WHEN payload is send in pascalCase THEN it lives in python in snake_case and send back in pascalCase """ - - class OutboundModel(Serializer): - page: int - per_page: int - ids: List[int] - order: List[dict] - - class InboundModel(Parser): - page: int - per_page: int - ids: List[int] = Field(alias="id[]") - order: List[dict] = Field(alias="order[]") - + # We need to define the endpoint here to set the query_string_key_transformer first. app.query_string_key_transformer = partial( _rename_query_params_keys, pattern=r"(.*)\[(.+)\]$" ) - @app.get("/web_boundaries", response_model=OutboundModel) - def read_items(payload: InboundModel): + @app.get( + "/query/special_pattern/after_configuration", response_model=ModelWithListOut + ) + def read_items(payload: ModelWithListIn): return payload - r = client.get( - "web_boundaries?page=1&perPage=10&id[]=1&id[]=2&order[name]=asc&order[age]=desc" + response = client.get( + "/query/special_pattern/after_configuration?page=1&perPage=10&id[]=1&id[]=2&order[name]=asc&order[age]=desc" ) - assert r.status_code == 200 - assert r.json == { + assert response.status_code == 200 + assert response.json == { "page": 1, "perPage": 10, "ids": [1, 2], @@ -53,36 +41,17 @@ def read_items(payload: InboundModel): } -def test_pascal_case_in_and_out_snake_case_without_transformer( - app: Jeroboam, client: FlaskClient -): +def test_pascal_case_in_and_out_snake_case_without_transformer(client: FlaskClient): """GIVEN an endpoint with param typed with a Parser and response_model a Serializer WHEN payload is send in pascalCase THEN it lives in python in snake_case and send back in pascalCase """ - - class OutboundModel(Serializer): - page: int - per_page: int - ids: List[int] - order: List[dict] - - class InboundModel(Parser): - page: int - per_page: int - ids: List[int] = Field(alias="id[]") - order: List[dict] = Field(alias="order[]") - - @app.get("/web_boundaries", response_model=OutboundModel) - def read_items(payload: InboundModel): - return payload - - r = client.get( - "web_boundaries?page=1&perPage=10&id[]=1&id[]=2&order[name]=asc&order[age]=desc" + response = client.get( + "/query/special_pattern?page=1&perPage=10&id[]=1&id[]=2&order[name]=asc&order[age]=desc" ) - assert r.status_code == 400 - assert r.json == { + assert response.status_code == 400 + assert response.json == { "detail": [ { "loc": ["query", "payload", "order[]"], From 95fad35b5c570937d02e0e6603d474e43594dbf6 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Bianic Date: Thu, 2 Feb 2023 16:23:15 +0100 Subject: [PATCH 08/10] :ok_hand: Improved Behavior when no response_model is present + others --- flask_jeroboam/__init__.py | 3 +++ flask_jeroboam/_outboundhandler.py | 16 ++++++++-------- flask_jeroboam/models.py | 4 ++-- flask_jeroboam/view.py | 3 +-- flask_jeroboam/view_params/solved.py | 2 +- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/flask_jeroboam/__init__.py b/flask_jeroboam/__init__.py index c602c3a..c573e7a 100644 --- a/flask_jeroboam/__init__.py +++ b/flask_jeroboam/__init__.py @@ -1,6 +1,9 @@ from .jeroboam import Jeroboam from .jeroboam import JeroboamBlueprint +from .models import InboundModel +from .models import OutboundModel from .view_params.functions import Body +from .view_params.functions import Cookie from .view_params.functions import File from .view_params.functions import Form from .view_params.functions import Header diff --git a/flask_jeroboam/_outboundhandler.py b/flask_jeroboam/_outboundhandler.py index 05c139c..db4e803 100644 --- a/flask_jeroboam/_outboundhandler.py +++ b/flask_jeroboam/_outboundhandler.py @@ -78,10 +78,6 @@ def __init__( ) self.response_class = response_class - def is_valid_handler(self) -> bool: - """Should the handler add behavior to the view_function?""" - return bool(self.response_model) - def add_outbound_handling_to( self, view_func: JeroboamRouteCallable ) -> JeroboamRouteCallable: @@ -104,7 +100,6 @@ def outbound_handling(*args: Any, **kwargs: Any) -> JeroboamResponseReturnValue: Credits: this algorithm and subalgorithms are inspired by FastAPI. """ initial_return_value = current_app.ensure_sync(view_func)(*args, **kwargs) - # TODO: Do we need to deal with BackgroundTasks Here ?? if issubclass(initial_return_value.__class__, Response): return initial_return_value ( @@ -117,8 +112,13 @@ def outbound_handling(*args: Any, **kwargs: Any) -> JeroboamResponseReturnValue: return self._build_response( status_code=solved_status_code, headers=headers ) - content = self._serialize_content(returned_body) - return self._build_response(content, solved_status_code, headers=headers) + if self.response_model: + content = self._serialize_content(returned_body) + return self._build_response( + content, solved_status_code, headers=headers + ) + else: + return returned_body, solved_status_code, headers return outbound_handling @@ -139,7 +139,7 @@ def _unpack_view_function_return_value( return ( initial_return_value[0], initial_return_value[1], # type:ignore - None, + {}, ) else: return ( diff --git a/flask_jeroboam/models.py b/flask_jeroboam/models.py index 3c73a23..82f7e45 100644 --- a/flask_jeroboam/models.py +++ b/flask_jeroboam/models.py @@ -47,7 +47,7 @@ def json_dumps_to_camel_case(*args, **kwargs): return json.dumps(*args, **kwargs) -class Parser(BaseModel): +class InboundModel(BaseModel): """Basic configuration for parsing Requests.""" class Config: @@ -57,7 +57,7 @@ class Config: allow_population_by_field_name = True -class Serializer(BaseModel): +class OutboundModel(BaseModel): """Basic Configiration for serializing Responses.""" class Config: diff --git a/flask_jeroboam/view.py b/flask_jeroboam/view.py index 63d5ea4..ebc7428 100644 --- a/flask_jeroboam/view.py +++ b/flask_jeroboam/view.py @@ -64,8 +64,7 @@ def as_view(self) -> JeroboamRouteCallable: if self.inbound_handler.is_valid: view_func = self.inbound_handler.add_inbound_handling_to(view_func) - if self.outbound_handler.is_valid_handler(): - view_func = self.outbound_handler.add_outbound_handling_to(view_func) + view_func = self.outbound_handler.add_outbound_handling_to(view_func) view_func.__name__ = name view_func.__doc__ = doc diff --git a/flask_jeroboam/view_params/solved.py b/flask_jeroboam/view_params/solved.py index b79af6c..04aa5d7 100644 --- a/flask_jeroboam/view_params/solved.py +++ b/flask_jeroboam/view_params/solved.py @@ -44,7 +44,7 @@ def __init__( self.name = name self.location: Optional[ParamLocation] = getattr(view_param, "location", None) if self.location == ParamLocation.file: - BaseConfig.arbitrary_types_allowed = True + model_config.arbitrary_types_allowed = True self.required = required self.embed = getattr(view_param, "embed", None) self.in_body = getattr(view_param, "in_body", None) From 0a2b9d2950bf30160f3a8cf60050f06782b05760 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Bianic Date: Thu, 2 Feb 2023 22:18:10 +0100 Subject: [PATCH 09/10] :hammer: Adressing CodeClimate Complexity Alerts --- .flake8 | 5 +- flask_jeroboam/__init__.py | 14 +- flask_jeroboam/_inboundhandler.py | 39 ++- flask_jeroboam/_outboundhandler.py | 9 +- flask_jeroboam/utils.py | 33 +- flask_jeroboam/view_params/functions.py | 303 ++---------------- flask_jeroboam/view_params/functions.pyi | 158 +++++++++ flask_jeroboam/view_params/parameters.py | 21 +- flask_jeroboam/view_params/solved.py | 104 +++--- poetry.lock | 34 +- pyproject.toml | 1 + tests/app_test/apps/body.py | 4 +- tests/app_test/apps/file.py | 2 +- tests/app_test/apps/form.py | 4 +- tests/app_test/apps/outbound.py | 2 +- tests/app_test/models/inbound.py | 8 + .../test_query_operations.py | 2 +- .../test_outbound_handler.py | 2 +- tests/test_utils.py | 4 +- 19 files changed, 371 insertions(+), 378 deletions(-) create mode 100644 flask_jeroboam/view_params/functions.pyi diff --git a/.flake8 b/.flake8 index 4bc241b..09b3aaf 100644 --- a/.flake8 +++ b/.flake8 @@ -2,12 +2,15 @@ select = B,B9,C,D,DAR,E,F,N,RST,S,W ignore = D104,E203,E501,RST201,RST203,RST301,W503,D105,D107 max-line-length = 80 -max-complexity = 10 +max-complexity = 5 docstring-convention = google per-file-ignores = tests/*:S101,D100,D205,D415,S106,B008,D101 tests/app_test/*:D103,B008 __init__.py:F401 typing.py:F401 + *.pyi:E302,E704,D103,N802 + functions.py:N802 + noxfile.py:C901 rst-roles = class,const,func,meth,mod,ref rst-directives = deprecated diff --git a/flask_jeroboam/__init__.py b/flask_jeroboam/__init__.py index c573e7a..25a36e0 100644 --- a/flask_jeroboam/__init__.py +++ b/flask_jeroboam/__init__.py @@ -2,10 +2,10 @@ from .jeroboam import JeroboamBlueprint from .models import InboundModel from .models import OutboundModel -from .view_params.functions import Body -from .view_params.functions import Cookie -from .view_params.functions import File -from .view_params.functions import Form -from .view_params.functions import Header -from .view_params.functions import Path -from .view_params.functions import Query +from .view_params import Body +from .view_params import Cookie +from .view_params import File +from .view_params import Form +from .view_params import Header +from .view_params import Path +from .view_params import Query diff --git a/flask_jeroboam/_inboundhandler.py b/flask_jeroboam/_inboundhandler.py index 05aee3e..d81d5d3 100644 --- a/flask_jeroboam/_inboundhandler.py +++ b/flask_jeroboam/_inboundhandler.py @@ -153,25 +153,17 @@ def _solve_view_function_parameter( - What is its type/annotation? - Is it a scalar or a sequence? - Is it required and/or has a default value? + # Split it into functions for each step. """ - # Solving Location - if param_name in self.path_param_names: - solved_location = ParamLocation.path - else: - solved_location = getattr( - param.default, "location", force_location or self.default_param_location - ) + solved_location = self._solve_location(param_name, param, force_location) # Get the ViewParam if isinstance(param.default, ViewParameter): view_param = param.default else: param_class = get_parameter_class(solved_location) - view_param = param_class(default=param.default) + view_param = param_class(param.default) - # Solving Default Value - default_value: Any = getattr(param.default, "default", param.default) - if default_value == param.empty or ignore_default: - default_value = Undefined + default_value = self._solve_default_value(param, ignore_default) # Solving Required required: bool = default_value is Undefined @@ -186,6 +178,29 @@ def _solve_view_function_parameter( view_param=view_param, ) + def _solve_location( + self, + param_name: str, + param: inspect.Parameter, + force_location: Optional[ParamLocation] = None, + ) -> ParamLocation: + if param_name in self.path_param_names: + return ParamLocation.path + else: + return getattr( + param.default, "location", force_location or self.default_param_location + ) + + def _solve_default_value( + self, + param: inspect.Parameter, + ignore_default: bool, + ) -> Any: + default_value: Any = getattr(param.default, "default", param.default) + if default_value == param.empty or ignore_default: + default_value = Undefined + return default_value + def _register_view_parameter(self, solved_parameter: SolvedParameter) -> None: """Registering the Solved View parameters for the View Function. diff --git a/flask_jeroboam/_outboundhandler.py b/flask_jeroboam/_outboundhandler.py index db4e803..6fec52a 100644 --- a/flask_jeroboam/_outboundhandler.py +++ b/flask_jeroboam/_outboundhandler.py @@ -112,13 +112,10 @@ def outbound_handling(*args: Any, **kwargs: Any) -> JeroboamResponseReturnValue: return self._build_response( status_code=solved_status_code, headers=headers ) - if self.response_model: - content = self._serialize_content(returned_body) - return self._build_response( - content, solved_status_code, headers=headers - ) - else: + if self.response_model is None: return returned_body, solved_status_code, headers + content = self._serialize_content(returned_body) + return self._build_response(content, solved_status_code, headers=headers) return outbound_handling diff --git a/flask_jeroboam/utils.py b/flask_jeroboam/utils.py index aefaac2..779175f 100644 --- a/flask_jeroboam/utils.py +++ b/flask_jeroboam/utils.py @@ -77,49 +77,38 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any: # pragma: no def is_scalar_field(field: ModelField) -> bool: """Check if a field is a scalar field.""" - field_info = field.field_info return ( False if field.shape != SHAPE_SINGLETON or lenient_issubclass(field.type_, BaseModel) or lenient_issubclass(field.type_, sequence_types + (dict,)) or dataclasses.is_dataclass(field.type_) - or isinstance(field_info, ViewParameter) - or getattr(field_info, "location", None) in body_locations + or isinstance(field.field_info, ViewParameter) + or getattr(field.field_info, "location", None) in body_locations else not field.sub_fields # pragma: no cover or all(is_scalar_field(f) for f in field.sub_fields) ) -def is_scalar_sequence_field(field: ModelField) -> bool: +def is_sequence_field(field: ModelField) -> bool: """Check if a field is a sequence field.""" if (field.shape in sequence_shapes) and not lenient_issubclass( field.type_, BaseModel ): - if field.sub_fields is not None: # pragma: no cover - for sub_field in field.sub_fields: - if not is_scalar_field(sub_field): - return False return True return bool(lenient_issubclass(field.type_, sequence_types)) def _rename_query_params_keys(self, inbound_dict: dict, pattern: str) -> dict: - """Rename keys in a dictionary. - - Probablement Obsolete. - """ - renamings = [] - for key, value in inbound_dict.items(): - match = re.match(pattern, key) + """Rename keys in a dictionary.""" + frozen_inbound_dict = inbound_dict.copy() + for old_key, value in frozen_inbound_dict.items(): + match = re.match(pattern, old_key) if match is not None: new_key = f"{match[1]}[]" new_value = {match[2]: value} - renamings.append((key, new_key, new_value)) - for key, new_key, new_value in renamings: - if new_key not in inbound_dict: - inbound_dict[new_key] = [new_value] - else: - inbound_dict[new_key].append(new_value) - del inbound_dict[key] + new_array = inbound_dict.get(new_key, []) + new_array.append(new_value) + inbound_dict[new_key] = new_array + del inbound_dict[old_key] return inbound_dict diff --git a/flask_jeroboam/view_params/functions.py b/flask_jeroboam/view_params/functions.py index 8947ad4..e0ce63b 100644 --- a/flask_jeroboam/view_params/functions.py +++ b/flask_jeroboam/view_params/functions.py @@ -2,335 +2,98 @@ This functions are used to declare the parameters of the view functions. By annotating the return value with Any, we make sure that the code editor -don't complain too much about assigning a default value of type ViewParameter +don't complain about assigning a default value of type ViewParameter to a parameter that have been annotated with a pydantic-compatible type... +They're primary purprose is to trick the code editor as they only delegate +to actual ViewParameter instantiers. +Their signature are defined in adjacent file functions.pyi. -Credits: This module is essentially a fork from the params module of FlaskAPI. +Credits: This module is a fork of FlaskAPI params_function module. """ from typing import Any -from typing import Dict -from typing import Optional - -from pydantic.fields import Undefined from .parameters import BodyParameter from .parameters import CookieParameter from .parameters import FileParameter from .parameters import FormParameter from .parameters import HeaderParameter -from .parameters import ParamLocation from .parameters import PathParameter from .parameters import QueryParameter def Path( # noqa:N802 - default: Any = Undefined, - *, - alias: Optional[str] = None, - title: Optional[str] = None, - description: Optional[str] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - regex: Optional[str] = None, - example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, - deprecated: Optional[bool] = None, - include_in_schema: bool = True, - **extra: Any, + *args: Any, + **kwargs: Any, ) -> Any: """Declare A Path parameter.""" return PathParameter( - default=default, - alias=alias, - title=title, - description=description, - gt=gt, - ge=ge, - lt=lt, - le=le, - min_length=min_length, - max_length=max_length, - regex=regex, - example=example, - examples=examples, - deprecated=deprecated, - include_in_schema=include_in_schema, - **extra, + *args, + **kwargs, ) def Query( # noqa:N802 - default: Any = Undefined, - *, - alias: Optional[str] = None, - title: Optional[str] = None, - description: Optional[str] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - regex: Optional[str] = None, - example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, - deprecated: Optional[bool] = None, - include_in_schema: bool = True, - **extra: Any, + *args: Any, + **kwargs: Any, ) -> Any: """Declare A Query parameter.""" return QueryParameter( - default=default, - alias=alias, - title=title, - description=description, - gt=gt, - ge=ge, - lt=lt, - le=le, - min_length=min_length, - max_length=max_length, - regex=regex, - example=example, - examples=examples, - deprecated=deprecated, - include_in_schema=include_in_schema, - **extra, + *args, + **kwargs, ) def Header( # noqa:N802 - default: Any = Undefined, - *, - alias: Optional[str] = None, - title: Optional[str] = None, - description: Optional[str] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - regex: Optional[str] = None, - example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, - deprecated: Optional[bool] = None, - include_in_schema: bool = True, - convert_underscores: bool = True, # for headers - **extra: Any, + *args: Any, + **kwargs: Any, ) -> Any: """Declare A Header parameter.""" return HeaderParameter( - default=default, - alias=alias, - title=title, - description=description, - gt=gt, - ge=ge, - lt=lt, - le=le, - min_length=min_length, - max_length=max_length, - regex=regex, - example=example, - examples=examples, - deprecated=deprecated, - include_in_schema=include_in_schema, - convert_underscores=convert_underscores, - **extra, + *args, + **kwargs, ) def Cookie( # noqa:N802 - default: Any = Undefined, - *, - alias: Optional[str] = None, - title: Optional[str] = None, - description: Optional[str] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - regex: Optional[str] = None, - example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, - deprecated: Optional[bool] = None, - include_in_schema: bool = True, - **extra: Any, + *args: Any, + **kwargs: Any, ) -> Any: """Declare A Cookie parameter.""" return CookieParameter( - default=default, - alias=alias, - title=title, - description=description, - gt=gt, - ge=ge, - lt=lt, - le=le, - min_length=min_length, - max_length=max_length, - regex=regex, - example=example, - examples=examples, - deprecated=deprecated, - include_in_schema=include_in_schema, - **extra, + *args, + **kwargs, ) def Body( # noqa:N802 - default: Any = Undefined, - *, - alias: Optional[str] = None, - title: Optional[str] = None, - description: Optional[str] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - regex: Optional[str] = None, - example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, - deprecated: Optional[bool] = None, - include_in_schema: bool = True, - location: ParamLocation = ParamLocation.query, # for all - required: bool = False, - convert_underscores: bool = True, # for headers - embed: bool = True, # for body - media_type: str = "application/json", - **extra: Any, + *args: Any, + **kwargs: Any, ) -> Any: """Declare A Body parameter.""" return BodyParameter( - default=default, - alias=alias, - title=title, - description=description, - gt=gt, - ge=ge, - lt=lt, - le=le, - min_length=min_length, - max_length=max_length, - regex=regex, - example=example, - examples=examples, - deprecated=deprecated, - include_in_schema=include_in_schema, - location=location, - required=required, - convert_underscores=convert_underscores, - embed=embed, - media_type=media_type, - **extra, + *args, + **kwargs, ) def Form( # noqa: N802 - default: Any = Undefined, - *, - alias: Optional[str] = None, - title: Optional[str] = None, - description: Optional[str] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - regex: Optional[str] = None, - example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, - deprecated: Optional[bool] = None, - include_in_schema: bool = True, - location: ParamLocation = ParamLocation.query, # for all - required: bool = False, - convert_underscores: bool = True, # for headers - embed: bool = False, # for body - media_type: str = "application/json", - **extra: Any, + *args: Any, + **kwargs: Any, ) -> Any: """Declare A Form parameter.""" return FormParameter( - default=default, - alias=alias, - title=title, - description=description, - gt=gt, - ge=ge, - lt=lt, - le=le, - min_length=min_length, - max_length=max_length, - regex=regex, - example=example, - examples=examples, - deprecated=deprecated, - include_in_schema=include_in_schema, - location=location, - required=required, - convert_underscores=convert_underscores, - embed=embed, - media_type=media_type, - **extra, + *args, + **kwargs, ) def File( # noqa: N802 - default: Any = Undefined, - *, - alias: Optional[str] = None, - title: Optional[str] = None, - description: Optional[str] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - regex: Optional[str] = None, - example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, - deprecated: Optional[bool] = None, - include_in_schema: bool = True, - location: ParamLocation = ParamLocation.query, # for all - required: bool = False, - convert_underscores: bool = True, # for headers - embed: bool = True, # for body - media_type: str = "application/json", - **extra: Any, + *args: Any, + **kwargs: Any, ) -> Any: """Declare A File parameter.""" return FileParameter( - default=default, - alias=alias, - title=title, - description=description, - gt=gt, - ge=ge, - lt=lt, - le=le, - min_length=min_length, - max_length=max_length, - regex=regex, - example=example, - examples=examples, - deprecated=deprecated, - include_in_schema=include_in_schema, - location=location, - required=required, - convert_underscores=convert_underscores, - embed=embed, - media_type=media_type, - **extra, + *args, + **kwargs, ) diff --git a/flask_jeroboam/view_params/functions.pyi b/flask_jeroboam/view_params/functions.pyi new file mode 100644 index 0000000..e5310fa --- /dev/null +++ b/flask_jeroboam/view_params/functions.pyi @@ -0,0 +1,158 @@ +"""Stub file for the adjacent functions module. + +Note that the default setting behavior is implemented +in each __init__ methods. + +Documentation on stubfiles can be found here: +https://mypy.readthedocs.io/en/stable/stubs.html +""" + +from typing import Any +from typing import Dict +from typing import Optional + +from pydantic.fields import Undefined + +def Path( + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + **extra: Any, +) -> Any: ... +def Query( + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + **extra: Any, +) -> Any: ... +def Header( + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + convert_underscores: bool = True, + **extra: Any, +) -> Any: ... +def Cookie( + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + **extra: Any, +) -> Any: ... +def Body( + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + required: bool = False, + embed: bool = True, # for body + media_type: str = "application/json", + **extra: Any, +) -> Any: ... +def Form( + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + required: bool = False, + embed: bool = True, + media_type: str = "application/x-www-form-urlencoded", + **extra: Any, +) -> Any: ... +def File( # noqa: N802 + default: Any = Undefined, + *, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[Dict[str, Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + required: bool = False, + embed: bool = False, + media_type: str = "multipart/form-data", + **extra: Any, +) -> Any: ... diff --git a/flask_jeroboam/view_params/parameters.py b/flask_jeroboam/view_params/parameters.py index 422b78f..68d8535 100644 --- a/flask_jeroboam/view_params/parameters.py +++ b/flask_jeroboam/view_params/parameters.py @@ -37,8 +37,9 @@ def __init__( self.example = kwargs.pop("example", Undefined) self.examples = kwargs.pop("examples", None) self.embed = kwargs.pop("embed", False) + self.include_in_schema = kwargs.get("include_in_schema", True) super().__init__( - default=default, + default, **kwargs, ) @@ -66,7 +67,7 @@ def __init__( self.deprecated = kwargs.pop("deprecated", None) self.include_in_schema = kwargs.pop("include_in_schema", True) super().__init__( - default=default, + default, **kwargs, ) @@ -84,12 +85,12 @@ class PathParameter(NonBodyParameter): def __init__( self, - default: Any = Undefined, + *args: Any, **kwargs: Any, ): self.required = True super().__init__( - default=..., + ..., **kwargs, ) @@ -106,7 +107,7 @@ def __init__( ): self.convert_underscores = kwargs.pop("convert_underscores", True) super().__init__( - default=default, + default, **kwargs, ) @@ -135,7 +136,7 @@ def __init__( self.embed = kwargs.get("embed", False) self.media_type = kwargs.pop("media_type", "application/json") super().__init__( - default=default, + default, **kwargs, ) @@ -150,10 +151,10 @@ def __init__( default: Any = Undefined, **kwargs: Any, ): - self.media_type = kwargs.pop("media_type", "application/x-www-form-urlencoded") embed = kwargs.pop("embed", True) + self.media_type = kwargs.pop("media_type", "application/x-www-form-urlencoded") super().__init__( - default=default, + default, embed=embed, **kwargs, ) @@ -169,10 +170,10 @@ def __init__( default: Any = Undefined, **kwargs: Any, ): - self.media_type = kwargs.pop("media_type", "multipart/form-data") embed = kwargs.pop("embed", False) + self.media_type = kwargs.pop("media_type", "multipart/form-data") super().__init__( - default=default, + default, embed=embed, **kwargs, ) diff --git a/flask_jeroboam/view_params/solved.py b/flask_jeroboam/view_params/solved.py index 04aa5d7..5e3b7c8 100644 --- a/flask_jeroboam/view_params/solved.py +++ b/flask_jeroboam/view_params/solved.py @@ -17,7 +17,7 @@ from pydantic.fields import ModelField from werkzeug.datastructures import MultiDict -from flask_jeroboam.utils import is_scalar_sequence_field +from flask_jeroboam.utils import is_sequence_field from .parameters import ParamLocation from .parameters import ViewParameter @@ -74,14 +74,14 @@ def validate_request(self): errors = [] assert self.location is not None # noqa: S101 inbound_values = self._get_values() - if inbound_values is None: - if self.required: - errors.append( - ErrorWrapper(MissingError(), loc=(self.location.value, self.alias)) - ) - else: - values = {self.name: deepcopy(self.default)} + if inbound_values is None and self.required: + errors.append( + ErrorWrapper(MissingError(), loc=(self.location.value, self.alias)) + ) return values, errors + # Should I return errors here ? + elif inbound_values is None and not self.required: + inbound_values = deepcopy(self.default) values_, errors_ = self.validate( inbound_values, values, loc=(self.location.value, self.alias) ) @@ -95,8 +95,22 @@ def _get_values(self) -> Union[dict, Optional[str], List[Any]]: """Get the values from the request.""" if self.in_body: return self._get_values_from_body() + elif self.location == ParamLocation.query: + source = request.args + elif self.location == ParamLocation.path: + source = MultiDict(request.view_args) + elif self.location == ParamLocation.header: + source = MultiDict(request.headers) + elif self.location == ParamLocation.cookie: + source = request.cookies else: - return self._get_values_from_request() + raise ValueError("Unknown location") + has_key_transformer = ( + getattr(current_app, "query_string_key_transformer", False) is not None + ) + return self._get_values_from_request( + self, source, self.name, self.alias, has_key_transformer + ) def _get_values_from_body(self) -> Any: """Get the values from the request body.""" @@ -109,7 +123,14 @@ def _get_values_from_body(self) -> Any: source = request.json or {} return source.get(self.alias or self.name) if self.embed else source - def _get_values_from_request(self) -> Union[dict, Optional[str], List[Any]]: + def _get_values_from_request( + self, + field: ModelField, + source: MultiDict, + name: str, + alias: str, + has_key_transformer: bool = False, + ) -> Union[dict, Optional[str], List[Any]]: """Get the values from the request. # TODO: Gestion des alias de fields. @@ -117,36 +138,41 @@ def _get_values_from_request(self) -> Union[dict, Optional[str], List[Any]]: # Est-ce qu'on gère le embed dans les QueryParams ? """ values: Union[dict, Optional[str], List[Any]] = {} - source: MultiDict = MultiDict() - # Decide on the source of the values - if self.location == ParamLocation.query: - source = request.args - elif self.location == ParamLocation.path: - source = MultiDict(request.view_args) - elif self.location == ParamLocation.header: - source = MultiDict(request.headers) - elif self.location == ParamLocation.cookie: - source = request.cookies - else: - raise ValueError("Unknown location") - - if hasattr(self.type_, "__fields__"): + if hasattr(field.type_, "__fields__"): assert isinstance(values, dict) # noqa: S101 - for field_name, field in self.type_.__fields__.items(): - values[field_name] = ( - source.getlist(field.alias or field_name) - if is_scalar_sequence_field(field) - else source.get(field.alias or field_name) + for field_name, subfield in field.type_.__fields__.items(): + values[field_name] = self._get_values_from_request( + subfield, source, field_name, subfield.alias, has_key_transformer ) - if values[field_name] is None and getattr( - current_app, "query_string_key_transformer", False - ): - values_ = current_app.query_string_key_transformer( # type: ignore - current_app, source.to_dict() - ) - values[field_name] = values_.get(field.alias or field_name) - elif is_scalar_sequence_field(self): - values = source.getlist(self.alias or self.name) + elif is_sequence_field(field): + values = _extract_sequence(source, alias, name) + if len(values) == 0 and has_key_transformer: + values = _extract_sequence_with_key_transformer(source, alias, name) else: - values = source.get(self.alias or self.name) + values = _extract_scalar(source, alias, name) return values + + +def _extract_scalar(source: MultiDict, name: Optional[str], alias: Optional[str]): + """Extract a scalar value from a source.""" + return source.get(alias, source.get(name)) + + +def _extract_sequence( + source: MultiDict, name: Optional[str], alias: Optional[str] +) -> List: + """Extract a Sequence value from a source.""" + _values = source.getlist(alias) + if len(_values) == 0: + _values = source.getlist(name) + return _values + + +def _extract_sequence_with_key_transformer( + source: MultiDict, name: Optional[str], alias: Optional[str] +): + """Apply the key transformer to the source.""" + transformed_source = current_app.query_string_key_transformer( # type: ignore + current_app, source.to_dict() + ) + return _extract_scalar(transformed_source, name, alias) diff --git a/poetry.lock b/poetry.lock index 4e78b4f..ce59a2f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -27,6 +27,18 @@ files = [ [package.extras] test = ["coverage", "flake8", "pexpect", "wheel"] +[[package]] +name = "ast-decompiler" +version = "0.7.0" +description = "Python module to decompile AST to Python code" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "ast_decompiler-0.7.0-py3-none-any.whl", hash = "sha256:5ebd37ba129227484daff4a15dd6056d87c488fa372036dd004ee84196b207d3"}, + {file = "ast_decompiler-0.7.0.tar.gz", hash = "sha256:efc3a507e5f8963ec7b4b2ce2ea693e3755c2f52b741c231bc344a4526738337"}, +] + [[package]] name = "attrs" version = "22.2.0" @@ -761,6 +773,26 @@ files = [ flake8 = ">=3" pydocstyle = ">=2.1" +[[package]] +name = "flake8-pyi" +version = "23.1.2" +description = "A plugin for flake8 to enable linting .pyi stub files." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "flake8_pyi-23.1.2-py3-none-any.whl", hash = "sha256:8f6e18eebd395669d3b0531b3e58787617084b0c0229d341ee0fd19d9346d210"}, + {file = "flake8_pyi-23.1.2.tar.gz", hash = "sha256:226866b75e8ae264e47799576ebbac96323cee6b85196bd286baa30a19bddf65"}, +] + +[package.dependencies] +ast-decompiler = {version = ">=0.7.0,<1.0", markers = "python_version < \"3.9\""} +flake8 = ">=3.2.1,<7.0.0" +pyflakes = ">=2.1.1" + +[package.extras] +dev = ["black (==22.12.0)", "flake8-bugbear (==23.1.14)", "flake8-noqa (==1.3.0)", "isort (==5.12.0)", "mypy (==0.991)", "pre-commit-hooks (==4.4.0)", "pytest (==7.2.1)", "types-pyflakes (<4)"] + [[package]] name = "flake8-rst-docstrings" version = "0.3.0" @@ -3002,4 +3034,4 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "bb1052364826ad72301367c2a131f5badf94c976628950aae2bdc4841c3da47b" +content-hash = "a373b9f4dba970ff0c79da18c5feb86ab4c81f0722c7b917cd819e858732ef1e" diff --git a/pyproject.toml b/pyproject.toml index 0de11a1..e05add5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ typeguard = ">=2.13.3" xdoctest = {extras = ["colors"], version = ">=0.15.10"} certifi = ">=2022.12.7" sourcery-cli = "^1.0.3" +flake8-pyi = "^23.1.2" [tool.poetry.group.docs.dependencies] sphinx = "<5.3.0" diff --git a/tests/app_test/apps/body.py b/tests/app_test/apps/body.py index 9490de5..5979bd5 100644 --- a/tests/app_test/apps/body.py +++ b/tests/app_test/apps/body.py @@ -11,13 +11,13 @@ @router.post("/body/int") -def post_body_as_int(payload: int = Body()): +def post_body_as_int(payload: int = Body(embed=True)): """Body Param as plain int.""" return {"payload": payload} @router.post("/body/str") -def post_body_as_str(payload: str = Body()): +def post_body_as_str(payload: str = Body(embed=True)): """Body Param as plain str.""" return {"payload": payload} diff --git a/tests/app_test/apps/file.py b/tests/app_test/apps/file.py index 2821180..2f09198 100644 --- a/tests/app_test/apps/file.py +++ b/tests/app_test/apps/file.py @@ -12,5 +12,5 @@ @router.post("/file") -def ping(file: FileStorage = File(...)): +def ping(file: FileStorage = File(embed=True)): return {"file_content": str(file.read())} diff --git a/tests/app_test/apps/form.py b/tests/app_test/apps/form.py index cfcbe3b..97dce10 100644 --- a/tests/app_test/apps/form.py +++ b/tests/app_test/apps/form.py @@ -11,6 +11,6 @@ @router.post("/form/base_model") -def post_base_model_in_form(payload: SimpleModelIn = Form()): +def post_base_model_in_form(form_payload: SimpleModelIn = Form(embed=False)): """POST Form Parameter as pydantic BaseModel.""" - return payload.json() + return form_payload.json() diff --git a/tests/app_test/apps/outbound.py b/tests/app_test/apps/outbound.py index 9156c35..3478b63 100644 --- a/tests/app_test/apps/outbound.py +++ b/tests/app_test/apps/outbound.py @@ -134,5 +134,5 @@ def configured_status_code_204_has_no_body(): @router.post("/sensitive_data", response_model=UserOut) -def reponse_model_filters_data(sensitive_data: UserIn = Body()): +def reponse_model_filters_data(sensitive_data: UserIn = Body(embed=True)): return sensitive_data diff --git a/tests/app_test/models/inbound.py b/tests/app_test/models/inbound.py index b310025..9bb3b1b 100644 --- a/tests/app_test/models/inbound.py +++ b/tests/app_test/models/inbound.py @@ -3,6 +3,7 @@ from typing import Optional from pydantic import Field +from pydantic import validator from flask_jeroboam import InboundModel @@ -34,3 +35,10 @@ class ModelWithListIn(InboundModel): per_page: int ids: List[int] = Field(alias="id[]") order: List[dict] = Field(alias="order[]") + + @validator("order") + def order_validator(cls, value): # noqa: B902,N805 + """Validate order.""" + if len(value) == 0: + raise ValueError("Order must have at least 1 value") + return value diff --git a/tests/test_inbound_handler/test_query_operations.py b/tests/test_inbound_handler/test_query_operations.py index 2a0902a..eae4ac9 100644 --- a/tests/test_inbound_handler/test_query_operations.py +++ b/tests/test_inbound_handler/test_query_operations.py @@ -68,8 +68,8 @@ def test_get_query_operations(client, url, expected_status, expected_response): THEN the request is parsed and validated accordingly """ response = client.get(url) - assert response.status_code == expected_status assert response.json == expected_response + assert response.status_code == expected_status def test_valid_base_model_as_query_parameter( diff --git a/tests/test_outbound_handler/test_outbound_handler.py b/tests/test_outbound_handler/test_outbound_handler.py index f82f5a9..ce11929 100644 --- a/tests/test_outbound_handler/test_outbound_handler.py +++ b/tests/test_outbound_handler/test_outbound_handler.py @@ -247,5 +247,5 @@ def test_reponse_model_filters_outbound_data_even_when_subclassing( json={"sensitive_data": {"username": "test", "password": "test"}}, ) - assert response.status_code == 201 assert response.json == {"username": "test"} + assert response.status_code == 201 diff --git a/tests/test_utils.py b/tests/test_utils.py index 488ebf3..5cf6afa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -55,8 +55,8 @@ def test_pascal_case_in_and_out_snake_case_without_transformer(client: FlaskClie "detail": [ { "loc": ["query", "payload", "order[]"], - "msg": "none is not an allowed value", - "type": "type_error.none.not_allowed", + "msg": "Order must have at least 1 value", + "type": "value_error", } ] } From 9838d1655e89e5a9ef4f43ce666f1e283b094f1b Mon Sep 17 00:00:00 2001 From: Jean-Christophe Bianic Date: Fri, 3 Feb 2023 15:04:14 +0100 Subject: [PATCH 10/10] :hammer: Speciazing SolvedParams --- flask_jeroboam/_inboundhandler.py | 2 +- flask_jeroboam/view_params/helpers.py | 83 +++++++ flask_jeroboam/view_params/parameters.py | 1 + flask_jeroboam/view_params/solved.py | 204 ++++++++++-------- .../test_inbound_handler/test_form_params.py | 2 +- tests/test_utils.py | 2 +- 6 files changed, 206 insertions(+), 88 deletions(-) create mode 100644 flask_jeroboam/view_params/helpers.py diff --git a/flask_jeroboam/_inboundhandler.py b/flask_jeroboam/_inboundhandler.py index d81d5d3..cdd63bd 100644 --- a/flask_jeroboam/_inboundhandler.py +++ b/flask_jeroboam/_inboundhandler.py @@ -171,7 +171,7 @@ def _solve_view_function_parameter( annotation = param.annotation if param.annotation != param.empty else Any annotation = get_annotation_from_field_info(annotation, view_param, param_name) - return SolvedParameter( + return SolvedParameter.specialize( name=param_name, type_=annotation, required=required, diff --git a/flask_jeroboam/view_params/helpers.py b/flask_jeroboam/view_params/helpers.py new file mode 100644 index 0000000..ab2bcaf --- /dev/null +++ b/flask_jeroboam/view_params/helpers.py @@ -0,0 +1,83 @@ +"""Helper functions for extracting values from request locations.""" +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from flask import current_app +from werkzeug.datastructures import MultiDict + +from flask_jeroboam.utils import is_sequence_field + + +def _extract_scalar( + *, + source: Union[MultiDict, dict], + name: Optional[str], + alias: Optional[str], + **_kwargs, +): + """Extract a scalar value from a source.""" + return source.get(alias, source.get(name)) + + +def _extract_sequence( + *, source: MultiDict, name: Optional[str], alias: Optional[str], **_kwargs +) -> List: + """Extract a Sequence value from a source.""" + _values = source.getlist(alias) + if len(_values) == 0: + _values = source.getlist(name) + return _values + + +def _extract_sequence_with_key_transformer( + *, source: MultiDict, name: Optional[str], alias: Optional[str], **_kwargs +): + """Apply the key transformer to the source.""" + transformed_source = current_app.query_string_key_transformer( # type: ignore + current_app, source.to_dict() + ) + return _extract_scalar(source=transformed_source, name=name, alias=alias) + + +def _undirected_extraction( + *, + field, + source, + alias: str, + name: str, + has_key_transformer: bool, + **_kwargs, +): + if is_sequence_field(field): + values = _extract_sequence(source=source, name=name, alias=alias) + if len(values) == 0 and has_key_transformer: + values = _extract_sequence_with_key_transformer( + source=source, name=name, alias=alias + ) + else: + values = _extract_scalar(source=source, name=name, alias=alias) + return values + + +def _extract_subfields( + *, + source: MultiDict, + fields: Dict, + **_kwargs, +) -> Dict: + """Extract a Sequence from subfields.""" + values = {} + has_key_transformer = ( + getattr(current_app, "query_string_key_transformer", False) is not None + ) + for field_name, subfield in fields.items(): + values[field_name] = _undirected_extraction( + field=subfield, + source=source, + name=field_name, + alias=subfield.alias, + has_key_transformer=has_key_transformer, + ) + return values diff --git a/flask_jeroboam/view_params/parameters.py b/flask_jeroboam/view_params/parameters.py index 68d8535..9e2cfb2 100644 --- a/flask_jeroboam/view_params/parameters.py +++ b/flask_jeroboam/view_params/parameters.py @@ -22,6 +22,7 @@ class ParamLocation(Enum): body = "body" form = "form" file = "file" + unknown = "unknown" class ViewParameter(FieldInfo): diff --git a/flask_jeroboam/view_params/solved.py b/flask_jeroboam/view_params/solved.py index 5e3b7c8..98d6303 100644 --- a/flask_jeroboam/view_params/solved.py +++ b/flask_jeroboam/view_params/solved.py @@ -1,4 +1,8 @@ -"""View params for solved problems.""" +"""Solved Specialized Params. + +Params are solved at registration time. This way we reduce indirections when +handling requests thus reducing overhead. +""" import re from copy import deepcopy from typing import Any @@ -8,17 +12,20 @@ from typing import Type from typing import Union -from flask import current_app from flask import request from pydantic import BaseConfig from pydantic.error_wrappers import ErrorWrapper from pydantic.errors import MissingError from pydantic.fields import FieldInfo from pydantic.fields import ModelField +from werkzeug.datastructures import FileStorage from werkzeug.datastructures import MultiDict from flask_jeroboam.utils import is_sequence_field +from .helpers import _extract_scalar +from .helpers import _extract_sequence +from .helpers import _extract_subfields from .parameters import ParamLocation from .parameters import ViewParameter @@ -27,7 +34,7 @@ class SolvedParameter(ModelField): - """A Parameter that have been solved, ready to validate data.""" + """Generic Solved Parameter.""" def __init__( self, @@ -50,13 +57,7 @@ def __init__( self.in_body = getattr(view_param, "in_body", None) default = getattr(view_param, "default", field_info.default) class_validators = class_validators or {} - if getattr(view_param, "convert_underscores", False): - self.alias = re.sub( - r"_(\w)", lambda x: f"-{x.group(1).upper()}", self.name.capitalize() - ) - kwargs["alias"] = self.alias - else: - kwargs["alias"] = kwargs.get("alias", getattr(view_param, "alias", None)) + kwargs["alias"] = kwargs.get("alias", getattr(view_param, "alias", None)) super().__init__( name=name, type_=type_, @@ -68,6 +69,42 @@ def __init__( **kwargs, ) + @classmethod + def specialize( + cls, + *, + name: str, + type_: type, + required: bool = False, + view_param: Optional[ViewParameter] = None, + class_validators: Optional[Dict] = None, + model_config: Type[BaseConfig] = BaseConfig, + field_info: FieldInfo = empty_field_info, + **kwargs, + ): + """Specialize the Current class to each location.""" + location = getattr(view_param, "location", ParamLocation.unknown) + target_class = { + ParamLocation.query: SolvedQueryParameter, + ParamLocation.header: SolvedHeaderParameter, + ParamLocation.path: SolvedPathParameter, + ParamLocation.cookie: SolvedCookieParameter, + ParamLocation.body: SolvedBodyParameter, + ParamLocation.file: SolvedFileParameter, + ParamLocation.form: SolvedFormParameter, + }.get(location, cls) + + return target_class( + name=name, + type_=type_, + required=required, + view_param=view_param, + class_validators=class_validators, + model_config=model_config, + field_info=field_info, + **kwargs, + ) + def validate_request(self): """Validate the request.""" values = {} @@ -79,7 +116,6 @@ def validate_request(self): ErrorWrapper(MissingError(), loc=(self.location.value, self.alias)) ) return values, errors - # Should I return errors here ? elif inbound_values is None and not self.required: inbound_values = deepcopy(self.default) values_, errors_ = self.validate( @@ -91,88 +127,86 @@ def validate_request(self): values[self.name] = values_ return values, errors + def _get_values( + self, + ) -> Union[FileStorage, MultiDict, dict, Optional[str], List[Any]]: + """The Value extraction method is specialized by location.""" + raise NotImplementedError + + +class SolvedPathParameter(SolvedParameter): + """Solved Path parameter.""" + def _get_values(self) -> Union[dict, Optional[str], List[Any]]: - """Get the values from the request.""" - if self.in_body: - return self._get_values_from_body() - elif self.location == ParamLocation.query: - source = request.args - elif self.location == ParamLocation.path: - source = MultiDict(request.view_args) - elif self.location == ParamLocation.header: - source = MultiDict(request.headers) - elif self.location == ParamLocation.cookie: - source = request.cookies - else: - raise ValueError("Unknown location") - has_key_transformer = ( - getattr(current_app, "query_string_key_transformer", False) is not None - ) - return self._get_values_from_request( - self, source, self.name, self.alias, has_key_transformer + source: dict = request.view_args or {} + return _extract_scalar(source=source, alias=self.alias, name=self.name) + + +class SolvedHeaderParameter(SolvedParameter): + """Solved Header parameter.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.alias = re.sub( + r"_(\w)", lambda x: f"-{x.group(1).upper()}", self.name.capitalize() ) - def _get_values_from_body(self) -> Any: - """Get the values from the request body.""" - source: Any = {} - if self.location == ParamLocation.form: - source = request.form - elif self.location == ParamLocation.file: - source = request.files - else: - source = request.json or {} - return source.get(self.alias or self.name) if self.embed else source + def _get_values(self) -> Union[dict, Optional[str], List[Any]]: + source: dict = request.headers or {} + return _extract_scalar(source=source, alias=self.alias, name=self.name) - def _get_values_from_request( - self, - field: ModelField, - source: MultiDict, - name: str, - alias: str, - has_key_transformer: bool = False, - ) -> Union[dict, Optional[str], List[Any]]: - """Get the values from the request. - - # TODO: Gestion des alias de fields. - # TODO: Gestion des default empty et des valeurs manquantes. - # Est-ce qu'on gère le embed dans les QueryParams ? - """ - values: Union[dict, Optional[str], List[Any]] = {} - if hasattr(field.type_, "__fields__"): - assert isinstance(values, dict) # noqa: S101 - for field_name, subfield in field.type_.__fields__.items(): - values[field_name] = self._get_values_from_request( - subfield, source, field_name, subfield.alias, has_key_transformer - ) - elif is_sequence_field(field): - values = _extract_sequence(source, alias, name) - if len(values) == 0 and has_key_transformer: - values = _extract_sequence_with_key_transformer(source, alias, name) + +class SolvedCookieParameter(SolvedParameter): + """Solved Cookie parameter.""" + + def _get_values(self) -> Union[dict, Optional[str], List[Any]]: + source: dict = request.cookies or {} + return _extract_scalar(source=source, alias=self.alias, name=self.name) + + +class SolvedQueryParameter(SolvedParameter): + """Solved Query parameter.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if hasattr(self.type_, "__fields__"): + self.extractor = _extract_subfields + elif is_sequence_field(self): + self.extractor = _extract_sequence else: - values = _extract_scalar(source, alias, name) - return values + self.extractor = _extract_scalar + def _get_values(self) -> Union[dict, Optional[str], List[Any]]: + source: MultiDict = request.args + return self.extractor( + source=source, + alias=self.alias, + name=self.name, + fields=getattr(self.type_, "__fields__", {}), + ) -def _extract_scalar(source: MultiDict, name: Optional[str], alias: Optional[str]): - """Extract a scalar value from a source.""" - return source.get(alias, source.get(name)) +class SolvedBodyParameter(SolvedParameter): + """Solved Scalar Query parameter.""" + + def _get_values(self) -> Union[dict, Optional[str], List[Any]]: + source: dict = request.json or {} + return source.get(self.alias or self.name) if self.embed else source -def _extract_sequence( - source: MultiDict, name: Optional[str], alias: Optional[str] -) -> List: - """Extract a Sequence value from a source.""" - _values = source.getlist(alias) - if len(_values) == 0: - _values = source.getlist(name) - return _values +class SolvedFileParameter(SolvedParameter): + """Solved File Parameter.""" -def _extract_sequence_with_key_transformer( - source: MultiDict, name: Optional[str], alias: Optional[str] -): - """Apply the key transformer to the source.""" - transformed_source = current_app.query_string_key_transformer( # type: ignore - current_app, source.to_dict() - ) - return _extract_scalar(transformed_source, name, alias) + def _get_values( + self, + ) -> Union[FileStorage, MultiDict, dict, Optional[str], List[Any]]: + source: MultiDict = request.files or MultiDict() + return source.get(self.alias or self.name) if self.embed else source + + +class SolvedFormParameter(SolvedParameter): + """Solved Form parameter.""" + + def _get_values(self) -> Union[dict, Optional[str], List[Any]]: + source: MultiDict = request.form or MultiDict() + return source.get(self.alias or self.name) if self.embed else source diff --git a/tests/test_inbound_handler/test_form_params.py b/tests/test_inbound_handler/test_form_params.py index 81aa950..3945731 100644 --- a/tests/test_inbound_handler/test_form_params.py +++ b/tests/test_inbound_handler/test_form_params.py @@ -10,5 +10,5 @@ def test_valid_payload_in_data_is_injected( """ response = client.post("/form/base_model", data={"page": 1, "type": "item"}) - assert response.status_code == 201 assert response.json == {"page": 1, "type": "item"} + assert response.status_code == 201 diff --git a/tests/test_utils.py b/tests/test_utils.py index 5cf6afa..070ab96 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -75,5 +75,5 @@ def test_solved_param_erroring(): solved_param = SolvedParameter( name="FaultySolvedParam", type_=str, view_param=param ) - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError): solved_param._get_values()