Navigation Menu

Skip to content

Commit

Permalink
Cleans up refresh logic + Adds unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
iMerica committed Nov 17, 2020
1 parent 5d6e8ca commit 63bd99a
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 37 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -37,6 +37,7 @@ pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
coverage_html/
.tox/
.coverage
.coverage.*
Expand Down
71 changes: 58 additions & 13 deletions dj_rest_auth/tests/test_api.py
Expand Up @@ -11,15 +11,17 @@

from dj_rest_auth.registration.app_settings import register_permission_classes
from dj_rest_auth.registration.views import RegisterView

from .mixins import CustomPermissionClass, TestsMixin

try:
from django.urls import reverse
except ImportError:
from django.core.urlresolvers import reverse

from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
from jwt import decode as decode_jwt
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer


class TESTTokenObtainPairSerializer(TokenObtainPairSerializer):
@classmethod
Expand Down Expand Up @@ -71,8 +73,8 @@ def setUp(self):

def _generate_uid_and_token(self, user):
result = {}
from django.utils.encoding import force_bytes
from django.contrib.auth.tokens import default_token_generator
from django.utils.encoding import force_bytes
from django.utils.http import urlsafe_base64_encode

result['uid'] = urlsafe_base64_encode(force_bytes(user.pk))
Expand Down Expand Up @@ -559,6 +561,20 @@ def test_logout_jwt_deletes_cookie(self):
resp = self.post(self.logout_url, status=200)
self.assertEqual('', resp.cookies.get('jwt-auth').value)

@override_settings(JWT_AUTH_REFRESH_COOKIE='jwt-auth-refresh')
@override_settings(REST_USE_JWT=True)
@override_settings(JWT_AUTH_COOKIE='jwt-auth')
def test_logout_jwt_deletes_cookie_refresh(self):
payload = {
"username": self.USERNAME,
"password": self.PASS
}
get_user_model().objects.create_user(self.USERNAME, '', self.PASS)
self.post(self.login_url, data=payload, status_code=200)
resp = self.post(self.logout_url, status=200)
self.assertEqual('', resp.cookies.get('jwt-auth').value)
self.assertEqual('', resp.cookies.get('jwt-auth-refresh').value)

@override_settings(REST_USE_JWT=True)
@override_settings(JWT_AUTH_COOKIE='jwt-auth')
@override_settings(REST_FRAMEWORK=dict(
Expand Down Expand Up @@ -604,21 +620,15 @@ def test_blacklisting(self):
resp = self.post(self.login_url, data=payload, status_code=200)
token = resp.data['refresh_token']
# test refresh token not included in request data
resp = self.post(self.logout_url, status=200)
self.assertEqual(resp.status_code, 401)
self.post(self.logout_url, status_code=401)
# test token is invalid or expired
resp = self.post(self.logout_url, status=200, data={'refresh': '1'})
self.assertEqual(resp.status_code, 401)
self.post(self.logout_url, status_code=401, data={'refresh': '1'})
# test successful logout
resp = self.post(self.logout_url, status=200, data={'refresh': token})
self.assertEqual(resp.status_code, 200)
self.post(self.logout_url, status_code=200, data={'refresh': token})
# test token is blacklisted
resp = self.post(self.logout_url, status=200, data={'refresh': token})
self.assertEqual(resp.status_code, 401)
self.post(self.logout_url, status_code=401, data={'refresh': token})
# test other TokenError, AttributeError, TypeError (invalid format)
resp = self.post(self.logout_url, status=200, data=json.dumps({'refresh': token}))
self.assertEqual(resp.status_code, 500)

self.post(self.logout_url, status_code=500, data=json.dumps({'refresh': token}))

@override_settings(REST_USE_JWT=True)
@override_settings(JWT_AUTH_COOKIE=None)
Expand Down Expand Up @@ -868,3 +878,38 @@ def test_csrf_w_login_csrf_enforcement_2(self):
resp = client.post('/protected-view/', csrfparam)
self.assertEquals(resp.status_code, 200)

@override_settings(JWT_AUTH_RETURN_EXPIRATION=True)
@override_settings(REST_USE_JWT=True)
@override_settings(ACCOUNT_LOGOUT_ON_GET=True)
def test_return_expiration(self):
payload = {
"username": self.USERNAME,
"password": self.PASS
}

# create user
get_user_model().objects.create_user(self.USERNAME, '', self.PASS)

resp = self.post(self.login_url, data=payload, status_code=200)
self.assertIn('access_token_expiration', resp.data.keys())
self.assertIn('refresh_token_expiration', resp.data.keys())

@override_settings(JWT_AUTH_RETURN_EXPIRATION=True)
@override_settings(REST_USE_JWT=True)
@override_settings(JWT_AUTH_COOKIE='xxx')
@override_settings(ACCOUNT_LOGOUT_ON_GET=True)
@override_settings(JWT_AUTH_REFRESH_COOKIE='refresh-xxx')
@override_settings(JWT_AUTH_REFRESH_COOKIE_PATH='/foo/bar')
def test_refresh_cookie_name(self):
payload = {
"username": self.USERNAME,
"password": self.PASS
}

# create user
get_user_model().objects.create_user(self.USERNAME, '', self.PASS)

resp = self.post(self.login_url, data=payload, status_code=200)
self.assertIn('xxx', resp.cookies.keys())
self.assertIn('refresh-xxx', resp.cookies.keys())
self.assertEqual(resp.cookies.get('refresh-xxx').get('path'), '/foo/bar')
48 changes: 24 additions & 24 deletions dj_rest_auth/views.py
Expand Up @@ -2,6 +2,7 @@
from django.contrib.auth import get_user_model
from django.contrib.auth import login as django_login
from django.contrib.auth import logout as django_logout
from django.utils import timezone
from django.core.exceptions import ObjectDoesNotExist
from django.utils.decorators import method_decorator
from django.utils.translation import ugettext_lazy as _
Expand Down Expand Up @@ -76,12 +77,12 @@ def login(self):
def get_response(self):
serializer_class = self.get_response_serializer()

access_token_expiration = None
refresh_token_expiration = None
if getattr(settings, 'REST_USE_JWT', False):
from rest_framework_simplejwt.settings import api_settings as jwt_settings
from datetime import datetime

access_token_expiration = (datetime.utcnow() + jwt_settings.ACCESS_TOKEN_LIFETIME)
refresh_token_expiration = (datetime.utcnow() + jwt_settings.REFRESH_TOKEN_LIFETIME)
access_token_expiration = (timezone.now() + jwt_settings.ACCESS_TOKEN_LIFETIME)
refresh_token_expiration = (timezone.now() + jwt_settings.REFRESH_TOKEN_LIFETIME)
return_expiration_times = getattr(settings, 'JWT_AUTH_RETURN_EXPIRATION', False)

data = {
Expand Down Expand Up @@ -170,8 +171,10 @@ def logout(self, request):
if getattr(settings, 'REST_SESSION_LOGIN', True):
django_logout(request)

response = Response({"detail": _("Successfully logged out.")},
status=status.HTTP_200_OK)
response = Response(
{"detail": _("Successfully logged out.")},
status=status.HTTP_200_OK
)

if getattr(settings, 'REST_USE_JWT', False):
# NOTE: this import occurs here rather than at the top level
Expand All @@ -183,41 +186,38 @@ def logout(self, request):
cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
if cookie_name:
response.delete_cookie(cookie_name)

refresh_cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None)
if refresh_cookie_name:
response.delete_cookie(refresh_cookie_name)

elif 'rest_framework_simplejwt.token_blacklist' in settings.INSTALLED_APPS:
if 'rest_framework_simplejwt.token_blacklist' in settings.INSTALLED_APPS:
# add refresh token to blacklist
try:
token = RefreshToken(request.data['refresh'])
token.blacklist()

except KeyError:
response = Response({"detail": _("Refresh token was not included in request data.")},
status=status.HTTP_401_UNAUTHORIZED)

response.data = {"detail": _("Refresh token was not included in request data.")}
response.status_code =status.HTTP_401_UNAUTHORIZED
except (TokenError, AttributeError, TypeError) as error:
if hasattr(error, 'args'):
if 'Token is blacklisted' in error.args or 'Token is invalid or expired' in error.args:
response = Response({"detail": _(error.args[0])},
status=status.HTTP_401_UNAUTHORIZED)

response.data = {"detail": _(error.args[0])}
response.status_code = status.HTTP_401_UNAUTHORIZED
else:
response = Response({"detail": _("An error has occurred.")},
status=status.HTTP_500_INTERNAL_SERVER_ERROR)
response.data = {"detail": _("An error has occurred.")}
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR

else:
response = Response({"detail": _("An error has occurred.")},
status=status.HTTP_500_INTERNAL_SERVER_ERROR)
response.data = {"detail": _("An error has occurred.")}
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR

else:
response = Response({
"detail": _("Neither cookies or blacklist are enabled, so the token has not been deleted server "
"side. Please make sure the token is deleted client side."
)}, status=status.HTTP_200_OK)

message = _(
"Neither cookies or blacklist are enabled, so the token "
"has not been deleted server side. Please make sure the token is deleted client side."
)
response.data = {"detail": message}
response.status_code = status.HTTP_200_OK
return response


Expand Down

0 comments on commit 63bd99a

Please sign in to comment.