@@ -32,13 +32,21 @@ def tearDown(self):
3232 def test_post (self ):
3333 username = 'Test@example.com'
3434 password = 'myepicstrongpassword'
35- UserFactory .create (email = username .lower (), password = password , is_active = True )
35+ UserFactory .create (
36+ email = username .lower (),
37+ password = password ,
38+ is_active = True ,
39+ )
3640
3741 data = {'username' : username , 'password' : password }
3842 request = self .create_request ('post' , auth = False , data = data )
3943 view = self .view_class .as_view ()
4044 response = view (request )
41- self .assertEqual (response .status_code , status .HTTP_200_OK , msg = response .data )
45+ self .assertEqual (
46+ response .status_code ,
47+ status .HTTP_200_OK ,
48+ msg = response .data ,
49+ )
4250
4351 def test_post_username (self ):
4452 username = 'Test@example.com'
@@ -49,7 +57,11 @@ def test_post_username(self):
4957 request = self .create_request ('post' , auth = False , data = data )
5058 view = self .view_class .as_view ()
5159 response = view (request )
52- self .assertEqual (response .status_code , status .HTTP_200_OK , msg = response .data )
60+ self .assertEqual (
61+ response .status_code ,
62+ status .HTTP_200_OK ,
63+ msg = response .data ,
64+ )
5365
5466 def test_delete (self ):
5567 user = UserFactory .create ()
@@ -153,7 +165,9 @@ def test_unauthenticated_user_post(self):
153165 @patch ('user_management.api.serializers.RegistrationSerializer.Meta.model' ,
154166 new = BasicUser )
155167 def test_unauthenticated_user_post_no_verify_email (self ):
156- """An email should not be sent if email_verification_required is False."""
168+ """
169+ An email should not be sent if email_verification_required is False.
170+ """
157171 request = self .create_request ('post' , auth = False , data = self .data )
158172
159173 response = self .view_class .as_view ()(request )
@@ -210,14 +224,50 @@ def test_existent_email(self):
210224 email = 'exists@example.com'
211225 user = UserFactory .create (email = email )
212226
213- request = self .create_request ('post' , data = {'email' : email }, auth = False )
227+ request = self .create_request (
228+ 'post' ,
229+ data = {'email' : email },
230+ auth = False ,
231+ )
214232 view = self .view_class .as_view ()
215233 with patch .object (self .view_class , 'send_email' ) as send_email :
216234 response = view (request )
217235 self .assertEqual (response .status_code , status .HTTP_204_NO_CONTENT )
218236
219237 send_email .assert_called_once_with (user )
220238
239+ def assert_post_returns_status (self , view , data , expected_status ):
240+ request = self .create_request ('post' , data = data , auth = False )
241+ response = view (request )
242+ self .assertEqual (response .status_code , expected_status )
243+
244+ def test_post_rate_limit (self ):
245+ """
246+ Ensure the post requests are rate limited.
247+
248+ The PasswordResetRateThrottle sets a limit of 3/hour requests.
249+ """
250+ email = 'exists@example.com'
251+ UserFactory .create (email = email )
252+ view = self .view_class .as_view ()
253+ rate_limit = 3
254+
255+ # Test the the first 3 requests aren't limited.
256+ data = {'email' : email }
257+ for i in range (rate_limit ):
258+ self .assert_post_returns_status (
259+ view ,
260+ data ,
261+ status .HTTP_204_NO_CONTENT ,
262+ )
263+
264+ # Test that the 4th request is throttled.
265+ self .assert_post_returns_status (
266+ view ,
267+ data ,
268+ status .HTTP_429_TOO_MANY_REQUESTS ,
269+ )
270+
221271 def test_authenticated (self ):
222272 request = self .create_request ('post' , auth = True )
223273 view = self .view_class .as_view ()
@@ -228,7 +278,11 @@ def test_non_existent_email(self):
228278 email = 'doesnotexist@example.com'
229279 UserFactory .create (email = 'exists@example.com' )
230280
231- request = self .create_request ('post' , data = {'email' : email }, auth = False )
281+ request = self .create_request (
282+ 'post' ,
283+ data = {'email' : email },
284+ auth = False ,
285+ )
232286 view = self .view_class .as_view ()
233287 with patch .object (self .view_class , 'send_email' ) as send_email :
234288 response = view (request )
@@ -261,7 +315,11 @@ def test_send_email(self):
261315 self .assertIn ('https://' , sent_mail .body )
262316
263317 def test_options (self ):
264- """Ensure information about email field is included in options request"""
318+ """
319+ Ensure information about email field is included in options request.
320+
321+ Ensure that the options request isn't rate limited.
322+ """
265323 request = self .create_request ('options' , auth = False )
266324 view = self .view_class .as_view ()
267325 response = view (request )
@@ -281,38 +339,25 @@ def test_options(self):
281339 expected_post_options ,
282340 )
283341
284- def test_default_user_password_reset_throttle (self ):
285- default_rate = 3
286- auth_url = reverse ('user_management_api:password_reset' )
287- expected_status = status .HTTP_429_TOO_MANY_REQUESTS
342+ def test_options_delimit (self ):
343+ """
344+ Ensure information about email field is included in options request.
288345
289- request = APIRequestFactory ().get (auth_url )
346+ Ensure that the options request isn't rate limited.
347+ """
348+ request = self .create_request ('options' , auth = False )
290349 view = self .view_class .as_view ()
291-
292- # make all but one of our allowed requests
293- for i in range (default_rate - 1 ):
350+ default_rate = 3
351+ # First three requests
352+ for i in range (default_rate ):
294353 view (request )
295354
296- response = view (request ) # our last allowed request
297- self .assertNotEqual (response .status_code , expected_status )
298-
299- response = view (request ) # our throttled request
300- self .assertEqual (response .status_code , expected_status )
301-
302- @patch ('rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES' , new = {
303- 'passwords' : '1/minute' ,
304- })
305- def test_user_password_reset_throttle (self ):
306- auth_url = reverse ('user_management_api:password_reset' )
307- expected_status = status .HTTP_429_TOO_MANY_REQUESTS
308-
309- request = APIRequestFactory ().get (auth_url )
310-
311- response = self .view_class .as_view ()(request )
312- self .assertNotEqual (response .status_code , expected_status )
313-
314- response = self .view_class .as_view ()(request )
315- self .assertEqual (response .status_code , expected_status )
355+ # Assert fourth request is not throttled
356+ response = view (request )
357+ self .assertNotEqual (
358+ response .status_code ,
359+ status .HTTP_429_TOO_MANY_REQUESTS ,
360+ )
316361
317362
318363class TestPasswordReset (APIRequestTestCase ):
@@ -360,7 +405,10 @@ def test_password_mismatch(self):
360405
361406 request = self .create_request (
362407 'put' ,
363- data = {'new_password' : new_password , 'new_password2' : invalid_password },
408+ data = {
409+ 'new_password' : new_password ,
410+ 'new_password2' : invalid_password ,
411+ },
364412 auth = False ,
365413 )
366414 view = self .view_class .as_view ()
@@ -674,7 +722,10 @@ class TestUserList(APIRequestTestCase):
674722 view_class = views .UserList
675723
676724 def expected_data (self , user ):
677- url = reverse ('user_management_api:user_detail' , kwargs = {'pk' : user .pk })
725+ url = reverse (
726+ 'user_management_api:user_detail' ,
727+ kwargs = {'pk' : user .pk },
728+ )
678729 expected = {
679730 'url' : TEST_SERVER + url ,
680731 'name' : user .name ,
@@ -734,7 +785,10 @@ def setUp(self):
734785 self .user , self .other_user = UserFactory .create_batch (2 )
735786
736787 def expected_data (self , user ):
737- url = reverse ('user_management_api:user_detail' , kwargs = {'pk' : user .pk })
788+ url = reverse (
789+ 'user_management_api:user_detail' ,
790+ kwargs = {'pk' : user .pk },
791+ )
738792 expected = {
739793 'url' : TEST_SERVER + url ,
740794 'name' : user .name ,
@@ -787,7 +841,11 @@ def test_put(self):
787841 view = self .view_class .as_view ()
788842
789843 response = view (request , pk = self .other_user .pk )
790- self .assertEqual (response .status_code , status .HTTP_200_OK , response .data )
844+ self .assertEqual (
845+ response .status_code ,
846+ status .HTTP_200_OK ,
847+ response .data ,
848+ )
791849
792850 user = User .objects .get (pk = self .other_user .pk )
793851 self .assertEqual (user .name , data ['name' ])
@@ -803,7 +861,11 @@ def test_patch(self):
803861 view = self .view_class .as_view ()
804862
805863 response = view (request , pk = self .other_user .pk )
806- self .assertEqual (response .status_code , status .HTTP_200_OK , response .data )
864+ self .assertEqual (
865+ response .status_code ,
866+ status .HTTP_200_OK ,
867+ response .data ,
868+ )
807869
808870 user = User .objects .get (pk = self .other_user .pk )
809871 self .assertEqual (user .name , data ['name' ])
0 commit comments