From 15c613a9eb645c63102b9e894199bcf1c9bf4d65 Mon Sep 17 00:00:00 2001 From: Jameel Al-Aziz <247849+jalaziz@users.noreply.github.com> Date: Wed, 22 Feb 2023 07:39:01 -0800 Subject: [PATCH] Allow generic requests, responses, fields, views (#8825) Allow Request, Response, Field, and GenericAPIView to be subscriptable. This allows the classes to be made generic for type checking. This is especially useful since monkey patching DRF can be problematic as seen in this [issue][1]. [1]: https://github.com/typeddjango/djangorestframework-stubs/issues/299 --- rest_framework/fields.py | 4 ++++ rest_framework/generics.py | 4 ++++ rest_framework/request.py | 4 ++++ rest_framework/response.py | 4 ++++ tests/test_fields.py | 10 ++++++++++ tests/test_generics.py | 25 +++++++++++++++++++++++++ tests/test_request.py | 10 ++++++++++ tests/test_response.py | 12 ++++++++++++ 8 files changed, 73 insertions(+) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 1c64255962..613bd325a6 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -356,6 +356,10 @@ def __init__(self, *, read_only=False, write_only=False, messages.update(error_messages or {}) self.error_messages = messages + # Allow generic typing checking for fields. + def __class_getitem__(cls, *args, **kwargs): + return cls + def bind(self, field_name, parent): """ Initializes the field name and parent for the field instance. diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 55cfafda44..1673033214 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -45,6 +45,10 @@ class GenericAPIView(views.APIView): # The style to use for queryset pagination. pagination_class = api_settings.DEFAULT_PAGINATION_CLASS + # Allow generic typing checking for generic views. + def __class_getitem__(cls, *args, **kwargs): + return cls + def get_queryset(self): """ Get the list of items for this view. diff --git a/rest_framework/request.py b/rest_framework/request.py index 194be5f6d4..93109226d9 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -186,6 +186,10 @@ def __repr__(self): self.method, self.get_full_path()) + # Allow generic typing checking for requests. + def __class_getitem__(cls, *args, **kwargs): + return cls + def _default_negotiator(self): return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() diff --git a/rest_framework/response.py b/rest_framework/response.py index 4954237347..6e756544c6 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -46,6 +46,10 @@ def __init__(self, data=None, status=None, for name, value in headers.items(): self[name] = value + # Allow generic typing checking for responses. + def __class_getitem__(cls, *args, **kwargs): + return cls + @property def rendered_content(self): renderer = getattr(self, 'accepted_renderer', None) diff --git a/tests/test_fields.py b/tests/test_fields.py index 512f3f7895..56e2a45bad 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2,6 +2,7 @@ import math import os import re +import sys import uuid from decimal import ROUND_DOWN, ROUND_UP, Decimal @@ -625,6 +626,15 @@ def test_parent_binding(self): assert field.root is parent +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_field_is_subscriptable(self): + assert serializers.Field is serializers.Field["foo"] + + # Tests for field input and output values. # ---------------------------------------- diff --git a/tests/test_generics.py b/tests/test_generics.py index 78dc5afb64..9990389c94 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1,3 +1,5 @@ +import sys + import pytest from django.db import models from django.http import Http404 @@ -698,3 +700,26 @@ def list(self, request): serializer = response.serializer assert serializer.context is context + + +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_genericview_is_subscriptable(self): + assert generics.GenericAPIView is generics.GenericAPIView["foo"] + + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_listview_is_subscriptable(self): + assert generics.ListAPIView is generics.ListAPIView["foo"] + + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_instanceview_is_subscriptable(self): + assert generics.RetrieveAPIView is generics.RetrieveAPIView["foo"] diff --git a/tests/test_request.py b/tests/test_request.py index 8c18aea9e6..e37aa7dda1 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -3,6 +3,7 @@ """ import copy import os.path +import sys import tempfile import pytest @@ -352,3 +353,12 @@ class TestDeepcopy(TestCase): def test_deepcopy_works(self): request = Request(factory.get('/', secure=False)) copy.deepcopy(request) + + +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_request_is_subscriptable(self): + assert Request is Request["foo"] diff --git a/tests/test_response.py b/tests/test_response.py index 0d5528dc9a..cab19a1eb8 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,3 +1,6 @@ +import sys + +import pytest from django.test import TestCase, override_settings from django.urls import include, path, re_path @@ -283,3 +286,12 @@ def test_form_has_label_and_help_text(self): self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') # self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text description.') + + +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_response_is_subscriptable(self): + assert Response is Response["foo"]