Skip to content

Commit e353acc

Browse files
author
Ian Foote
committed
Merge pull request #71 from incuna/options-rate-delimit
OPTIONS requests shouldn't be rate limited
2 parents 54fb754 + c418ca7 commit e353acc

File tree

3 files changed

+119
-42
lines changed

3 files changed

+119
-42
lines changed

user_management/api/tests/test_views.py

Lines changed: 102 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,21 @@ def tearDown(self):
3232
def test_post(self):
3333
username = 'Test@example.com'
3434
password = 'myepicstrongpassword'
35-
UserFactory.create(email=username.lower(), password=password, is_active=True)
35+
UserFactory.create(
36+
email=username.lower(),
37+
password=password,
38+
is_active=True,
39+
)
3640

3741
data = {'username': username, 'password': password}
3842
request = self.create_request('post', auth=False, data=data)
3943
view = self.view_class.as_view()
4044
response = view(request)
41-
self.assertEqual(response.status_code, status.HTTP_200_OK, msg=response.data)
45+
self.assertEqual(
46+
response.status_code,
47+
status.HTTP_200_OK,
48+
msg=response.data,
49+
)
4250

4351
def test_post_username(self):
4452
username = 'Test@example.com'
@@ -49,7 +57,11 @@ def test_post_username(self):
4957
request = self.create_request('post', auth=False, data=data)
5058
view = self.view_class.as_view()
5159
response = view(request)
52-
self.assertEqual(response.status_code, status.HTTP_200_OK, msg=response.data)
60+
self.assertEqual(
61+
response.status_code,
62+
status.HTTP_200_OK,
63+
msg=response.data,
64+
)
5365

5466
def test_delete(self):
5567
user = UserFactory.create()
@@ -153,7 +165,9 @@ def test_unauthenticated_user_post(self):
153165
@patch('user_management.api.serializers.RegistrationSerializer.Meta.model',
154166
new=BasicUser)
155167
def test_unauthenticated_user_post_no_verify_email(self):
156-
"""An email should not be sent if email_verification_required is False."""
168+
"""
169+
An email should not be sent if email_verification_required is False.
170+
"""
157171
request = self.create_request('post', auth=False, data=self.data)
158172

159173
response = self.view_class.as_view()(request)
@@ -210,14 +224,50 @@ def test_existent_email(self):
210224
email = 'exists@example.com'
211225
user = UserFactory.create(email=email)
212226

213-
request = self.create_request('post', data={'email': email}, auth=False)
227+
request = self.create_request(
228+
'post',
229+
data={'email': email},
230+
auth=False,
231+
)
214232
view = self.view_class.as_view()
215233
with patch.object(self.view_class, 'send_email') as send_email:
216234
response = view(request)
217235
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
218236

219237
send_email.assert_called_once_with(user)
220238

239+
def assert_post_returns_status(self, view, data, expected_status):
240+
request = self.create_request('post', data=data, auth=False)
241+
response = view(request)
242+
self.assertEqual(response.status_code, expected_status)
243+
244+
def test_post_rate_limit(self):
245+
"""
246+
Ensure the post requests are rate limited.
247+
248+
The PasswordResetRateThrottle sets a limit of 3/hour requests.
249+
"""
250+
email = 'exists@example.com'
251+
UserFactory.create(email=email)
252+
view = self.view_class.as_view()
253+
rate_limit = 3
254+
255+
# Test the the first 3 requests aren't limited.
256+
data = {'email': email}
257+
for i in range(rate_limit):
258+
self.assert_post_returns_status(
259+
view,
260+
data,
261+
status.HTTP_204_NO_CONTENT,
262+
)
263+
264+
# Test that the 4th request is throttled.
265+
self.assert_post_returns_status(
266+
view,
267+
data,
268+
status.HTTP_429_TOO_MANY_REQUESTS,
269+
)
270+
221271
def test_authenticated(self):
222272
request = self.create_request('post', auth=True)
223273
view = self.view_class.as_view()
@@ -228,7 +278,11 @@ def test_non_existent_email(self):
228278
email = 'doesnotexist@example.com'
229279
UserFactory.create(email='exists@example.com')
230280

231-
request = self.create_request('post', data={'email': email}, auth=False)
281+
request = self.create_request(
282+
'post',
283+
data={'email': email},
284+
auth=False,
285+
)
232286
view = self.view_class.as_view()
233287
with patch.object(self.view_class, 'send_email') as send_email:
234288
response = view(request)
@@ -261,7 +315,11 @@ def test_send_email(self):
261315
self.assertIn('https://', sent_mail.body)
262316

263317
def test_options(self):
264-
"""Ensure information about email field is included in options request"""
318+
"""
319+
Ensure information about email field is included in options request.
320+
321+
Ensure that the options request isn't rate limited.
322+
"""
265323
request = self.create_request('options', auth=False)
266324
view = self.view_class.as_view()
267325
response = view(request)
@@ -281,38 +339,25 @@ def test_options(self):
281339
expected_post_options,
282340
)
283341

284-
def test_default_user_password_reset_throttle(self):
285-
default_rate = 3
286-
auth_url = reverse('user_management_api:password_reset')
287-
expected_status = status.HTTP_429_TOO_MANY_REQUESTS
342+
def test_options_delimit(self):
343+
"""
344+
Ensure information about email field is included in options request.
288345
289-
request = APIRequestFactory().get(auth_url)
346+
Ensure that the options request isn't rate limited.
347+
"""
348+
request = self.create_request('options', auth=False)
290349
view = self.view_class.as_view()
291-
292-
# make all but one of our allowed requests
293-
for i in range(default_rate - 1):
350+
default_rate = 3
351+
# First three requests
352+
for i in range(default_rate):
294353
view(request)
295354

296-
response = view(request) # our last allowed request
297-
self.assertNotEqual(response.status_code, expected_status)
298-
299-
response = view(request) # our throttled request
300-
self.assertEqual(response.status_code, expected_status)
301-
302-
@patch('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES', new={
303-
'passwords': '1/minute',
304-
})
305-
def test_user_password_reset_throttle(self):
306-
auth_url = reverse('user_management_api:password_reset')
307-
expected_status = status.HTTP_429_TOO_MANY_REQUESTS
308-
309-
request = APIRequestFactory().get(auth_url)
310-
311-
response = self.view_class.as_view()(request)
312-
self.assertNotEqual(response.status_code, expected_status)
313-
314-
response = self.view_class.as_view()(request)
315-
self.assertEqual(response.status_code, expected_status)
355+
# Assert fourth request is not throttled
356+
response = view(request)
357+
self.assertNotEqual(
358+
response.status_code,
359+
status.HTTP_429_TOO_MANY_REQUESTS,
360+
)
316361

317362

318363
class TestPasswordReset(APIRequestTestCase):
@@ -360,7 +405,10 @@ def test_password_mismatch(self):
360405

361406
request = self.create_request(
362407
'put',
363-
data={'new_password': new_password, 'new_password2': invalid_password},
408+
data={
409+
'new_password': new_password,
410+
'new_password2': invalid_password,
411+
},
364412
auth=False,
365413
)
366414
view = self.view_class.as_view()
@@ -674,7 +722,10 @@ class TestUserList(APIRequestTestCase):
674722
view_class = views.UserList
675723

676724
def expected_data(self, user):
677-
url = reverse('user_management_api:user_detail', kwargs={'pk': user.pk})
725+
url = reverse(
726+
'user_management_api:user_detail',
727+
kwargs={'pk': user.pk},
728+
)
678729
expected = {
679730
'url': TEST_SERVER + url,
680731
'name': user.name,
@@ -734,7 +785,10 @@ def setUp(self):
734785
self.user, self.other_user = UserFactory.create_batch(2)
735786

736787
def expected_data(self, user):
737-
url = reverse('user_management_api:user_detail', kwargs={'pk': user.pk})
788+
url = reverse(
789+
'user_management_api:user_detail',
790+
kwargs={'pk': user.pk},
791+
)
738792
expected = {
739793
'url': TEST_SERVER + url,
740794
'name': user.name,
@@ -787,7 +841,11 @@ def test_put(self):
787841
view = self.view_class.as_view()
788842

789843
response = view(request, pk=self.other_user.pk)
790-
self.assertEqual(response.status_code, status.HTTP_200_OK, response.data)
844+
self.assertEqual(
845+
response.status_code,
846+
status.HTTP_200_OK,
847+
response.data,
848+
)
791849

792850
user = User.objects.get(pk=self.other_user.pk)
793851
self.assertEqual(user.name, data['name'])
@@ -803,7 +861,11 @@ def test_patch(self):
803861
view = self.view_class.as_view()
804862

805863
response = view(request, pk=self.other_user.pk)
806-
self.assertEqual(response.status_code, status.HTTP_200_OK, response.data)
864+
self.assertEqual(
865+
response.status_code,
866+
status.HTTP_200_OK,
867+
response.data,
868+
)
807869

808870
user = User.objects.get(pk=self.other_user.pk)
809871
self.assertEqual(user.name, data['name'])

user_management/api/throttling.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,11 @@ class LoginRateThrottle(DefaultRateMixin, ScopedRateThrottle):
1515

1616
class PasswordResetRateThrottle(DefaultRateMixin, ScopedRateThrottle):
1717
default_rate = '3/hour'
18+
19+
def allow_request(self, request, view):
20+
if request.META['REQUEST_METHOD'] != 'POST':
21+
return True
22+
return super(PasswordResetRateThrottle, self).allow_request(
23+
request,
24+
view,
25+
)

user_management/api/views.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ class UserRegister(generics.CreateAPIView):
3737
permission_classes = [permissions.IsNotAuthenticated]
3838

3939
def create(self, request, *args, **kwargs):
40-
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
40+
serializer = self.get_serializer(
41+
data=request.DATA,
42+
files=request.FILES,
43+
)
4144
if serializer.is_valid():
4245
return self.is_valid(serializer)
4346
return self.is_invalid(serializer)
@@ -123,7 +126,11 @@ def initial(self, request, *args, **kwargs):
123126
if not default_token_generator.check_token(self.user, token):
124127
raise Http404()
125128

126-
return super(OneTimeUseAPIMixin, self).initial(request, *args, **kwargs)
129+
return super(OneTimeUseAPIMixin, self).initial(
130+
request,
131+
*args,
132+
**kwargs
133+
)
127134

128135

129136
class PasswordReset(OneTimeUseAPIMixin, generics.UpdateAPIView):

0 commit comments

Comments
 (0)