diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 1eecc81..89e184a 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,20 @@ Changelog ============ +`v0.3.0 `__ +-------------------------------------------------------------------------------- + +**Features:** + +- `AUTHTOKEN_SELECT_RELATED_LIST `_ setting to enable performance optimization. (Issue 16_) + +**Other:** + +- Test cases now cover the :doc:`permissions`. + +.. _16: https://github.com/Eshaan7/django-rest-durin/issues/16 + + `v0.2.0 `__ -------------------------------------------------------------------------------- diff --git a/docs/source/settings.rst b/docs/source/settings.rst index a07810b..346c52b 100644 --- a/docs/source/settings.rst +++ b/docs/source/settings.rst @@ -18,6 +18,7 @@ Example ``settings.py``:: "EXPIRY_DATETIME_FORMAT": api_settings.DATETIME_FORMAT, "TOKEN_CACHE_TIMEOUT": 60, "REFRESH_TOKEN_ON_LOGIN": False, + "AUTHTOKEN_SELECT_RELATED_LIST": ["user"], } #...snip... @@ -83,4 +84,20 @@ Example ``settings.py``:: In the first case, the already existing token is sent in response. So this setting if set to ``True`` should extend the expiry time of the - token by it's :class:`durin.models.Client` ``token_ttl`` everytime login happens. \ No newline at end of file + token by it's :class:`durin.models.Client` ``token_ttl`` everytime login happens. + +.. data:: AUTHTOKEN_SELECT_RELATED_LIST + + Default: ``["user"]`` + + This is passed as an argument to ``select_related`` when the :class:`durin.auth.TokenAuthentication` class + fetches the :class:`durin.models.AuthToken` instance. For example, + + .. code-block:: python + + AuthToken.objects.select_related(*AUTHTOKEN_SELECT_RELATED_LIST).get(token=token_string) + + Otherwise, set to a falsy value such as ``None`` or ``False`` to not use ``select_related``. + + .. Hint:: Refer to `Django's select_related docs `_ + to see how this can boost performance by reducing number of SQL queries made. diff --git a/durin/auth.py b/durin/auth.py index 59f732e..28ddf87 100644 --- a/durin/auth.py +++ b/durin/auth.py @@ -53,9 +53,20 @@ def authenticate_credentials(cls, token): """ Verify that the given token exists in the database """ - token = token.decode("utf-8") + token_str = token.decode("utf-8") try: - auth_token = AuthToken.objects.get(token=token) + # read settings + to_select = durin_settings.AUTHTOKEN_SELECT_RELATED_LIST + + # get AuthToken object + if isinstance(to_select, list): + auth_token = AuthToken.objects.select_related(*to_select).get( + token=token_str + ) + else: + auth_token = AuthToken.objects.get(token=token_str) + + # validate token if cls._cleanup_token(auth_token): e = _("The given token has expired.") raise exceptions.AuthenticationFailed(e) diff --git a/durin/models.py b/durin/models.py index 22ff230..74ace5b 100644 --- a/durin/models.py +++ b/durin/models.py @@ -76,7 +76,8 @@ class Client(models.Model): def __str__(self): td = humanize.naturaldelta(self.token_ttl) - return "({0}, {1})".format(self.name, td) + rate = self.throttle_rate or "null" + return "({0}: {1}, {2})".format(self.name, td, rate) class AuthTokenManager(models.Manager): diff --git a/durin/permissions.py b/durin/permissions.py index cfa12fb..5fa99e6 100644 --- a/durin/permissions.py +++ b/durin/permissions.py @@ -26,9 +26,9 @@ class AllowSpecificClients(BasePermission): allowed_clients_name = () def has_permission(self, request, view): - if not hasattr(request, "_auth"): + if not request.auth: return False - return request._auth.client.name in self.allowed_clients_name + return request.auth.client.name in self.allowed_clients_name class DisallowSpecificClients(BasePermission): @@ -41,6 +41,6 @@ class DisallowSpecificClients(BasePermission): disallowed_clients_name = () def has_permission(self, request, view): - if not hasattr(request, "_auth"): + if not request.auth: return False - return request._auth.client.name not in self.disallowed_clients_name + return request.auth.client.name not in self.disallowed_clients_name diff --git a/durin/settings.py b/durin/settings.py index 707a38c..e151ef2 100644 --- a/durin/settings.py +++ b/durin/settings.py @@ -14,6 +14,7 @@ "EXPIRY_DATETIME_FORMAT": api_settings.DATETIME_FORMAT, "TOKEN_CACHE_TIMEOUT": 60, "REFRESH_TOKEN_ON_LOGIN": False, + "AUTHTOKEN_SELECT_RELATED_LIST": ["user"], } IMPORT_STRINGS = { diff --git a/example_project/permissions.py b/example_project/permissions.py new file mode 100644 index 0000000..f8694d5 --- /dev/null +++ b/example_project/permissions.py @@ -0,0 +1,13 @@ +from durin import permissions + +TEST_CLIENT_NAME = "web-browser-client-test" + + +class CustomAllowSpecificClients(permissions.AllowSpecificClients): + + allowed_clients_name = (TEST_CLIENT_NAME,) + + +class CustomDisallowSpecificClients(permissions.DisallowSpecificClients): + + disallowed_clients_name = (TEST_CLIENT_NAME,) diff --git a/example_project/urls.py b/example_project/urls.py index 6b0c5fd..d4cda8f 100644 --- a/example_project/urls.py +++ b/example_project/urls.py @@ -2,7 +2,13 @@ from django.urls import include, path, re_path from django.views.generic.base import RedirectView -from .views import CachedRootView, RootView, ThrottledView +from .views import ( + CachedRootView, + NoWebClientView, + OnlyWebClientView, + RootView, + ThrottledView, +) urlpatterns = [ path("", RedirectView.as_view(url="admin/", permanent=False)), @@ -11,4 +17,14 @@ re_path(r"^api/$", RootView.as_view(), name="api-root"), re_path(r"^api/cached$", CachedRootView.as_view(), name="cached-auth-api"), re_path(r"^api/throttled$", ThrottledView.as_view(), name="throttled-api"), + re_path( + r"^api/onlywebclient$", + OnlyWebClientView.as_view(), + name="onlywebclient-api", + ), + re_path( + r"^api/nowebclient$", + NoWebClientView.as_view(), + name="nowebclient-api", + ), ] diff --git a/example_project/views.py b/example_project/views.py index 6f23312..665cd20 100644 --- a/example_project/views.py +++ b/example_project/views.py @@ -5,27 +5,50 @@ from durin.auth import CachedTokenAuthentication, TokenAuthentication from durin.throttling import UserClientRateThrottle +from .permissions import CustomAllowSpecificClients, CustomDisallowSpecificClients -class RootView(APIView): + +class _BaseAPIView(APIView): authentication_classes = (TokenAuthentication,) permission_classes = (IsAuthenticated,) + +class RootView(_BaseAPIView): def get(self, request): return Response("api root") -class CachedRootView(APIView): +class CachedRootView(_BaseAPIView): authentication_classes = (CachedTokenAuthentication,) - permission_classes = (IsAuthenticated,) def get(self, request): return Response("cached api root") -class ThrottledView(APIView): - authentication_classes = (TokenAuthentication,) - permission_classes = (IsAuthenticated,) +class ThrottledView(_BaseAPIView): throttle_classes = (UserClientRateThrottle,) def get(self, request): return Response("ThrottledView") + + +class OnlyWebClientView(_BaseAPIView): + """ + Only accessible to TEST_CLIENT_NAME + """ + + permission_classes = (CustomAllowSpecificClients, IsAuthenticated) + + def get(self, request): + return Response("OnlyWebClientView") + + +class NoWebClientView(_BaseAPIView): + """ + Not accessible to TEST_CLIENT_NAME + """ + + permission_classes = (CustomDisallowSpecificClients, IsAuthenticated) + + def get(self, request): + return Response("NoWebClientView") diff --git a/tests/__init__.py b/tests/__init__.py index deae2cf..bd5d645 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,7 +2,7 @@ from django.core.cache import cache as default_cache from rest_framework.test import APITestCase -from durin.models import Client +from durin.models import AuthToken, Client User = get_user_model() @@ -11,6 +11,7 @@ class CustomTestCase(APITestCase): def setUp(self): # cleanup default_cache.clear() + AuthToken.objects.all().delete() Client.objects.all().delete() # setup self.authclient = Client.objects.create(name="authclientfortest") diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..f57254e --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,64 @@ +from importlib import reload + +from django.db import reset_queries +from django.test import override_settings +from django.urls import reverse +from rest_framework.test import APIRequestFactory + +from durin import auth +from durin.models import AuthToken, Client +from durin.settings import durin_settings + +from . import CustomTestCase + +root_url = reverse("api-root") + +new_settings = durin_settings.defaults.copy() + + +class AuthTestCase(CustomTestCase): + def setUp(self): + super().setUp() + # authenticate client + self.token_instance = AuthToken.objects.create(self.user, self.authclient) + self.client.credentials( + HTTP_AUTHORIZATION=("Token %s" % self.token_instance.token) + ) + # reset queries + reset_queries() + self.assertNumQueries(0, msg="Queries were reset") + + def test_authtoken_lookup_1_sql_query(self): + with self.assertNumQueries( + 1, + msg="Since we use ``select_related`` it should take only 1 query", + ): + resp = self.client.get(root_url) + self.assertEqual(resp.status_code, 200) + + def test_authtoken_lookup_2_sql_query(self): + # override settings + new_settings["AUTHTOKEN_SELECT_RELATED_LIST"] = False + with override_settings(REST_DURIN=new_settings): + reload(auth) + with self.assertNumQueries( + 2, + msg="Since we didn't use ``select_related`` it should take 2 queries", + ): + resp = self.client.get(root_url) + self.assertEqual(resp.status_code, 200) + + def test_update_token_key(self): + self.assertEqual(AuthToken.objects.count(), 1) + self.assertEqual(Client.objects.count(), 1) + rf = APIRequestFactory() + request = rf.get("/") + request.META = { + "HTTP_AUTHORIZATION": "Token {}".format(self.token_instance.token) + } + (auth_user, auth_token) = auth.TokenAuthentication().authenticate(request) + self.assertEqual( + self.token_instance.token, + auth_token.token, + ) + self.assertEqual(self.user, auth_user) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..41f5a7c --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,45 @@ +from django.core.exceptions import ValidationError as DjValidationError +from django.test import TestCase + +from durin.models import Client + + +class ClientTestCase(TestCase): + @classmethod + def setUpClass(cls): + cls.client_names = ["web", "mobile", "cli"] + return super().setUpClass() + + def test_create_clients(self): + Client.objects.all().delete() + self.assertEqual(Client.objects.count(), 0) + for name in self.client_names: + Client.objects.create(name=name) + self.assertEqual(Client.objects.count(), len(self.client_names)) + + def test_throttle_rate_validation_ok(self): + testclient = Client.objects.create( + name="test_throttle_rate_validation", throttle_rate="2/m" + ) + testclient.full_clean() + + self.assertIsNotNone(testclient.pk) + self.assertIsNotNone(testclient.token_ttl) + self.assertIsNotNone(testclient.throttle_rate) + + def test_throttle_rate_validation_raises_exc(self): + + with self.assertRaises(DjValidationError): + testclient1 = Client.objects.create( + name="testclient1", throttle_rate="blahblah" + ) + testclient1.full_clean() + testclient1.delete() + + with self.assertRaises(DjValidationError): + testclient2 = Client.objects.create( + name="testclient2", + throttle_rate="2/minute", + ) + testclient2.full_clean() + testclient2.delete() diff --git a/tests/test_permissions.py b/tests/test_permissions.py new file mode 100644 index 0000000..f087855 --- /dev/null +++ b/tests/test_permissions.py @@ -0,0 +1,58 @@ +from django.urls import reverse +from rest_framework import status + +from durin.models import AuthToken, Client +from example_project.permissions import TEST_CLIENT_NAME + +from . import CustomTestCase + +onlywebclient_url = reverse("onlywebclient-api") +nowebclient_url = reverse("nowebclient-api") + + +class PermissionsTestCase(CustomTestCase): + def setUp(self): + super().setUp() + # authenticate client + Client.objects.all().delete() + self.token1 = AuthToken.objects.create( + self.user, Client.objects.create(name=TEST_CLIENT_NAME) + ) + + self.token2 = AuthToken.objects.create( + self.user, Client.objects.create(name="someotherclient") + ) + + def test_onlywebclient_view_200(self): + self.client.credentials(HTTP_AUTHORIZATION=("Token %s" % self.token1.token)) + resp = self.client.get(onlywebclient_url) + + self.assertEqual(resp.status_code, status.HTTP_200_OK) + + def test_onlywebclient_view_403(self): + self.client.credentials(HTTP_AUTHORIZATION=("Token %s" % self.token2.token)) + resp = self.client.get(onlywebclient_url) + + self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) + + def test_nowebclient_view_403(self): + self.client.credentials(HTTP_AUTHORIZATION=("Token %s" % self.token1.token)) + resp = self.client.get(nowebclient_url) + + self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) + + def test_nowebclient_view_200(self): + self.client.credentials(HTTP_AUTHORIZATION=("Token %s" % self.token2.token)) + resp = self.client.get(nowebclient_url) + + self.assertEqual(resp.status_code, status.HTTP_200_OK) + + def test_onlywebclient_view_no_auth_401(self): + resp = self.client.get(onlywebclient_url) + + self.assertEqual(resp.status_code, status.HTTP_401_UNAUTHORIZED) + + def test_nowebclient_view_no_auth_401(self): + resp = self.client.get(nowebclient_url) + + self.assertEqual(resp.status_code, status.HTTP_401_UNAUTHORIZED) diff --git a/tests/tests.py b/tests/test_views.py similarity index 87% rename from tests/tests.py rename to tests/test_views.py index fa01e9a..0ee4557 100644 --- a/tests/tests.py +++ b/tests/test_views.py @@ -2,15 +2,12 @@ from datetime import timedelta from importlib import reload -from django.core.exceptions import ValidationError as DjValidationError -from django.test import TestCase, override_settings +from django.test import override_settings from django.urls import reverse from rest_framework import status from rest_framework.serializers import DateTimeField -from rest_framework.test import APIRequestFactory from durin import views -from durin.auth import TokenAuthentication from durin.models import AuthToken, Client from durin.serializers import UserSerializer from durin.settings import durin_settings @@ -30,7 +27,7 @@ new_settings = durin_settings.defaults.copy() -class AuthTestCase(CustomTestCase): +class AuthViewsTestCase(CustomTestCase): def test_create_tokens_for_users(self): AuthToken.objects.all().delete() self.assertEqual(AuthToken.objects.count(), 0) @@ -184,20 +181,6 @@ def test_logout_all_deletes_only_targets_keys(self): "tokens from other users should not be affected by logout all", ) - def test_update_token_key(self): - self.assertEqual(AuthToken.objects.count(), 0) - self.assertEqual(Client.objects.count(), 1) - instance = AuthToken.objects.create(self.user, self.authclient) - rf = APIRequestFactory() - request = rf.get("/") - request.META = {"HTTP_AUTHORIZATION": "Token {}".format(instance.token)} - (auth_user, auth_token) = TokenAuthentication().authenticate(request) - self.assertEqual( - instance.token, - auth_token.token, - ) - self.assertEqual(self.user, auth_user) - def test_invalid_token_length_returns_401_code(self): invalid_token = "1" * (durin_settings.TOKEN_CHARACTER_LENGTH - 1) self.client.credentials(HTTP_AUTHORIZATION=("Token %s" % invalid_token)) @@ -331,47 +314,6 @@ def __create_clients(self): self.assertEqual(Client.objects.count(), len(self.client_names)) -class ClientTestCase(TestCase): - @classmethod - def setUpClass(cls): - cls.client_names = ["web", "mobile", "cli"] - return super().setUpClass() - - def test_create_clients(self): - Client.objects.all().delete() - self.assertEqual(Client.objects.count(), 0) - for name in self.client_names: - Client.objects.create(name=name) - self.assertEqual(Client.objects.count(), len(self.client_names)) - - def test_throttle_rate_validation_ok(self): - testclient = Client.objects.create( - name="test_throttle_rate_validation", throttle_rate="2/m" - ) - testclient.full_clean() - - self.assertIsNotNone(testclient.pk) - self.assertIsNotNone(testclient.token_ttl) - self.assertIsNotNone(testclient.throttle_rate) - - def test_throttle_rate_validation_raises_exc(self): - - with self.assertRaises(DjValidationError): - testclient1 = Client.objects.create( - name="testclient1", throttle_rate="blahblah" - ) - testclient1.full_clean() - testclient1.delete() - - with self.assertRaises(DjValidationError): - testclient2 = Client.objects.create( - name="testclient2", - throttle_rate="2/minute", - ) - testclient2.full_clean() - testclient2.delete() - - class ExampleProjectViewsTestCase(CustomTestCase): def test_cached_api(self): self.assertEqual(AuthToken.objects.count(), 0)