Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 64 additions & 17 deletions drf_openapi/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from rest_framework.schemas import SchemaGenerator
from rest_framework.schemas.generators import insert_into, distribute_links, LinkNode
from rest_framework.schemas.inspectors import get_pk_description, field_to_schema

import copy
from drf_openapi.codec import _get_parameters

class PaginatedListSerializer:
pass

class VersionedSerializers:
"""Adapted from https://github.com/avanov/Rhetoric/ :)
Expand Down Expand Up @@ -84,8 +86,9 @@ def get(cls, request_version):


class OpenApiSchemaGenerator(SchemaGenerator):
def __init__(self, version, title=None, url=None, description=None, patterns=None, urlconf=None):
def __init__(self, version, actions=None, title=None, url=None, description=None, patterns=None, urlconf=None):
self.version = version
self.actions = actions
super(OpenApiSchemaGenerator, self).__init__(title, url, description, patterns, urlconf)

def get_schema(self, request=None, public=False):
Expand All @@ -108,6 +111,7 @@ def get_schema(self, request=None, public=False):
url=url, content=links
)


def get_links(self, request=None):
"""
Return a dictionary containing all the links that should be
Expand All @@ -122,6 +126,16 @@ def get_links(self, request=None):
view = self.create_view(callback, method, request)
if getattr(view, 'exclude_from_schema', False):
continue

action = None
if hasattr(view, "action"):
action = view.action
# In case actions is defined we want to render only actions in the actions list
if self.actions is not None and action not in self.actions:
continue



path = self.coerce_path(path, method, view)
paths.append(path)
view_endpoints.append((path, method, view))
Expand Down Expand Up @@ -161,6 +175,11 @@ def get_link(self, path, method, view, version=None):
fields += view.schema.get_pagination_fields(path, method)
fields += view.schema.get_filter_fields(path, method)

if hasattr(view.schema, "get_custom_fields"):
fields += view.schema.get_custom_fields(path, method)



if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = view.schema.get_encoding(path, method)
else:
Expand All @@ -181,14 +200,12 @@ def get_link(self, path, method, view, version=None):
description = description + '\n\n**Response Description:**\n' + res_doc
response_serializer_class = response_serializer_class.get(version)

if not response_serializer_class and method_name in ('list', 'retrieve'):
if hasattr(view, 'get_serializer_class'):
response_serializer_class = view.get_serializer_class()
elif hasattr(view, 'serializer_class'):
response_serializer_class = view.serializer_class
if response_serializer_class and method_name == 'list':
response_serializer_class = self.get_paginator_serializer(
view, response_serializer_class)
if response_serializer_class and issubclass(response_serializer_class, PaginatedListSerializer):
response_serializer_class = self._default_response_class("list", view)


if response_serializer_class is None:
response_serializer_class = self._default_response_class(method_name, view)
response_schema, error_status_codes = self.get_response_object(
response_serializer_class, method_func.__doc__) if response_serializer_class else ({}, {})

Expand All @@ -202,6 +219,19 @@ def get_link(self, path, method, view, version=None):
description=description
)

def _default_response_class(self, method_name, view):
response_serializer_class = None
if method_name in ('list', 'retrieve'):
if hasattr(view, 'get_serializer_class'):
response_serializer_class = view.get_serializer_class()
elif hasattr(view, 'serializer_class'):
response_serializer_class = view.serializer_class
if response_serializer_class and method_name == 'list':
response_serializer_class = self.get_paginator_serializer(view, response_serializer_class)
return response_serializer_class



def get_paginator_serializer(self, view, child_serializer_class):
class BaseFakeListSerializer(serializers.Serializer):
results = child_serializer_class(many=True)
Expand Down Expand Up @@ -317,6 +347,11 @@ def get_serializer_fields(self, path, method, view, version=None, method_func=No
else:
location = 'query'

method_name = getattr(view, 'action', method.lower())
# I don't see reason to return the serializers fields in other actions
if method_name not in ['update', 'create']:
return []

serializer_class = self.get_serializer_class(view, method_func)
if not serializer_class:
return []
Expand Down Expand Up @@ -355,34 +390,47 @@ def get_serializer_fields(self, path, method, view, version=None, method_func=No

return fields

def remove_write_only_fields(self, field):
if isinstance(field, serializers.ListSerializer):
fields = [key for key, value in field.child.fields.items() if value.write_only]
for field_name in fields:
field.child.fields.pop(field_name)

for sub_field in field.child.fields:
self.remove_write_only_fields(sub_field)



def get_response_object(self, response_serializer_class, description):

fields = []
serializer = response_serializer_class()
nested_obj = {}

# I copied the serializer so I will be able to alter and not to affect other behaviours
serializer = copy.deepcopy(serializer)
for field in serializer.fields.values():
self.remove_write_only_fields(field)

# we don't want to render write only fields in the response
if field.write_only:
continue
# If field is a serializer, attempt to get its schema.
if isinstance(field, serializers.Serializer):
subfield_schema = self.get_response_object(field.__class__, None)[0].get('schema')

# If the schema exists, use it as the nested_obj
if subfield_schema is not None:
nested_obj[field.field_name] = subfield_schema
nested_obj[field.field_name]['description'] = field.help_text
continue

# Otherwise, carry-on and use the field's schema.
# Otherwise, carry-on and use the field's schema.get_filter_fields
fallback_schema = self.fallback_schema_from_field(field)
fields.append(Field(
name=field.field_name,
location='form',
required=field.required,
schema=fallback_schema if fallback_schema else field_to_schema(field),
))

res = _get_parameters(Link(fields=fields), None)

if not res:
if nested_obj:
return {
Expand All @@ -408,7 +456,6 @@ def get_response_object(self, response_serializer_class, description):

for status_code, description in getattr(response_meta, 'error_status_codes', {}).items():
error_status_codes[status_code] = {'description': description}

return response_schema, error_status_codes


Expand Down