Skip to content

Commit

Permalink
Fix SerializerMethodResourceRelatedField bug
Browse files Browse the repository at this point in the history
  • Loading branch information
glowka committed Apr 8, 2020
1 parent 99c6f42 commit e63804b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 8 deletions.
24 changes: 19 additions & 5 deletions drf_yasg_json_api/inspectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .utils import is_json_api
from .utils import is_json_api_request
from .utils import is_json_api_response
from .utils import is_many_related_field

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -161,6 +162,8 @@ def extract_relationships(self, fields, ChildSwaggerType, use_references, is_req
relationships = OrderedDict()
required_relationships = []
for field_name, field in fields.items():
self.maybe_fix_broken_parent_relation(field)

if self.should_strip_from_schema(field, is_request):
continue
# Self url field
Expand All @@ -183,7 +186,7 @@ def extract_relationships(self, fields, ChildSwaggerType, use_references, is_req
required=['id', 'type'] if (self.is_request_or_unknown(is_request)) and not field.read_only else None,
)))

if isinstance(field, serializers.ManyRelatedField):
if is_many_related_field(field):
relation_data_schema = openapi.Schema(type=openapi.TYPE_ARRAY, items=relation_data_schema)

relation_links_schema = self.get_links_from_id_field(field_name, field)
Expand Down Expand Up @@ -228,7 +231,7 @@ def get_resource_name_from_related_id_field(self, field_name, id_field):
this_model = getattr(serializer_meta, 'model', None)

source = getattr(id_field, 'source', '') or id_field.field_name
if not source and isinstance(id_field.parent, serializers.ManyRelatedField):
if not source and is_many_related_field(id_field.parent):
source = getattr(id_field.parent, 'source', '') or id_field.parent.field_name

model = get_related_model(this_model, source)
Expand Down Expand Up @@ -274,6 +277,15 @@ def is_request_or_unknown(self, is_request):
# evaluate None to True as well
return is_request is None or is_request

def maybe_fix_broken_parent_relation(self, candidate_field):
"""
SerializerMethodResourceRelatedField is a bit hacky and breaks field.parent.parent...serializer chain when used
with many=True. To avoid multiple exception in random inspectors, we just check every relation if fix is needed.
"""
child_field = getattr(candidate_field, 'child_relation', None)
if child_field and not getattr(child_field, 'parent', None):
setattr(child_field, 'parent', candidate_field)


class InlineSerializerSmartInspector(InlineSerializerInspector):
strip_read_fields_from_request = True
Expand All @@ -288,7 +300,7 @@ class IDIntegerFieldInspector(inspectors.FieldInspector):
"""

def field_to_swagger_object(self, field, swagger_object_type, **kwargs):
if isinstance(field, serializers.ManyRelatedField) or not is_json_api(self.view):
if is_many_related_field(field) or not is_json_api(self.view):
return inspectors.NotHandled

parent_serializer = get_parent_serializer(field)
Expand All @@ -299,6 +311,7 @@ def field_to_swagger_object(self, field, swagger_object_type, **kwargs):
model_field = get_model_field(model, field_name)
if (
model_field is not None and
isinstance(model_field, models.Field) and # avoid OneToOneRel and other field like caches
model_field.primary_key and
isinstance(model_field, (models.IntegerField, models.AutoField))
):
Expand All @@ -311,7 +324,8 @@ def field_to_swagger_object(self, field, swagger_object_type, **kwargs):
elif field.field_name == 'id' and isinstance(field, serializers.IntegerField):
type_info = get_basic_type_info(field)
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, **kwargs)
return SwaggerType(**type_info, type=openapi.TYPE_STRING)
type_info['type'] = openapi.TYPE_STRING
return SwaggerType(**type_info)

return inspectors.NotHandled

Expand All @@ -323,7 +337,7 @@ class ManyRelatedFieldInspector(inspectors.SimpleFieldInspector):
"""

def field_to_swagger_object(self, field, swagger_object_type, **kwargs):
if not isinstance(field, serializers.ManyRelatedField) or not is_json_api(self.view):
if not is_many_related_field(field) or not is_json_api(self.view):
return inspectors.NotHandled

id_field = field.child_relation
Expand Down
6 changes: 6 additions & 0 deletions drf_yasg_json_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,9 @@ def get_field_by_source(fields: list, source):
if field.source == source:
return field
return None


def is_many_related_field(field):
# Check for child relation attribute covers ManyRelationField as well as other possible cases like
# hacky SerializerMethodResourceRelatedField
return getattr(field, 'child_relation', None)
13 changes: 10 additions & 3 deletions tests/test_serializer_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,16 @@ class ProjectViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet):

def test_get__serializer_method_resource():
class ProjectSerializer(serializers.ModelSerializer):
members = relations.SerializerMethodResourceRelatedField(model=Member, source='get_members', read_only=True)
member = relations.SerializerMethodResourceRelatedField(model=Member, source='get_member', read_only=True)
members = relations.SerializerMethodResourceRelatedField(model=Member, many=True,
source='get_members', read_only=True)

class Meta:
model = Project
fields = ['name', 'archived', 'members']
fields = ['name', 'archived', 'member', 'members']

def get_member(self):
pass

def get_members(self):
pass
Expand All @@ -126,7 +131,9 @@ class ProjectViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet):
assert 'attributes' in response_schema['data']['properties']
assert list(response_schema['data']['properties']['attributes']['properties'].keys()) == ['name', 'archived']
assert 'relationships' in response_schema['data']['properties']
assert list(response_schema['data']['properties']['relationships']['properties'].keys()) == ['members']
assert list(response_schema['data']['properties']['relationships']['properties'].keys()) == ['member', 'members']
relationships_schema = response_schema['data']['properties']['relationships']['properties']
assert 'items' in relationships_schema['members']['properties']['data']


def test_get__included():
Expand Down

0 comments on commit e63804b

Please sign in to comment.