@@ -1,5 +1,5 @@
from dataclasses import asdict
from datetime import date, datetime
import datetime
from dataclasses import asdict, field
from typing import Any, Dict, List, Optional, Union, cast

import httpx
@@ -13,7 +13,7 @@


async def get_user_list(
*, client: Client, an_enum_value: List[AnEnum], some_date: Union[date, datetime],
*, client: Client, an_enum_value: List[AnEnum], some_date: Union[datetime.date, datetime.datetime],
) -> Union[
List[AModel], HTTPValidationError,
]:
@@ -29,7 +29,7 @@ async def get_user_list(

json_an_enum_value.append(an_enum_value_item)

if isinstance(some_date, date):
if isinstance(some_date, datetime.date):
json_some_date = some_date.isoformat()

else:
@@ -97,3 +97,82 @@ async def json_body_tests_json_body_post(
return HTTPValidationError.from_dict(cast(Dict[str, Any], response.json()))
else:
raise ApiResponseError(response=response)


async def test_defaults_tests_test_defaults_post(
*,
client: Client,
json_body: Dict[Any, Any],
string_prop: Optional[str] = "the default string",
datetime_prop: Optional[datetime.datetime] = datetime.datetime(1010, 10, 10, 0, 0),
date_prop: Optional[datetime.date] = datetime.date(1010, 10, 10),
float_prop: Optional[float] = 3.14,
int_prop: Optional[int] = 7,
boolean_prop: Optional[bool] = False,
list_prop: Optional[List[AnEnum]] = field(
default_factory=lambda: cast(Optional[List[AnEnum]], [AnEnum.FIRST_VALUE, AnEnum.SECOND_VALUE])
),
union_prop: Optional[Union[Optional[float], Optional[str]]] = "not a float",
enum_prop: Optional[AnEnum] = None,
) -> Union[
None, HTTPValidationError,
]:

""" """
url = "{}/tests/test_defaults".format(client.base_url,)

headers: Dict[str, Any] = client.get_headers()

json_datetime_prop = datetime_prop.isoformat() if datetime_prop else None

json_date_prop = date_prop.isoformat() if date_prop else None

if list_prop is None:
json_list_prop = None
else:
json_list_prop = []
for list_prop_item_data in list_prop:
list_prop_item = list_prop_item_data.value

json_list_prop.append(list_prop_item)

if union_prop is None:
json_union_prop: Optional[Union[Optional[float], Optional[str]]] = None
elif isinstance(union_prop, float):
json_union_prop = union_prop
else:
json_union_prop = union_prop

json_enum_prop = enum_prop.value if enum_prop else None

params: Dict[str, Any] = {}
if string_prop is not None:
params["string_prop"] = string_prop
if datetime_prop is not None:
params["datetime_prop"] = json_datetime_prop
if date_prop is not None:
params["date_prop"] = json_date_prop
if float_prop is not None:
params["float_prop"] = float_prop
if int_prop is not None:
params["int_prop"] = int_prop
if boolean_prop is not None:
params["boolean_prop"] = boolean_prop
if list_prop is not None:
params["list_prop"] = json_list_prop
if union_prop is not None:
params["union_prop"] = json_union_prop
if enum_prop is not None:
params["enum_prop"] = json_enum_prop

json_json_body = json_body

async with httpx.AsyncClient() as _client:
response = await _client.post(url=url, headers=headers, json=json_json_body, params=params,)

if response.status_code == 200:
return None
if response.status_code == 422:
return HTTPValidationError.from_dict(cast(Dict[str, Any], response.json()))
else:
raise ApiResponseError(response=response)
@@ -1,7 +1,7 @@
from __future__ import annotations

import datetime
from dataclasses import dataclass, field
from datetime import date, datetime
from typing import Any, Dict, List, Optional, Union, cast

from .an_enum import AnEnum
@@ -13,17 +13,19 @@ class AModel:
""" A Model for testing all the ways custom objects can be used """

an_enum_value: AnEnum
a_camel_date_time: Union[datetime, date]
a_date: date
some_dict: Dict[Any, Any]
a_camel_date_time: Union[datetime.datetime, datetime.date]
a_date: datetime.date
nested_list_of_enums: Optional[List[List[DifferentEnum]]] = field(
default_factory=lambda: cast(Optional[List[List[DifferentEnum]]], [])
)
some_dict: Optional[Dict[Any, Any]] = field(default_factory=lambda: cast(Optional[Dict[Any, Any]], {}))

def to_dict(self) -> Dict[str, Any]:
an_enum_value = self.an_enum_value.value

if isinstance(self.a_camel_date_time, datetime):
some_dict = self.some_dict

if isinstance(self.a_camel_date_time, datetime.datetime):
a_camel_date_time = self.a_camel_date_time.isoformat()

else:
@@ -44,35 +46,35 @@ def to_dict(self) -> Dict[str, Any]:

nested_list_of_enums.append(nested_list_of_enums_item)

some_dict = self.some_dict

return {
"an_enum_value": an_enum_value,
"some_dict": some_dict,
"aCamelDateTime": a_camel_date_time,
"a_date": a_date,
"nested_list_of_enums": nested_list_of_enums,
"some_dict": some_dict,
}

@staticmethod
def from_dict(d: Dict[str, Any]) -> AModel:
an_enum_value = AnEnum(d["an_enum_value"])

def _parse_a_camel_date_time(data: Dict[str, Any]) -> Union[datetime, date]:
a_camel_date_time: Union[datetime, date]
some_dict = d["some_dict"]

def _parse_a_camel_date_time(data: Dict[str, Any]) -> Union[datetime.datetime, datetime.date]:
a_camel_date_time: Union[datetime.datetime, datetime.date]
try:
a_camel_date_time = datetime.fromisoformat(d["aCamelDateTime"])
a_camel_date_time = datetime.datetime.fromisoformat(d["aCamelDateTime"])

return a_camel_date_time
except:
pass
a_camel_date_time = date.fromisoformat(d["aCamelDateTime"])
a_camel_date_time = datetime.date.fromisoformat(d["aCamelDateTime"])

return a_camel_date_time

a_camel_date_time = _parse_a_camel_date_time(d["aCamelDateTime"])

a_date = date.fromisoformat(d["a_date"])
a_date = datetime.date.fromisoformat(d["a_date"])

nested_list_of_enums = []
for nested_list_of_enums_item_data in d.get("nested_list_of_enums") or []:
@@ -84,12 +86,10 @@ def _parse_a_camel_date_time(data: Dict[str, Any]) -> Union[datetime, date]:

nested_list_of_enums.append(nested_list_of_enums_item)

some_dict = d.get("some_dict")

return AModel(
an_enum_value=an_enum_value,
some_dict=some_dict,
a_camel_date_time=a_camel_date_time,
a_date=a_date,
nested_list_of_enums=nested_list_of_enums,
some_dict=some_dict,
)
@@ -96,7 +96,9 @@ def __init__(self, *, openapi: GeneratorData) -> None:

self.package_name: str = self.package_name_override or self.project_name.replace("-", "_")
self.package_dir: Path = self.project_dir / self.package_name
self.package_description: str = f"A client library for accessing {self.openapi.title}"
self.package_description: str = utils.remove_string_escapes(
f"A client library for accessing {self.openapi.title}"
)
self.version: str = openapi.version

self.env.filters.update(self.TEMPLATE_FILTERS)
@@ -37,3 +37,7 @@ class PropertyError(ParseError):
""" Error raised when there's a problem creating a Property """

header = "Problem creating a Property: "


class ValidationError(Exception):
pass
@@ -8,6 +8,7 @@
from pydantic import ValidationError

from .. import schema as oai
from .. import utils
from .errors import GeneratorError, ParseError, PropertyError
from .properties import EnumProperty, Property, property_from_data
from .reference import Reference
@@ -182,7 +183,7 @@ def from_data(*, data: oai.Operation, path: str, method: str, tag: str) -> Union
endpoint = Endpoint(
path=path,
method=method,
description=data.description,
description=utils.remove_string_escapes(data.description) if data.description else "",
name=data.operationId,
requires_security=bool(data.security),
tag=tag,
@@ -1,11 +1,12 @@
from __future__ import annotations

from dataclasses import InitVar, dataclass, field
from datetime import date, datetime
from typing import Any, ClassVar, Dict, Generic, List, Optional, Set, TypeVar, Union

from .. import schema as oai
from .. import utils
from .errors import PropertyError
from .errors import PropertyError, ValidationError
from .reference import Reference


@@ -19,6 +20,9 @@ class Property:
templates/property_templates and must contain two macros: construct and transform. Construct will be used to
build this property from JSON data (a response from an API). Transform will be used to convert this property
to JSON data (when sending a request to the API).
Raises:
ValidationError: Raised when the default value fails to be converted to the expected type
"""

name: str
@@ -32,10 +36,16 @@ class Property:

def __post_init__(self) -> None:
self.python_name = utils.snake_case(self.name)
if self.default is not None:
self.default = self._validate_default(default=self.default)

def _validate_default(self, default: Any) -> Any:
""" Check that the default value is valid for the property's type + perform any necessary sanitization """
raise ValidationError

def get_type_string(self) -> str:
def get_type_string(self, no_optional: bool = False) -> str:
""" Get a string representation of type that should be used when declaring this property """
if self.required:
if self.required or no_optional:
return self._type_string
return f"Optional[{self._type_string}]"

@@ -74,10 +84,8 @@ class StringProperty(Property):

_type_string: ClassVar[str] = "str"

def __post_init__(self) -> None:
super().__post_init__()
if self.default is not None:
self.default = f'"{self.default}"'
def _validate_default(self, default: Any) -> str:
return f'"{utils.remove_string_escapes(default)}"'


@dataclass
@@ -86,7 +94,7 @@ class DateTimeProperty(Property):
A property of type datetime.datetime
"""

_type_string: ClassVar[str] = "datetime"
_type_string: ClassVar[str] = "datetime.datetime"
template: ClassVar[str] = "datetime_property.pyi"

def get_imports(self, *, prefix: str) -> Set[str]:
@@ -97,15 +105,23 @@ def get_imports(self, *, prefix: str) -> Set[str]:
prefix: A prefix to put before any relative (local) module names.
"""
imports = super().get_imports(prefix=prefix)
imports.update({"from datetime import datetime", "from typing import cast"})
imports.update({"import datetime", "from typing import cast"})
return imports

def _validate_default(self, default: Any) -> str:
for format_string in ("%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%S%z"):
try:
return repr(datetime.strptime(default, format_string))
except (TypeError, ValueError):
continue
raise ValidationError


@dataclass
class DateProperty(Property):
""" A property of type datetime.date """

_type_string: ClassVar[str] = "date"
_type_string: ClassVar[str] = "datetime.date"
template: ClassVar[str] = "date_property.pyi"

def get_imports(self, *, prefix: str) -> Set[str]:
@@ -116,9 +132,15 @@ def get_imports(self, *, prefix: str) -> Set[str]:
prefix: A prefix to put before any relative (local) module names.
"""
imports = super().get_imports(prefix=prefix)
imports.update({"from datetime import date", "from typing import cast"})
imports.update({"import datetime", "from typing import cast"})
return imports

def _validate_default(self, default: Any) -> str:
try:
return repr(date.fromisoformat(default))
except (TypeError, ValueError) as e:
raise ValidationError() from e


@dataclass
class FileProperty(Property):
@@ -146,6 +168,12 @@ class FloatProperty(Property):
default: Optional[float] = None
_type_string: ClassVar[str] = "float"

def _validate_default(self, default: Any) -> float:
try:
return float(default)
except (TypeError, ValueError) as e:
raise ValidationError() from e


@dataclass
class IntProperty(Property):
@@ -154,13 +182,23 @@ class IntProperty(Property):
default: Optional[int] = None
_type_string: ClassVar[str] = "int"

def _validate_default(self, default: Any) -> int:
try:
return int(default)
except (TypeError, ValueError) as e:
raise ValidationError() from e


@dataclass
class BooleanProperty(Property):
""" Property for bool """

_type_string: ClassVar[str] = "bool"

def _validate_default(self, default: Any) -> bool:
# no try/except needed as anything that comes from the initial load from json/yaml will be boolable
return bool(default)


InnerProp = TypeVar("InnerProp", bound=Property)

@@ -172,14 +210,9 @@ class ListProperty(Property, Generic[InnerProp]):
inner_property: InnerProp
template: ClassVar[str] = "list_property.pyi"

def __post_init__(self) -> None:
super().__post_init__()
if self.default is not None:
self.default = f"field(default_factory=lambda: cast({self.get_type_string()}, {self.default}))"

def get_type_string(self) -> str:
def get_type_string(self, no_optional: bool = False) -> str:
""" Get a string representation of type that should be used when declaring this property """
if self.required:
if self.required or no_optional:
return f"List[{self.inner_property.get_type_string()}]"
return f"Optional[List[{self.inner_property.get_type_string()}]]"

@@ -198,6 +231,16 @@ def get_imports(self, *, prefix: str) -> Set[str]:
imports.add("from typing import cast")
return imports

def _validate_default(self, default: Any) -> str:
if not isinstance(default, list):
raise ValidationError()

default = list(map(self.inner_property._validate_default, default))
if isinstance(self.inner_property, RefProperty): # Fix enums to use the actual value
default = str(default).replace("'", "")

return f"field(default_factory=lambda: cast({self.get_type_string()}, {default}))"


@dataclass
class UnionProperty(Property):
@@ -206,11 +249,11 @@ class UnionProperty(Property):
inner_properties: List[Property]
template: ClassVar[str] = "union_property.pyi"

def get_type_string(self) -> str:
def get_type_string(self, no_optional: bool = False) -> str:
""" Get a string representation of type that should be used when declaring this property """
inner_types = [p.get_type_string() for p in self.inner_properties]
inner_prop_string = ", ".join(inner_types)
if self.required:
if self.required or no_optional:
return f"Union[{inner_prop_string}]"
return f"Optional[Union[{inner_prop_string}]]"

@@ -227,6 +270,15 @@ def get_imports(self, *, prefix: str) -> Set[str]:
imports.add("from typing import Union")
return imports

def _validate_default(self, default: Any) -> Any:
for property in self.inner_properties:
try:
val = property._validate_default(default)
return val
except ValidationError:
continue
raise ValidationError()


_existing_enums: Dict[str, EnumProperty] = {}

@@ -242,7 +294,6 @@ class EnumProperty(Property):
template: ClassVar[str] = "enum_property.pyi"

def __post_init__(self, title: str) -> None: # type: ignore
super().__post_init__()
reference = Reference.from_ref(title)
dedup_counter = 0
while reference.class_name in _existing_enums:
@@ -253,9 +304,7 @@ def __post_init__(self, title: str) -> None: # type: ignore
reference = Reference.from_ref(f"{reference.class_name}{dedup_counter}")

self.reference = reference
inverse_values = {v: k for k, v in self.values.items()}
if self.default is not None:
self.default = f"{self.reference.class_name}.{inverse_values[self.default]}"
super().__post_init__()
_existing_enums[self.reference.class_name] = self

@staticmethod
@@ -268,10 +317,10 @@ def get_enum(name: str) -> Optional[EnumProperty]:
""" Get all the EnumProperties that have been registered keyed by class name """
return _existing_enums.get(name)

def get_type_string(self) -> str:
def get_type_string(self, no_optional: bool = False) -> str:
""" Get a string representation of type that should be used when declaring this property """

if self.required:
if self.required or no_optional:
return self.reference.class_name
return f"Optional[{self.reference.class_name}]"

@@ -298,10 +347,18 @@ def values_from_list(values: List[str]) -> Dict[str, str]:
key = f"VALUE_{i}"
if key in output:
raise ValueError(f"Duplicate key {key} in Enum")
output[key] = value
sanitized_key = utils.fix_keywords(utils.sanitize(key))
output[sanitized_key] = utils.remove_string_escapes(value)

return output

def _validate_default(self, default: Any) -> str:
inverse_values = {v: k for k, v in self.values.items()}
try:
return f"{self.reference.class_name}.{inverse_values[default]}"
except KeyError as e:
raise ValidationError() from e


@dataclass
class RefProperty(Property):
@@ -316,9 +373,9 @@ def template(self) -> str: # type: ignore
return "enum_property.pyi"
return "ref_property.pyi"

def get_type_string(self) -> str:
def get_type_string(self, no_optional: bool = False) -> str:
""" Get a string representation of type that should be used when declaring this property """
if self.required:
if self.required or no_optional:
return self.reference.class_name
return f"Optional[{self.reference.class_name}]"

@@ -339,17 +396,20 @@ def get_imports(self, *, prefix: str) -> Set[str]:
)
return imports

def _validate_default(self, default: Any) -> Any:
enum = EnumProperty.get_enum(self.reference.class_name)
if enum:
return enum._validate_default(default)
else:
raise ValidationError


@dataclass
class DictProperty(Property):
""" Property that is a general Dict """

_type_string: ClassVar[str] = "Dict[Any, Any]"

def __post_init__(self) -> None:
super().__post_init__()
if self.default is not None:
self.default = f"field(default_factory=lambda: cast({self.get_type_string()}, {self.default}))"
template: ClassVar[str] = "dict_property.pyi"

def get_imports(self, *, prefix: str) -> Set[str]:
"""
@@ -365,6 +425,11 @@ def get_imports(self, *, prefix: str) -> Set[str]:
imports.add("from typing import cast")
return imports

def _validate_default(self, default: Any) -> str:
if isinstance(default, dict):
return repr(default)
raise ValidationError


def _string_based_property(
name: str, required: bool, data: oai.Schema
@@ -381,10 +446,11 @@ def _string_based_property(
return StringProperty(name=name, default=data.default, required=required, pattern=data.pattern)


def property_from_data(
def _property_from_data(
name: str, required: bool, data: Union[oai.Reference, oai.Schema]
) -> Union[Property, PropertyError]:
""" Generate a Property from the OpenAPI dictionary representation of it """
name = utils.remove_string_escapes(name)
if isinstance(data, oai.Reference):
return RefProperty(name=name, required=required, reference=Reference.from_ref(data.ref), default=None)
if data.enum:
@@ -423,3 +489,12 @@ def property_from_data(
elif data.type == "object":
return DictProperty(name=name, required=required, default=data.default)
return PropertyError(data=data, detail=f"unknown type {data.type}")


def property_from_data(
name: str, required: bool, data: Union[oai.Reference, oai.Schema]
) -> Union[Property, PropertyError]:
try:
return _property_from_data(name=name, required=required, data=data)
except ValidationError:
return PropertyError(detail="Failed to validate default value", data=data)
@@ -70,7 +70,7 @@ async def {{ endpoint.name | snakecase }}(
files=multipart_data.to_dict(),
{% endif %}
{% if endpoint.json_body %}
json={{ "json_" + endpoint.json_body.python_name }},
json={{ "json_" + endpoint.json_body.python_name }},
{% endif %}
{% if endpoint.query_parameters %}
params=params,
@@ -1,10 +1,10 @@
{% macro construct(property, source) %}
{% if property.required %}
{{ property.python_name }} = date.fromisoformat({{ source }})
{{ property.python_name }} = datetime.date.fromisoformat({{ source }})
{% else %}
{{ property.python_name }} = None
if {{ source }} is not None:
{{ property.python_name }} = date.fromisoformat(cast(str, {{ source }}))
{{ property.python_name }} = datetime.date.fromisoformat(cast(str, {{ source }}))
{% endif %}
{% endmacro %}

@@ -1,10 +1,10 @@
{% macro construct(property, source) %}
{% if property.required %}
{{ property.python_name }} = datetime.fromisoformat({{ source }})
{{ property.python_name }} = datetime.datetime.fromisoformat({{ source }})
{% else %}
{{ property.python_name }} = None
if {{ source }} is not None:
{{ property.python_name }} = datetime.fromisoformat(cast(str, {{ source }}))
{{ property.python_name }} = datetime.datetime.fromisoformat(cast(str, {{ source }}))
{% endif %}
{% endmacro %}

@@ -0,0 +1,17 @@
{% macro construct(property, source) %}
{% if property.required %}
{{ property.python_name }} = {{ source }}
{% else %}
{{ property.python_name }} = None
if {{ source }} is not None:
{{ property.python_name }} = {{ source }}
{% endif %}
{% endmacro %}

{% macro transform(property, source, destination) %}
{% if property.required %}
{{ destination }} = {{ source }}
{% else %}
{{ destination }} = {{ source }} if {{ source }} else None
{% endif %}
{% endmacro %}
@@ -24,13 +24,13 @@ def _parse_{{ property.python_name }}(data: Dict[str, Any]) -> {{ property.get_t
{% macro transform(property, source, destination) %}
{% if not property.required %}
if {{ source }} is None:
{{ destination }} = None
{{ destination }}: {{ property.get_type_string() }} = None
{% endif %}
{% for inner_property in property.inner_properties %}
{% if loop.first and property.required %}{# No if None statement before this #}
if isinstance({{ source }}, {{ inner_property.get_type_string() }}):
if isinstance({{ source }}, {{ inner_property.get_type_string(no_optional=True) }}):
{% elif not loop.last %}
elif isinstance({{ source }}, {{ inner_property.get_type_string() }}):
elif isinstance({{ source }}, {{ inner_property.get_type_string(no_optional=True) }}):
{% else %}
else:
{% endif %}
@@ -1,25 +1,36 @@
import re
from keyword import iskeyword

import stringcase


def _sanitize(value: str) -> str:
return re.sub(r"[^\w _-]+", "", value)
def sanitize(value: str) -> str:
return re.sub(r"[^\w _\-]+", "", value)


def fix_keywords(value: str) -> str:
if iskeyword(value):
return f"{value}_"
return value


def group_title(value: str) -> str:
value = re.sub(r"([A-Z]{2,})([A-Z][a-z]|[ -_]|$)", lambda m: m.group(1).title() + m.group(2), value.strip())
value = re.sub(r"([A-Z]{2,})([A-Z][a-z]|[ \-_]|$)", lambda m: m.group(1).title() + m.group(2), value.strip())
value = re.sub(r"(^|[ _-])([A-Z])", lambda m: m.group(1) + m.group(2).lower(), value)
return value


def snake_case(value: str) -> str:
return stringcase.snakecase(group_title(_sanitize(value)))
return fix_keywords(stringcase.snakecase(group_title(sanitize(value))))


def pascal_case(value: str) -> str:
return stringcase.pascalcase(_sanitize(value))
return fix_keywords(stringcase.pascalcase(sanitize(value)))


def kebab_case(value: str) -> str:
return stringcase.spinalcase(group_title(_sanitize(value)))
return fix_keywords(stringcase.spinalcase(group_title(sanitize(value))))


def remove_string_escapes(value: str) -> str:
return value.replace('"', r"\"")
@@ -2,7 +2,7 @@
from pydantic.error_wrappers import ErrorWrapper

import openapi_python_client.schema as oai
from openapi_python_client import GeneratorError
from openapi_python_client import GeneratorError, utils
from openapi_python_client.parser.errors import ParseError

MODULE_NAME = "openapi_python_client.parser.openapi"
@@ -546,6 +546,8 @@ def test_from_data(self, mocker):
responses=mocker.MagicMock(),
)

mocker.patch("openapi_python_client.utils.remove_string_escapes", return_value=data.description)

endpoint = Endpoint.from_data(data=data, path=path, method=method, tag="default")

assert endpoint == _add_body.return_value

Large diffs are not rendered by default.

@@ -22,3 +22,15 @@ def test_snake_case_from_camel():

def test_kebab_case():
assert utils.kebab_case("keep_alive") == "keep-alive"


def test__sanitize():
assert utils.sanitize("something*~with lots_- of weird things}=") == "somethingwith lots_- of weird things"


def test_no_string_escapes():
assert utils.remove_string_escapes('an "evil" string') == 'an \\"evil\\" string'


def test__fix_keywords():
assert utils.fix_keywords("None") == "None_"