diff --git a/drf_openapi/entities.py b/drf_openapi/entities.py index 70bba2d..825c6e8 100644 --- a/drf_openapi/entities.py +++ b/drf_openapi/entities.py @@ -15,9 +15,11 @@ from rest_framework.schemas import SchemaGenerator from rest_framework.schemas.generators import insert_into, distribute_links, LinkNode from rest_framework.schemas.inspectors import get_pk_description, field_to_schema - +import copy from drf_openapi.codec import _get_parameters +class PaginatedListSerializer: + pass class VersionedSerializers: """Adapted from https://github.com/avanov/Rhetoric/ :) @@ -84,8 +86,9 @@ def get(cls, request_version): class OpenApiSchemaGenerator(SchemaGenerator): - def __init__(self, version, title=None, url=None, description=None, patterns=None, urlconf=None): + def __init__(self, version, actions=None, title=None, url=None, description=None, patterns=None, urlconf=None): self.version = version + self.actions = actions super(OpenApiSchemaGenerator, self).__init__(title, url, description, patterns, urlconf) def get_schema(self, request=None, public=False): @@ -108,6 +111,7 @@ def get_schema(self, request=None, public=False): url=url, content=links ) + def get_links(self, request=None): """ Return a dictionary containing all the links that should be @@ -122,6 +126,16 @@ def get_links(self, request=None): view = self.create_view(callback, method, request) if getattr(view, 'exclude_from_schema', False): continue + + action = None + if hasattr(view, "action"): + action = view.action + # In case actions is defined we want to render only actions in the actions list + if self.actions is not None and action not in self.actions: + continue + + + path = self.coerce_path(path, method, view) paths.append(path) view_endpoints.append((path, method, view)) @@ -161,6 +175,11 @@ def get_link(self, path, method, view, version=None): fields += view.schema.get_pagination_fields(path, method) fields += view.schema.get_filter_fields(path, method) + if hasattr(view.schema, "get_custom_fields"): + fields += view.schema.get_custom_fields(path, method) + + + if fields and any([field.location in ('form', 'body') for field in fields]): encoding = view.schema.get_encoding(path, method) else: @@ -181,14 +200,12 @@ def get_link(self, path, method, view, version=None): description = description + '\n\n**Response Description:**\n' + res_doc response_serializer_class = response_serializer_class.get(version) - if not response_serializer_class and method_name in ('list', 'retrieve'): - if hasattr(view, 'get_serializer_class'): - response_serializer_class = view.get_serializer_class() - elif hasattr(view, 'serializer_class'): - response_serializer_class = view.serializer_class - if response_serializer_class and method_name == 'list': - response_serializer_class = self.get_paginator_serializer( - view, response_serializer_class) + if response_serializer_class and issubclass(response_serializer_class, PaginatedListSerializer): + response_serializer_class = self._default_response_class("list", view) + + + if response_serializer_class is None: + response_serializer_class = self._default_response_class(method_name, view) response_schema, error_status_codes = self.get_response_object( response_serializer_class, method_func.__doc__) if response_serializer_class else ({}, {}) @@ -202,6 +219,19 @@ def get_link(self, path, method, view, version=None): description=description ) + def _default_response_class(self, method_name, view): + response_serializer_class = None + if method_name in ('list', 'retrieve'): + if hasattr(view, 'get_serializer_class'): + response_serializer_class = view.get_serializer_class() + elif hasattr(view, 'serializer_class'): + response_serializer_class = view.serializer_class + if response_serializer_class and method_name == 'list': + response_serializer_class = self.get_paginator_serializer(view, response_serializer_class) + return response_serializer_class + + + def get_paginator_serializer(self, view, child_serializer_class): class BaseFakeListSerializer(serializers.Serializer): results = child_serializer_class(many=True) @@ -317,6 +347,11 @@ def get_serializer_fields(self, path, method, view, version=None, method_func=No else: location = 'query' + method_name = getattr(view, 'action', method.lower()) + # I don't see reason to return the serializers fields in other actions + if method_name not in ['update', 'create']: + return [] + serializer_class = self.get_serializer_class(view, method_func) if not serializer_class: return [] @@ -355,24 +390,39 @@ def get_serializer_fields(self, path, method, view, version=None, method_func=No return fields + def remove_write_only_fields(self, field): + if isinstance(field, serializers.ListSerializer): + fields = [key for key, value in field.child.fields.items() if value.write_only] + for field_name in fields: + field.child.fields.pop(field_name) + + for sub_field in field.child.fields: + self.remove_write_only_fields(sub_field) + + + def get_response_object(self, response_serializer_class, description): fields = [] serializer = response_serializer_class() nested_obj = {} - + # I copied the serializer so I will be able to alter and not to affect other behaviours + serializer = copy.deepcopy(serializer) for field in serializer.fields.values(): + self.remove_write_only_fields(field) + + # we don't want to render write only fields in the response + if field.write_only: + continue # If field is a serializer, attempt to get its schema. if isinstance(field, serializers.Serializer): subfield_schema = self.get_response_object(field.__class__, None)[0].get('schema') - # If the schema exists, use it as the nested_obj if subfield_schema is not None: nested_obj[field.field_name] = subfield_schema nested_obj[field.field_name]['description'] = field.help_text continue - - # Otherwise, carry-on and use the field's schema. + # Otherwise, carry-on and use the field's schema.get_filter_fields fallback_schema = self.fallback_schema_from_field(field) fields.append(Field( name=field.field_name, @@ -380,9 +430,7 @@ def get_response_object(self, response_serializer_class, description): required=field.required, schema=fallback_schema if fallback_schema else field_to_schema(field), )) - res = _get_parameters(Link(fields=fields), None) - if not res: if nested_obj: return { @@ -408,7 +456,6 @@ def get_response_object(self, response_serializer_class, description): for status_code, description in getattr(response_meta, 'error_status_codes', {}).items(): error_status_codes[status_code] = {'description': description} - return response_schema, error_status_codes