Skip to content

Commit

Permalink
Use scopes for throttling views
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt Lenc committed Jun 2, 2014
1 parent fcb0267 commit 9fbeb57
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 36 deletions.
99 changes: 66 additions & 33 deletions user_management/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from django.contrib.auth.hashers import check_password
from django.contrib.auth.tokens import default_token_generator
from django.core import mail
from django.core.cache import cache
from django.core.urlresolvers import reverse
from django.utils.encoding import force_bytes
from django.utils.http import urlsafe_base64_encode
Expand All @@ -20,43 +21,15 @@
TEST_SERVER = 'http://testserver'


class TestThrottle(APIRequestTestCase):
class GetTokenTest(APIRequestTestCase):
view_class = views.GetToken

@patch('rest_framework.throttling.AnonRateThrottle.THROTTLE_RATES', new={
'anon': '1/day',
})
def test_user_auth_throttle(self):
auth_url = reverse('user_management_api:auth')
expected_status = status.HTTP_429_TOO_MANY_REQUESTS
def tearDown(self):
cache.clear()

request = APIRequestFactory().get(auth_url)

response = self.view_class.as_view()(request)
self.assertNotEqual(response.status_code, expected_status)

response = self.view_class.as_view()(request)
self.assertEqual(response.status_code, expected_status)

@patch('rest_framework.throttling.UserRateThrottle.THROTTLE_RATES', new={
'user': '1/day',
@patch('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES', new={
'logins': '1/minute',
})
def test_user_password_reset_throttle(self):
auth_url = reverse('user_management_api:password_reset')
expected_status = status.HTTP_429_TOO_MANY_REQUESTS

request = APIRequestFactory().get(auth_url)

response = self.view_class.as_view()(request)
self.assertNotEqual(response.status_code, expected_status)

response = self.view_class.as_view()(request)
self.assertEqual(response.status_code, expected_status)


class GetTokenTest(APIRequestTestCase):
view_class = views.GetToken

def test_post(self):
username = 'Test@example.com'
password = 'myepicstrongpassword'
Expand All @@ -68,6 +41,9 @@ def test_post(self):
response = view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK, msg=response.data)

@patch('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES', new={
'logins': '1/minute',
})
def test_post_username(self):
username = 'Test@example.com'
password = 'myepicstrongpassword'
Expand All @@ -79,6 +55,9 @@ def test_post_username(self):
response = view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK, msg=response.data)

@patch('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES', new={
'logins': '1/minute',
})
def test_delete(self):
user = UserFactory.create()
token = Token.objects.create(user=user)
Expand All @@ -90,11 +69,29 @@ def test_delete(self):
with self.assertRaises(Token.DoesNotExist):
Token.objects.get(pk=token.pk)

@patch('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES', new={
'logins': '1/minute',
})
def test_delete_no_token(self):
request = self.create_request('delete')
response = self.view_class.as_view()(request)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)

@patch('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES', new={
'logins': '1/minute',
})
def test_user_auth_throttle(self):
auth_url = reverse('user_management_api:auth')
expected_status = status.HTTP_429_TOO_MANY_REQUESTS

request = APIRequestFactory().get(auth_url)

response = self.view_class.as_view()(request)
self.assertNotEqual(response.status_code, expected_status)

response = self.view_class.as_view()(request)
self.assertEqual(response.status_code, expected_status)


class TestRegisterView(APIRequestTestCase):
view_class = views.UserRegister
Expand Down Expand Up @@ -189,6 +186,12 @@ def test_duplicate_email(self):
class TestPasswordResetEmail(APIRequestTestCase):
view_class = views.PasswordResetEmail

def tearDown(self):
cache.clear()

@patch('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES', new={
'passwords': '1/minute',
})
def test_existent_email(self):
email = 'exists@example.com'
user = UserFactory.create(email=email)
Expand All @@ -201,12 +204,18 @@ def test_existent_email(self):

send_email.assert_called_once_with(user)

@patch('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES', new={
'passwords': '1/minute',
})
def test_authenticated(self):
request = self.create_request('post', auth=True)
view = self.view_class.as_view()
response = view(request)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)

@patch('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES', new={
'passwords': '1/minute',
})
def test_non_existent_email(self):
email = 'doesnotexist@example.com'
UserFactory.create(email='exists@example.com')
Expand All @@ -219,12 +228,18 @@ def test_non_existent_email(self):

self.assertFalse(send_email.called)

@patch('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES', new={
'passwords': '1/minute',
})
def test_missing_email(self):
request = self.create_request('post', auth=False)
view = self.view_class.as_view()
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

@patch('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES', new={
'passwords': '1/minute',
})
def test_send_email(self):
email = 'test@example.com'
user = UserFactory.create(
Expand All @@ -243,6 +258,9 @@ def test_send_email(self):
self.assertIn('auth/password_reset/confirm/', sent_mail.body)
self.assertIn('https://', sent_mail.body)

@patch('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES', new={
'passwords': '1/minute',
})
def test_options(self):
"""Ensure information about email field is included in options request"""
request = self.create_request('options', auth=False)
Expand All @@ -264,6 +282,21 @@ def test_options(self):
expected_post_options,
)

@patch('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES', new={
'passwords': '1/minute',
})
def test_user_password_reset_throttle(self):
auth_url = reverse('user_management_api:password_reset')
expected_status = status.HTTP_429_TOO_MANY_REQUESTS

request = APIRequestFactory().get(auth_url)

response = self.view_class.as_view()(request)
self.assertNotEqual(response.status_code, expected_status)

response = self.view_class.as_view()(request)
self.assertEqual(response.status_code, expected_status)


class TestPasswordReset(APIRequestTestCase):
view_class = views.PasswordReset
Expand Down
8 changes: 5 additions & 3 deletions user_management/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.views import ObtainAuthToken
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.throttling import AnonRateThrottle, UserRateThrottle
from rest_framework.throttling import ScopedRateThrottle

from . import serializers, permissions

Expand All @@ -20,7 +20,8 @@

class GetToken(ObtainAuthToken):
renderer_classes = (renderers.JSONRenderer, renderers.BrowsableAPIRenderer)
throttle_classes = [AnonRateThrottle, UserRateThrottle]
throttle_classes = [ScopedRateThrottle]
throttle_scope = 'logins'

def delete(self, request, *args, **kwargs):
try:
Expand Down Expand Up @@ -69,6 +70,8 @@ class PasswordResetEmail(generics.GenericAPIView):
permission_classes = [permissions.IsNotAuthenticated]
template_name = 'user_management/password_reset_email.html'
serializer_class = serializers.PasswordResetEmailSerializer
throttle_classes = [ScopedRateThrottle]
throttle_scope = 'passwords'

def email_context(self, site, user):
return {
Expand Down Expand Up @@ -126,7 +129,6 @@ def initial(self, request, *args, **kwargs):

class PasswordReset(OneTimeUseAPIMixin, generics.UpdateAPIView):
permission_classes = [permissions.IsNotAuthenticated]
throttle_classes = [AnonRateThrottle, UserRateThrottle]
model = User
serializer_class = serializers.PasswordResetSerializer

Expand Down

0 comments on commit 9fbeb57

Please sign in to comment.