Skip to content

Commit

Permalink
Return 400 when token is invalid
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Etienne committed Jan 28, 2015
1 parent 9711b19 commit b7987fe
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
12 changes: 6 additions & 6 deletions user_management/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def test_put_invalid_user(self):
request = self.create_request('put', auth=False)
view = self.view_class.as_view()
response = view(request, uidb64=invalid_uid)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_put_invalid_token(self):
user = UserFactory.create()
Expand All @@ -392,7 +392,7 @@ def test_put_invalid_token(self):
request = self.create_request('put', auth=False)
view = self.view_class.as_view()
response = view(request, uidb64=uid, token=token)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_full_stack_wrong_url(self):
user = UserFactory.create()
Expand All @@ -402,7 +402,7 @@ def test_full_stack_wrong_url(self):
view_name = 'user_management_api:password_reset_confirm'
url = reverse(view_name, kwargs={'uidb64': uid, 'token': token})
response = self.client.put(url)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

self.assertTrue(hasattr(response, 'accepted_renderer'))

Expand Down Expand Up @@ -551,7 +551,7 @@ def test_post_invalid_user(self):
request = self.create_request('post')
view = self.view_class.as_view()
response = view(request, uidb64=invalid_uid)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_post_invalid_token(self):
user = UserFactory.create()
Expand All @@ -562,7 +562,7 @@ def test_post_invalid_token(self):
request = self.create_request('post')
view = self.view_class.as_view()
response = view(request, uidb64=uid, token=token)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_post_verified_email(self):
user = UserFactory.create(email_verification_required=False)
Expand Down Expand Up @@ -599,7 +599,7 @@ def test_full_stack_wrong_url(self):
view_name = 'user_management_api:verify_user'
url = reverse(view_name, kwargs={'uidb64': uid, 'token': token})
response = self.client.post(url)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

self.assertTrue(hasattr(response, 'accepted_renderer'))

Expand Down
7 changes: 3 additions & 4 deletions user_management/api/views.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from django.contrib.auth import get_user_model
from django.contrib.auth.tokens import default_token_generator
from django.contrib.sites.models import Site
from django.http import Http404
from django.utils.encoding import force_bytes, force_text
from django.utils.http import urlsafe_base64_decode, urlsafe_base64_encode
from django.utils.translation import ugettext_lazy as _
from incuna_mail import send
from rest_framework import generics, renderers, response, status, views
from rest_framework.authentication import get_authorization_header
from rest_framework.authtoken.views import ObtainAuthToken
from rest_framework.exceptions import AuthenticationFailed
from rest_framework.exceptions import ParseError
from rest_framework.permissions import AllowAny, IsAuthenticated

from . import models, permissions, serializers, throttling
Expand Down Expand Up @@ -154,11 +153,11 @@ def initial(self, request, *args, **kwargs):
try:
self.user = User.objects.get(pk=uid)
except User.DoesNotExist:
raise AuthenticationFailed(detail=self.message)
raise ParseError(detail=self.message)

token = kwargs['token']
if not default_token_generator.check_token(self.user, token):
raise AuthenticationFailed(detail=self.message)
raise ParseError(detail=self.message)

return super(OneTimeUseAPIMixin, self).initial(
request,
Expand Down

0 comments on commit b7987fe

Please sign in to comment.