Skip to content

Commit

Permalink
Use AuthenticationFailed
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Etienne committed Jan 27, 2015
1 parent 46646fb commit 59f8df1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
12 changes: 6 additions & 6 deletions user_management/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def test_put_invalid_user(self):
request = self.create_request('put', auth=False)
view = self.view_class.as_view()
response = view(request, uidb64=invalid_uid)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_put_invalid_token(self):
user = UserFactory.create()
Expand All @@ -392,7 +392,7 @@ def test_put_invalid_token(self):
request = self.create_request('put', auth=False)
view = self.view_class.as_view()
response = view(request, uidb64=uid, token=token)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_full_stack_wrong_url(self):
user = UserFactory.create()
Expand All @@ -402,7 +402,7 @@ def test_full_stack_wrong_url(self):
view_name = 'user_management_api:password_reset_confirm'
url = reverse(view_name, kwargs={'uidb64': uid, 'token': token})
response = self.client.put(url)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

self.assertTrue(hasattr(response, 'accepted_renderer'))

Expand Down Expand Up @@ -551,7 +551,7 @@ def test_post_invalid_user(self):
request = self.create_request('post')
view = self.view_class.as_view()
response = view(request, uidb64=invalid_uid)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_post_invalid_token(self):
user = UserFactory.create()
Expand All @@ -562,7 +562,7 @@ def test_post_invalid_token(self):
request = self.create_request('post')
view = self.view_class.as_view()
response = view(request, uidb64=uid, token=token)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_post_verified_email(self):
user = UserFactory.create(email_verification_required=False)
Expand Down Expand Up @@ -599,7 +599,7 @@ def test_full_stack_wrong_url(self):
view_name = 'user_management_api:verify_user'
url = reverse(view_name, kwargs={'uidb64': uid, 'token': token})
response = self.client.post(url)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

self.assertTrue(hasattr(response, 'accepted_renderer'))

Expand Down
8 changes: 4 additions & 4 deletions user_management/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rest_framework import generics, renderers, response, status, views
from rest_framework.authentication import get_authorization_header
from rest_framework.authtoken.views import ObtainAuthToken
from rest_framework.exceptions import PermissionDenied
from rest_framework.exceptions import AuthenticationFailed
from rest_framework.permissions import AllowAny, IsAuthenticated

from . import models, permissions, serializers, throttling
Expand Down Expand Up @@ -152,12 +152,12 @@ def initial(self, request, *args, **kwargs):
try:
self.user = User.objects.get(pk=uid)
except User.DoesNotExist:
msg = _('Invalid or expired token.')
raise PermissionDenied(detail=msg)
raise AuthenticationFailed()

token = kwargs['token']
if not default_token_generator.check_token(self.user, token):
raise Http404()
msg = _('Invalid or expired token.')
raise AuthenticationFailed(detail=msg)

return super(OneTimeUseAPIMixin, self).initial(
request,
Expand Down

0 comments on commit 59f8df1

Please sign in to comment.