Skip to content

Commit

Permalink
Merge pull request #99 from incuna/invalid-token-401
Browse files Browse the repository at this point in the history
Return HTTP Bad Request 400 code when the token and uidb64 are not valid
  • Loading branch information
LilyFoote committed Jan 28, 2015
2 parents 9711b19 + a827f43 commit 57abfd2
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 13 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## v5.0.0 (Upcoming)

* Return 400 instead of 401 when `uidb64` or `token` is expired or not valid.

## v4.2.0

* Return `AuthenticationFailed` `401` instead of `404` `NotFound` for not valid
Expand Down
9 changes: 9 additions & 0 deletions user_management/api/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import status
from rest_framework.exceptions import APIException


class InvalidExpiredToken(APIException):
"""Exception to confirm an account."""
status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Invalid or expired token.')
15 changes: 15 additions & 0 deletions user_management/api/tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from django.test import TestCase
from rest_framework.status import HTTP_400_BAD_REQUEST

from ..exceptions import InvalidExpiredToken


class InvalidExpiredTokenTest(TestCase):
"""Assert `InvalidExpiredToken` behaves as expected."""
def test_raise(self):
"""Assert `InvalidExpiredToken` can be raised."""
with self.assertRaises(InvalidExpiredToken) as error:
raise InvalidExpiredToken
self.assertEqual(error.exception.status_code, HTTP_400_BAD_REQUEST)
message = error.exception.detail.format()
self.assertEqual(message, 'Invalid or expired token.')
12 changes: 6 additions & 6 deletions user_management/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def test_put_invalid_user(self):
request = self.create_request('put', auth=False)
view = self.view_class.as_view()
response = view(request, uidb64=invalid_uid)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_put_invalid_token(self):
user = UserFactory.create()
Expand All @@ -392,7 +392,7 @@ def test_put_invalid_token(self):
request = self.create_request('put', auth=False)
view = self.view_class.as_view()
response = view(request, uidb64=uid, token=token)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_full_stack_wrong_url(self):
user = UserFactory.create()
Expand All @@ -402,7 +402,7 @@ def test_full_stack_wrong_url(self):
view_name = 'user_management_api:password_reset_confirm'
url = reverse(view_name, kwargs={'uidb64': uid, 'token': token})
response = self.client.put(url)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

self.assertTrue(hasattr(response, 'accepted_renderer'))

Expand Down Expand Up @@ -551,7 +551,7 @@ def test_post_invalid_user(self):
request = self.create_request('post')
view = self.view_class.as_view()
response = view(request, uidb64=invalid_uid)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_post_invalid_token(self):
user = UserFactory.create()
Expand All @@ -562,7 +562,7 @@ def test_post_invalid_token(self):
request = self.create_request('post')
view = self.view_class.as_view()
response = view(request, uidb64=uid, token=token)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_post_verified_email(self):
user = UserFactory.create(email_verification_required=False)
Expand Down Expand Up @@ -599,7 +599,7 @@ def test_full_stack_wrong_url(self):
view_name = 'user_management_api:verify_user'
url = reverse(view_name, kwargs={'uidb64': uid, 'token': token})
response = self.client.post(url)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

self.assertTrue(hasattr(response, 'accepted_renderer'))

Expand Down
10 changes: 3 additions & 7 deletions user_management/api/views.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from django.contrib.auth import get_user_model
from django.contrib.auth.tokens import default_token_generator
from django.contrib.sites.models import Site
from django.http import Http404
from django.utils.encoding import force_bytes, force_text
from django.utils.http import urlsafe_base64_decode, urlsafe_base64_encode
from django.utils.translation import ugettext_lazy as _
from incuna_mail import send
from rest_framework import generics, renderers, response, status, views
from rest_framework.authentication import get_authorization_header
from rest_framework.authtoken.views import ObtainAuthToken
from rest_framework.exceptions import AuthenticationFailed
from rest_framework.permissions import AllowAny, IsAuthenticated

from . import models, permissions, serializers, throttling
from . import exceptions, models, permissions, serializers, throttling


User = get_user_model()
Expand Down Expand Up @@ -145,20 +143,18 @@ def send_email(self, user):


class OneTimeUseAPIMixin(object):
message = _('Invalid or expired token.')

def initial(self, request, *args, **kwargs):
uidb64 = kwargs['uidb64']
uid = urlsafe_base64_decode(force_text(uidb64))

try:
self.user = User.objects.get(pk=uid)
except User.DoesNotExist:
raise AuthenticationFailed(detail=self.message)
raise exceptions.InvalidExpiredToken()

token = kwargs['token']
if not default_token_generator.check_token(self.user, token):
raise AuthenticationFailed(detail=self.message)
raise exceptions.InvalidExpiredToken()

return super(OneTimeUseAPIMixin, self).initial(
request,
Expand Down

0 comments on commit 57abfd2

Please sign in to comment.