Permalink
Browse files

Fixed #4476 -- Added a ``follow`` option to the test client request m…

…ethods. This implements browser-like behavior for the test client, following redirect chains when a 30X response is received. Thanks to Marc Fargas and Keith Bussell for their work on this.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9911 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
1 parent e20f09c commit e735fe7160d786efeb2e8bea595174c1a68409a2 @freakboy3742 freakboy3742 committed Feb 27, 2009
View
@@ -1,5 +1,5 @@
import urllib
-from urlparse import urlparse, urlunparse
+from urlparse import urlparse, urlunparse, urlsplit
import sys
import os
try:
@@ -12,7 +12,7 @@
from django.core.handlers.base import BaseHandler
from django.core.handlers.wsgi import WSGIRequest
from django.core.signals import got_request_exception
-from django.http import SimpleCookie, HttpRequest
+from django.http import SimpleCookie, HttpRequest, QueryDict
from django.template import TemplateDoesNotExist
from django.test import signals
from django.utils.functional import curry
@@ -261,7 +261,7 @@ def request(self, **request):
return response
- def get(self, path, data={}, **extra):
+ def get(self, path, data={}, follow=False, **extra):
"""
Requests a response from the server using GET.
"""
@@ -275,9 +275,13 @@ def get(self, path, data={}, **extra):
}
r.update(extra)
- return self.request(**r)
+ response = self.request(**r)
+ if follow:
+ response = self._handle_redirects(response)
+ return response
- def post(self, path, data={}, content_type=MULTIPART_CONTENT, **extra):
+ def post(self, path, data={}, content_type=MULTIPART_CONTENT,
+ follow=False, **extra):
"""
Requests a response from the server using POST.
"""
@@ -297,9 +301,12 @@ def post(self, path, data={}, content_type=MULTIPART_CONTENT, **extra):
}
r.update(extra)
- return self.request(**r)
+ response = self.request(**r)
+ if follow:
+ response = self._handle_redirects(response)
+ return response
- def head(self, path, data={}, **extra):
+ def head(self, path, data={}, follow=False, **extra):
"""
Request a response from the server using HEAD.
"""
@@ -313,9 +320,12 @@ def head(self, path, data={}, **extra):
}
r.update(extra)
- return self.request(**r)
+ response = self.request(**r)
+ if follow:
+ response = self._handle_redirects(response)
+ return response
- def options(self, path, data={}, **extra):
+ def options(self, path, data={}, follow=False, **extra):
"""
Request a response from the server using OPTIONS.
"""
@@ -328,9 +338,13 @@ def options(self, path, data={}, **extra):
}
r.update(extra)
- return self.request(**r)
+ response = self.request(**r)
+ if follow:
+ response = self._handle_redirects(response)
+ return response
- def put(self, path, data={}, content_type=MULTIPART_CONTENT, **extra):
+ def put(self, path, data={}, content_type=MULTIPART_CONTENT,
+ follow=False, **extra):
"""
Send a resource to the server using PUT.
"""
@@ -350,9 +364,12 @@ def put(self, path, data={}, content_type=MULTIPART_CONTENT, **extra):
}
r.update(extra)
- return self.request(**r)
+ response = self.request(**r)
+ if follow:
+ response = self._handle_redirects(response)
+ return response
- def delete(self, path, data={}, **extra):
+ def delete(self, path, data={}, follow=False, **extra):
"""
Send a DELETE request to the server.
"""
@@ -365,7 +382,10 @@ def delete(self, path, data={}, **extra):
}
r.update(extra)
- return self.request(**r)
+ response = self.request(**r)
+ if follow:
+ response = self._handle_redirects(response)
+ return response
def login(self, **credentials):
"""
@@ -416,3 +436,27 @@ def logout(self):
session = __import__(settings.SESSION_ENGINE, {}, {}, ['']).SessionStore()
session.delete(session_key=self.cookies[settings.SESSION_COOKIE_NAME].value)
self.cookies = SimpleCookie()
+
+ def _handle_redirects(self, response):
+ "Follows any redirects by requesting responses from the server using GET."
+
+ response.redirect_chain = []
+ while response.status_code in (301, 302, 303, 307):
+ url = response['Location']
+ scheme, netloc, path, query, fragment = urlsplit(url)
+
+ redirect_chain = response.redirect_chain
+ redirect_chain.append((url, response.status_code))
+
+ # The test client doesn't handle external links,
+ # but since the situation is simulated in test_client,
+ # we fake things here by ignoring the netloc portion of the
+ # redirected URL.
+ response = self.get(path, QueryDict(query), follow=False)
+ response.redirect_chain = redirect_chain
+
+ # Prevent loops
+ if response.redirect_chain[-1] in response.redirect_chain[0:-1]:
+ break
+ return response
+
@@ -43,7 +43,7 @@ def disable_transaction_methods():
transaction.savepoint_commit = nop
transaction.savepoint_rollback = nop
transaction.enter_transaction_management = nop
- transaction.leave_transaction_management = nop
+ transaction.leave_transaction_management = nop
def restore_transaction_methods():
transaction.commit = real_commit
@@ -198,7 +198,7 @@ def report_unexpected_exception(self, out, test, example, exc_info):
# Rollback, in case of database errors. Otherwise they'd have
# side effects on other tests.
transaction.rollback_unless_managed()
-
+
class TransactionTestCase(unittest.TestCase):
def _pre_setup(self):
"""Performs any pre-test setup. This includes:
@@ -242,7 +242,7 @@ def __call__(self, result=None):
import sys
result.addError(self, sys.exc_info())
return
- super(TransactionTestCase, self).__call__(result)
+ super(TransactionTestCase, self).__call__(result)
try:
self._post_teardown()
except (KeyboardInterrupt, SystemExit):
@@ -263,7 +263,7 @@ def _post_teardown(self):
def _fixture_teardown(self):
pass
- def _urlconf_teardown(self):
+ def _urlconf_teardown(self):
if hasattr(self, '_old_root_urlconf'):
settings.ROOT_URLCONF = self._old_root_urlconf
clear_url_caches()
@@ -276,25 +276,48 @@ def assertRedirects(self, response, expected_url, status_code=302,
Note that assertRedirects won't work for external links since it uses
TestClient to do a request.
"""
- self.assertEqual(response.status_code, status_code,
- ("Response didn't redirect as expected: Response code was %d"
- " (expected %d)" % (response.status_code, status_code)))
- url = response['Location']
- scheme, netloc, path, query, fragment = urlsplit(url)
+ if hasattr(response, 'redirect_chain'):
+ # The request was a followed redirect
+ self.assertTrue(len(response.redirect_chain) > 0,
+ ("Response didn't redirect as expected: Response code was %d"
+ " (expected %d)" % (response.status_code, status_code)))
+
+ self.assertEqual(response.redirect_chain[0][1], status_code,
+ ("Initial response didn't redirect as expected: Response code was %d"
+ " (expected %d)" % (response.redirect_chain[0][1], status_code)))
+
+ url, status_code = response.redirect_chain[-1]
+
+ self.assertEqual(response.status_code, target_status_code,
+ ("Response didn't redirect as expected: Final Response code was %d"
+ " (expected %d)" % (response.status_code, target_status_code)))
+
+ else:
+ # Not a followed redirect
+ self.assertEqual(response.status_code, status_code,
+ ("Response didn't redirect as expected: Response code was %d"
+ " (expected %d)" % (response.status_code, status_code)))
+
+ url = response['Location']
+ scheme, netloc, path, query, fragment = urlsplit(url)
+
+ redirect_response = response.client.get(path, QueryDict(query))
+
+ # Get the redirection page, using the same client that was used
+ # to obtain the original response.
+ self.assertEqual(redirect_response.status_code, target_status_code,
+ ("Couldn't retrieve redirection page '%s': response code was %d"
+ " (expected %d)") %
+ (path, redirect_response.status_code, target_status_code))
+
e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url)
if not (e_scheme or e_netloc):
expected_url = urlunsplit(('http', host or 'testserver', e_path,
- e_query, e_fragment))
+ e_query, e_fragment))
+
self.assertEqual(url, expected_url,
"Response redirected to '%s', expected '%s'" % (url, expected_url))
- # Get the redirection page, using the same client that was used
- # to obtain the original response.
- redirect_response = response.client.get(path, QueryDict(query))
- self.assertEqual(redirect_response.status_code, target_status_code,
- ("Couldn't retrieve redirection page '%s': response code was %d"
- " (expected %d)") %
- (path, redirect_response.status_code, target_status_code))
def assertContains(self, response, text, count=None, status_code=200):
"""
@@ -401,15 +424,15 @@ def assertTemplateNotUsed(self, response, template_name):
class TestCase(TransactionTestCase):
"""
Does basically the same as TransactionTestCase, but surrounds every test
- with a transaction, monkey-patches the real transaction management routines to
- do nothing, and rollsback the test transaction at the end of the test. You have
+ with a transaction, monkey-patches the real transaction management routines to
+ do nothing, and rollsback the test transaction at the end of the test. You have
to use TransactionTestCase, if you need transaction management inside a test.
"""
def _fixture_setup(self):
if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
return super(TestCase, self)._fixture_setup()
-
+
transaction.enter_transaction_management()
transaction.managed(True)
disable_transaction_methods()
@@ -426,7 +449,7 @@ def _fixture_setup(self):
def _fixture_teardown(self):
if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
return super(TestCase, self)._fixture_teardown()
-
+
restore_transaction_methods()
transaction.rollback()
transaction.leave_transaction_management()
Oops, something went wrong. Retry.

0 comments on commit e735fe7

Please sign in to comment.