Skip to content

Commit

Permalink
Provide enum for openapi schema when using ChoiceField (#1168)
Browse files Browse the repository at this point in the history
  • Loading branch information
imomaliev committed Feb 4, 2020
1 parent e6d5f2d commit 400469f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 5 deletions.
16 changes: 12 additions & 4 deletions django_filters/rest_framework/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,22 @@ def get_schema_operation_parameters(self, view):
)

filterset_class = self.get_filterset_class(view, queryset)
return [] if not filterset_class else [
({

if not filterset_class:
return []

parameters = []
for field_name, field in filterset_class.base_filters.items():
parameter = {
'name': field_name,
'required': field.extra['required'],
'in': 'query',
'description': field.label if field.label is not None else field_name,
'schema': {
'type': 'string',
},
}) for field_name, field in filterset_class.base_filters.items()
]
}
if field.extra and 'choices' in field.extra:
parameter['schema']['enum'] = [c[0] for c in field.extra['choices']]
parameters.append(parameter)
return parameters
4 changes: 4 additions & 0 deletions tests/rest_framework/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ class DjangoFilterOrderingModel(models.Model):

class Meta:
ordering = ['-date']


class CategoryItem(BaseFilterableItem):
category = models.CharField(max_length=10, choices=(("home", "Home"), ("office", "Office")))
34 changes: 33 additions & 1 deletion tests/rest_framework/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)

from ..models import Article
from .models import FilterableItem
from .models import CategoryItem, FilterableItem

factory = APIRequestFactory()

Expand All @@ -26,6 +26,12 @@ class Meta:
fields = '__all__'


class CategoryItemSerializer(serializers.ModelSerializer):
class Meta:
model = CategoryItem
fields = '__all__'


# These class are used to test a filter class.
class SeveralFieldsFilter(FilterSet):
text = filters.CharFilter(lookup_expr='icontains')
Expand All @@ -52,6 +58,13 @@ class FilterClassRootView(FilterableItemView):
filterset_class = SeveralFieldsFilter


class CategoryItemView(generics.ListCreateAPIView):
queryset = CategoryItem.objects.all()
serializer_class = CategoryItemSerializer
filter_backends = (DjangoFilterBackend,)
filterset_fields = ["category"]


class GetFilterClassTests(TestCase):

def test_filterset_class(self):
Expand Down Expand Up @@ -237,6 +250,25 @@ def test_get_operation_parameters_with_filterset_fields_list(self):

self.assertEqual(fields, ['decimal', 'date'])

def test_get_operation_parameters_with_filterset_fields_list_with_choices(self):
backend = DjangoFilterBackend()
fields = backend.get_schema_operation_parameters(CategoryItemView())

self.assertEqual(
fields,
[{
'name': 'category',
'required': False,
'in': 'query',
'description': 'category',
'schema': {
'type': 'string',
'enum': ['home', 'office']
},

}]
)


class TemplateTests(TestCase):
def test_backend_output(self):
Expand Down

0 comments on commit 400469f

Please sign in to comment.