diff --git a/rest_framework/test.py b/rest_framework/test.py index 04409f9621..5af1bb4cea 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -152,14 +152,19 @@ def _encode_data(self, data, format=None, content_type=None): Encode the data returning a two tuple of (bytes, content_type) """ - if data is None: - return ('', content_type) - assert format is None or content_type is None, ( 'You may not set both `format` and `content_type`.' ) if content_type: + try: + data = self._encode_json(data, content_type) + except AttributeError: + pass + + if data is None: + data = '' + # Content type specified explicitly, treat data as a raw bytestring ret = force_bytes(data, settings.DEFAULT_CHARSET) @@ -177,7 +182,6 @@ def _encode_data(self, data, format=None, content_type=None): # Use format and render the data into a bytestring renderer = self.renderer_classes[format]() - ret = renderer.render(data) # Determine the content-type header from the renderer content_type = renderer.media_type @@ -186,6 +190,11 @@ def _encode_data(self, data, format=None, content_type=None): content_type, renderer.charset ) + if data is None: + ret = '' + else: + ret = renderer.render(data) + # Coerce text to bytes if required. if isinstance(ret, str): ret = ret.encode(renderer.charset) diff --git a/tests/test_testing.py b/tests/test_testing.py index 196319a29e..068ea27e03 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -9,9 +9,9 @@ from django.test import TestCase, override_settings from django.urls import path -from rest_framework import fields, serializers +from rest_framework import fields, parsers, serializers from rest_framework.authtoken.models import Token -from rest_framework.decorators import api_view +from rest_framework.decorators import api_view, parser_classes from rest_framework.response import Response from rest_framework.test import ( APIClient, APIRequestFactory, URLPatternsTestCase, force_authenticate @@ -51,6 +51,12 @@ class BasicSerializer(serializers.Serializer): flag = fields.BooleanField(default=lambda: True) +@api_view(['POST']) +@parser_classes((parsers.JSONParser,)) +def post_json_view(request): + return Response(request.data) + + @api_view(['POST']) def post_view(request): serializer = BasicSerializer(data=request.data) @@ -63,7 +69,8 @@ def post_view(request): path('session-view/', session_view), path('redirect-view/', redirect_view), path('redirect-view//', redirect_307_308_view), - path('post-view/', post_view) + path('post-json-view/', post_json_view), + path('post-view/', post_view), ] @@ -237,6 +244,17 @@ def test_empty_post_uses_default_boolean_value(self): assert response.status_code == 200 assert response.data == {"flag": True} + def test_post_encodes_data_based_on_json_content_type(self): + data = {'data': True} + response = self.client.post( + '/post-json-view/', + data=data, + content_type='application/json' + ) + + assert response.status_code == 200 + assert response.data == data + class TestAPIRequestFactory(TestCase): def test_csrf_exempt_by_default(self):