Skip to content

Commit

Permalink
Add support for manual relation fields
Browse files Browse the repository at this point in the history
  • Loading branch information
glowka committed Apr 8, 2020
1 parent e63804b commit ed86906
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 106 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,14 @@ SWAGGER_SETTINGS = {
'drf_yasg.inspectors.RecursiveFieldInspector',
'drf_yasg_json_api.inspectors.XPropertiesFilter', # Added
'drf_yasg_json_api.inspectors.InlineSerializerSmartInspector', # Replaces ReferencingSerializerInspector
'drf_yasg_json_api.inspectors.IDFieldInspector', # Added
'drf_yasg_json_api.inspectors.IntegerIDFieldInspector', # Added
'drf_yasg.inspectors.ChoiceFieldInspector',
'drf_yasg.inspectors.FileFieldInspector',
'drf_yasg.inspectors.DictFieldInspector',
'drf_yasg.inspectors.JSONFieldInspector',
'drf_yasg.inspectors.HiddenFieldInspector',
'drf_yasg_json_api.inspectors.ManyRelatedFieldInspector', # Added
'drf_yasg_json_api.inspectors.ManyRelatedFieldInspector', # Added
'drf_yasg_json_api.inspectors.IntegerPrimaryKeyRelatedFieldInspector', # Added
'drf_yasg.inspectors.RelatedFieldInspector',
'drf_yasg.inspectors.SerializerMethodFieldInspector',
'drf_yasg.inspectors.SimpleFieldInspector',
Expand Down
99 changes: 61 additions & 38 deletions drf_yasg_json_api/inspectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from django.utils.functional import cached_property
from drf_yasg import inspectors
from drf_yasg import openapi
from drf_yasg.inspectors.field import get_basic_type_info
from drf_yasg.inspectors.field import get_model_field
from drf_yasg.inspectors.field import get_parent_serializer
from drf_yasg.utils import filter_none
Expand All @@ -22,6 +21,8 @@
from rest_framework_json_api.utils import get_resource_type_from_serializer

from .utils import get_field_by_source
from .utils import get_field_model
from .utils import get_field_source
from .utils import get_related_model
from .utils import get_serializer_model_primary_key
from .utils import is_json_api
Expand Down Expand Up @@ -292,73 +293,95 @@ class InlineSerializerSmartInspector(InlineSerializerInspector):
strip_write_fields_from_response = True


class IDIntegerFieldInspector(inspectors.FieldInspector):
class IntegerFieldInspectorMixin:
int64_fields = (models.BigIntegerField, models.BigAutoField)

def get_format(self, model_field):
return openapi.FORMAT_INT64 if isinstance(model_field, self.int64_fields) else openapi.FORMAT_INT32


class IntegerIDFieldInspector(IntegerFieldInspectorMixin, inspectors.FieldInspector):
"""
Force string type on Integer ID, it can happen when:
- Primary Key of model is Integer
- Serializer's field named "id" is Integer
Force string type on ID field that on model level is integer. Since we get here just an integer, we look for:
- Primary Key of model that is models.IntegerField
- Serializer field named "id" that is serializers.IntegerField
"""

def field_to_swagger_object(self, field, swagger_object_type, **kwargs):
if is_many_related_field(field) or not is_json_api(self.view):
if not isinstance(field, serializers.IntegerField) or not is_json_api(self.view):
return inspectors.NotHandled

stringify_id = False
integer_format = None

parent_serializer = get_parent_serializer(field)
serializer_meta = getattr(parent_serializer, 'Meta', None)
model = getattr(serializer_meta, 'model', None)
if model is not None:
field_name = getattr(field, 'source', None) or field.field_name
model_field = get_model_field(model, field_name)
# Check for primary key only if sure it is pure Field (not FieldCacheMixin or anything)
if (
model_field is not None and
isinstance(model_field, models.Field) and # avoid OneToOneRel and other field like caches
model_field.primary_key and
isinstance(model_field, (models.IntegerField, models.AutoField))
isinstance(model_field, (models.IntegerField, models.AutoField)) and
model_field.primary_key
):
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, **kwargs)
return SwaggerType(
type=openapi.TYPE_STRING,
format=openapi.FORMAT_INT64
if isinstance(model_field, models.BigIntegerField) else openapi.FORMAT_INT32
)
stringify_id = True
integer_format = self.get_format(model_field)

elif field.field_name == 'id' and isinstance(field, serializers.IntegerField):
type_info = get_basic_type_info(field)
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, **kwargs)
type_info['type'] = openapi.TYPE_STRING
return SwaggerType(**type_info)
stringify_id = True

return inspectors.NotHandled
if not stringify_id:
return inspectors.NotHandled

SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, **kwargs)
return SwaggerType(type=openapi.TYPE_STRING, format=integer_format)

class ManyRelatedFieldInspector(inspectors.SimpleFieldInspector):

class IntegerPrimaryKeyRelatedFieldInspector(IntegerFieldInspectorMixin, inspectors.FieldInspector):
"""
Minimum: unwrap ManyRelatedField child relation and pass further for probing.
Maximum: unwrap and force string type of Integer RelatedID
Force string type on PrimaryRelatedField that refers to model with integer primary key.
"""

def field_to_swagger_object(self, field, swagger_object_type, **kwargs):
if not is_many_related_field(field) or not is_json_api(self.view):
if not isinstance(field, serializers.PrimaryKeyRelatedField) or not is_json_api(self.view):
return inspectors.NotHandled

id_field = field.child_relation
if is_many_related_field(field):
return inspectors.NotHandled

parent_serializer = get_parent_serializer(field)
serializer_meta = getattr(parent_serializer, 'Meta', None)
model = getattr(serializer_meta, 'model', None)

related_model = None
# Try extracting by traversing model and model fields
if model is not None:
source = getattr(field, 'source', None) or field.field_name
field_model = get_related_model(model, source)
if field_model is not None:
model_field = get_model_field(field_model, 'pk')
if isinstance(model_field, (models.IntegerField, models.AutoField)):
SwaggerType, ChildSwaggerType = self._get_partial_types(id_field, swagger_object_type, **kwargs)
return SwaggerType(
type=openapi.TYPE_STRING,
format=openapi.FORMAT_INT64
if isinstance(model_field, models.BigIntegerField) else openapi.FORMAT_INT32
)

return self.probe_field_inspectors(id_field, swagger_object_type, **kwargs)
related_model = get_related_model(model, source=get_field_source(field))
# Try extracting by extracting directly from field
if not related_model:
related_model = get_field_model(field)

if related_model:
related_model_pk_field = get_model_field(related_model, 'pk')
if isinstance(related_model_pk_field, (models.IntegerField, models.AutoField)):
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, **kwargs)
return SwaggerType(type=openapi.TYPE_STRING, format=self.get_format(related_model_pk_field))

return inspectors.NotHandled


class ManyRelatedFieldInspector(inspectors.SimpleFieldInspector):
"""
Unwrap ManyRelatedField child relation as array node has already been added by SerializerInspector.
"""

def field_to_swagger_object(self, field, swagger_object_type, **kwargs):
if is_many_related_field(field) and is_json_api(self.view):
return self.probe_field_inspectors(field.child_relation, swagger_object_type, **kwargs)

return inspectors.NotHandled


class NamesFormatFilter(inspectors.FieldInspector):
Expand Down
22 changes: 22 additions & 0 deletions drf_yasg_json_api/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Optional

from django.db import models
from rest_framework import serializers


Expand Down Expand Up @@ -48,7 +51,26 @@ def get_field_by_source(fields: list, source):
return None


def get_field_model(field: serializers.Field) -> Optional[models.Model]:
field_model = getattr(field, 'model', None)
if field_model:
return field_model

try:
return field.queryset.model
except AttributeError:
return None


def is_many_related_field(field):
# Check for child relation attribute covers ManyRelationField as well as other possible cases like
# hacky SerializerMethodResourceRelatedField
return getattr(field, 'child_relation', None)


def get_field_source(field: serializers.Field):
source = field.source or field.field_name
# If no source and parent is not serializer it is child_field of other field
if not source and not isinstance(field.parent, serializers.BaseSerializer):
return field.parent.source or field.parent.field_name
return None
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,13 @@ def pytest_configure(config):
'django.contrib.auth.hashers.MD5PasswordHasher',
),
JSON_API_FORMAT_TYPES='dasherize',
JSON_API_FORMAT_KEYS='dasherize',
JSON_API_FORMAT_FIELD_NAMES='dasherize',
JSON_API_PLURALIZE_TYPES=True,

# For JSON API django-rest-framework-json-api <= 2.8
JSON_API_FORMAT_KEYS='dasherize',
JSON_API_RELATION_KEYS='dasherize',
JSON_API_PLURALIZE_RELATION_TYPE=True,
)

django.setup()
9 changes: 4 additions & 5 deletions tests/test_noop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ class MyModel(models.Model):
name = models.CharField(max_length=255)


class TestNoop:
@pytest.mark.django_db
def test_noop(self):
instance = MyModel(name='qwerty')
instance.save()
@pytest.mark.django_db
def test_django_models_setup():
instance = MyModel(name='qwerty')
instance.save()

0 comments on commit ed86906

Please sign in to comment.