Skip to content

Commit

Permalink
Merge pull request #94 from strollby/feat/pydantic-v2
Browse files Browse the repository at this point in the history
Pydantic V2 Support
  • Loading branch information
dantheman39 committed Jan 31, 2024
2 parents 54a159b + 58135e7 commit 5d66766
Show file tree
Hide file tree
Showing 13 changed files with 379 additions and 167 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
os: [ubuntu-latest, macos-latest, windows-latest]
steps:
- uses: actions/checkout@v3
Expand Down
143 changes: 86 additions & 57 deletions graphene_pydantic/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,49 @@
import collections.abc
import datetime
import decimal
import inspect
import enum
import inspect
import sys
import typing as T
import uuid
from typing import Type, get_origin

import graphene
from graphene import (
UUID,
Boolean,
Enum,
Field,
Float,
ID,
InputField,
Int,
JSONString,
List,
String,
UUID,
Union,
)
import graphene
from graphene.types.base import BaseType
from graphene.types.datetime import Date, DateTime, Time
from pydantic import BaseModel
from pydantic.fields import ModelField
from pydantic.typing import evaluate_forwardref
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined

from .registry import Registry
from .util import construct_union_class_name
from .registry import Placeholder, Registry
from .util import construct_union_class_name, evaluate_forward_ref

from pydantic import fields
PYTHON10 = sys.version_info >= (3, 10)
if PYTHON10:
from types import UnionType

GRAPHENE2 = graphene.VERSION[0] < 3

SHAPE_SINGLETON = (fields.SHAPE_SINGLETON,)
SHAPE_SEQUENTIAL = (
fields.SHAPE_LIST,
fields.SHAPE_TUPLE,
fields.SHAPE_TUPLE_ELLIPSIS,
fields.SHAPE_SEQUENCE,
fields.SHAPE_SET,
)

if hasattr(fields, "SHAPE_DICT"):
SHAPE_MAPPING = T.cast(
T.Tuple, (fields.SHAPE_MAPPING, fields.SHAPE_DICT, fields.SHAPE_DEFAULTDICT)
)
else:
SHAPE_MAPPING = T.cast(T.Tuple, (fields.SHAPE_MAPPING,))
try:
from bson import ObjectId

BSON_OBJECT_ID_SUPPORTED = True
except ImportError:
BSON_OBJECT_ID_SUPPORTED = False

try:
from graphene.types.decimal import Decimal as GrapheneDecimal
Expand All @@ -59,7 +54,6 @@
# graphene 2.1.5+ is required for Decimals
DECIMAL_SUPPORTED = False


NONE_TYPE = None.__class__ # need to do this because mypy complains about type(None)


Expand All @@ -80,7 +74,7 @@ def _get_field(root, _info):


def convert_pydantic_input_field(
field: ModelField,
field: FieldInfo,
registry: Registry,
parent_type: T.Type = None,
model: T.Type[BaseModel] = None,
Expand All @@ -90,26 +84,29 @@ def convert_pydantic_input_field(
Convert a Pydantic model field into a Graphene type field that we can add
to the generated Graphene data model type.
"""
declared_type = getattr(field, "type_", None)
declared_type = getattr(field, "annotation", None)
field_kwargs.setdefault(
"type" if GRAPHENE2 else "type_",
convert_pydantic_type(
declared_type, field, registry, parent_type=parent_type, model=model
),
)
field_kwargs.setdefault("required", field.required)
field_kwargs.setdefault("default_value", field.default)
field_kwargs.setdefault("required", field.is_required())
field_kwargs.setdefault(
"default_value", None if field.default is PydanticUndefined else field.default
)
# TODO: find a better way to get a field's description. Some ideas include:
# - hunt down the description from the field's schema, or the schema
# from the field's base model
# - maybe even (Sphinx-style) parse attribute documentation
field_kwargs.setdefault("description", field.field_info.description)
field_kwargs.setdefault("description", field.description)

return InputField(**field_kwargs)


def convert_pydantic_field(
field: ModelField,
name: str,
field: FieldInfo,
registry: Registry,
parent_type: T.Type = None,
model: T.Type[BaseModel] = None,
Expand All @@ -119,44 +116,67 @@ def convert_pydantic_field(
Convert a Pydantic model field into a Graphene type field that we can add
to the generated Graphene data model type.
"""
declared_type = getattr(field, "type_", None)
declared_type = getattr(field, "annotation", None)

# Convert Python 10 UnionType to T.Union
if PYTHON10:
is_union_type = (
get_origin(declared_type) is T.Union
or get_origin(declared_type) is UnionType
)
else:
is_union_type = get_origin(declared_type) is T.Union

if is_union_type:
declared_type = T.Union[declared_type.__args__]

field_kwargs.setdefault(
"type" if GRAPHENE2 else "type_",
convert_pydantic_type(
declared_type, field, registry, parent_type=parent_type, model=model
),
)
field_kwargs.setdefault("required", not field.allow_none)
field_kwargs.setdefault("default_value", field.default)
if field.has_alias:
field_kwargs.setdefault(
"required",
field.is_required()
or (
type(field.default) is not PydanticUndefined
and getattr(declared_type, "_name", "") != "Optional"
and not is_union_type
),
)
field_kwargs.setdefault(
"default_value", None if field.default is PydanticUndefined else field.default
)
if field.alias:
field_kwargs.setdefault("name", field.alias)
# TODO: find a better way to get a field's description. Some ideas include:
# - hunt down the description from the field's schema, or the schema
# from the field's base model
# - maybe even (Sphinx-style) parse attribute documentation
field_kwargs.setdefault("description", field.field_info.description)
field_kwargs.setdefault("description", field.description)

# Handle Graphene 2 and 3
field_type = field_kwargs.pop("type", field_kwargs.pop("type_", None))
if field_type is None:
raise ValueError("No field type could be determined.")

resolver_function = getattr(parent_type, "resolve_" + field.name, None)
resolver_function = getattr(parent_type, "resolve_" + name, None)
if resolver_function and callable(resolver_function):
field_resolver = resolver_function
else:
field_resolver = get_attr_resolver(field.name)
field_resolver = get_attr_resolver(name)

return Field(field_type, resolver=field_resolver, **field_kwargs)


def convert_pydantic_type(
type_: T.Type,
field: ModelField,
field: FieldInfo,
registry: Registry,
parent_type: T.Type = None,
model: T.Type[BaseModel] = None,
) -> BaseType: # noqa: C901
) -> T.Union[Type[T.Union[BaseType, List]], Placeholder]: # noqa: C901
"""
Convert a Pydantic type to a Graphene Field type, including not just the
native Python type but any additional metadata (e.g. shape) that Pydantic
Expand All @@ -165,26 +185,30 @@ def convert_pydantic_type(
graphene_type = find_graphene_type(
type_, field, registry, parent_type=parent_type, model=model
)
if field.shape in SHAPE_SINGLETON:
return graphene_type
elif field.shape in SHAPE_SEQUENTIAL:
# TODO: _should_ Sets remain here?
return List(graphene_type)
elif field.shape in SHAPE_MAPPING:
field_type = getattr(field.annotation, "__origin__", None)
if field_type == map: # SHAPE_MAPPING
raise ConversionError("Don't know how to handle mappings in Graphene.")

return graphene_type


def find_graphene_type(
type_: T.Type,
field: ModelField,
field: FieldInfo,
registry: Registry,
parent_type: T.Type = None,
model: T.Type[BaseModel] = None,
) -> BaseType: # noqa: C901
) -> T.Union[Type[T.Union[BaseType, List]], Placeholder]: # noqa: C901
"""
Map a native Python type to a Graphene-supported Field type, where possible,
throwing an error if we don't know what to map it to.
"""

# Convert Python 10 UnionType to T.Union
if PYTHON10:
if isinstance(type_, UnionType):
type_ = T.Union[type_.__args__]

if type_ == uuid.UUID:
return UUID
elif type_ in (str, bytes):
Expand All @@ -199,6 +223,10 @@ def find_graphene_type(
return Boolean
elif type_ == float:
return Float
elif BSON_OBJECT_ID_SUPPORTED and type_ == ObjectId:
return ID
elif type_ == dict:
return JSONString
elif type_ == decimal.Decimal:
return GrapheneDecimal if DECIMAL_SUPPORTED else Float
elif type_ == int:
Expand Down Expand Up @@ -231,12 +259,13 @@ def find_graphene_type(
if not sibling:
raise ConversionError(
"Don't know how to convert the Pydantic field "
f"{field!r} ({field.type_}), could not resolve "
f"{field!r} ({field.annotation}), could not resolve "
"the forward reference. Did you call `resolve_placeholders()`? "
"See the README for more on forward references."
)

module_ns = sys.modules[sibling.__module__].__dict__
resolved = evaluate_forwardref(type_, module_ns, None)
resolved = evaluate_forward_ref(type_, module_ns, None)
# TODO: make this behavior optional. maybe this is a place for the TypeOptions to play a role?
if registry:
registry.add_placeholder_for_model(resolved)
Expand Down Expand Up @@ -265,20 +294,20 @@ def find_graphene_type(
return List
else:
raise ConversionError(
f"Don't know how to convert the Pydantic field {field!r} ({field.type_})"
f"Don't know how to convert the Pydantic field {field!r} ({field.annotation})"
)


def convert_generic_python_type(
type_: T.Type,
field: ModelField,
field: FieldInfo,
registry: Registry,
parent_type: T.Type = None,
model: T.Type[BaseModel] = None,
) -> BaseType: # noqa: C901
) -> T.Union[Type[T.Union[BaseType, List]], Placeholder]: # noqa: C901
"""
Convert annotated Python generic types into the most appropriate Graphene
Field type -- e.g. turn `typing.Union` into a Graphene Union.
Field type -- e.g., turn `typing.Union` into a Graphene Union.
"""
origin = type_.__origin__
if not origin: # pragma: no cover # this really should be impossible
Expand Down Expand Up @@ -321,14 +350,14 @@ def convert_generic_python_type(
elif origin in (T.Dict, T.Mapping, collections.OrderedDict, dict) or issubclass(
origin, collections.abc.Mapping
):
raise ConversionError("Don't know how to handle mappings in Graphene")
raise ConversionError("Don't know how to handle mappings in Graphene.")
else:
raise ConversionError(f"Don't know how to handle {type_} (generic: {origin})")


def convert_union_type(
type_: T.Type,
field: ModelField,
field: FieldInfo,
registry: Registry,
parent_type: T.Type = None,
model: T.Type[BaseModel] = None,
Expand Down Expand Up @@ -361,11 +390,11 @@ def convert_union_type(

def convert_literal_type(
type_: T.Type,
field: ModelField,
field: FieldInfo,
registry: Registry,
parent_type: T.Type = None,
model: T.Type[BaseModel] = None,
):
) -> T.Union[Type[T.Union[BaseType, List]], Placeholder]:
"""
Convert an annotated Python Literal type into a Graphene Scalar or Union of Scalars.
"""
Expand Down
22 changes: 16 additions & 6 deletions graphene_pydantic/inputobjecttype.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,24 @@ def construct_fields(
if exclude_fields:
excluded = exclude_fields
elif only_fields:
excluded = tuple(k for k in model.__fields__ if k not in only_fields)
excluded = tuple(k for k in model.model_fields if k not in only_fields)

fields_to_convert = (
(k, v) for k, v in model.__fields__.items() if k not in excluded
(k, v) for k, v in model.model_fields.items() if k not in excluded
)

fields = {}
for name, field in fields_to_convert:
# Graphql does not accept union as input. Refer https://github.com/graphql/graphql-spec/issues/488
annotation = getattr(field, "annotation", None)
if isinstance(annotation, str) or isinstance(annotation, int):
union_types = field.annotation.__args__
if type(None) not in union_types or len(union_types) > 2:
continue
# But str|None or Union[str, None] is valid input equivalent to Optional[str]
base_type = list(filter(lambda x: x is not type(None), union_types)).pop()
field.annotation = T.Optional[base_type]

converted = convert_pydantic_input_field(
field, registry, parent_type=obj_type, model=model
)
Expand Down Expand Up @@ -127,11 +137,11 @@ def resolve_placeholders(cls):
meta = cls._meta
fields_to_update = {}
for name, field in meta.fields.items():
target_type = field._type
if hasattr(target_type, "_of_type"):
target_type = target_type._of_type
target_type = field.type
while hasattr(target_type, "of_type"):
target_type = target_type.of_type
if isinstance(target_type, Placeholder):
pydantic_field = meta.model.__fields__[name]
pydantic_field = meta.model.model_fields[name]
graphene_field = convert_pydantic_input_field(
pydantic_field,
meta.registry,
Expand Down

0 comments on commit 5d66766

Please sign in to comment.