Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Fixed #9002 -- Added a RequestFactory. This allows you to create requ…

…est instances so you can unit test views as standalone functions. Thanks to Simon Willison for the suggestion and snippet on which this patch was originally based.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@14191 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit eec45e8b710b97201db106a6460fe051f8917833 1 parent 120aae2
Russell Keith-Magee authored October 12, 2010
2  django/test/__init__.py
@@ -2,6 +2,6 @@
2 2
 Django Unit Test and Doctest framework.
3 3
 """
4 4
 
5  
-from django.test.client import Client
  5
+from django.test.client import Client, RequestFactory
6 6
 from django.test.testcases import TestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature
7 7
 from django.test.utils import Approximate
283  django/test/client.py
@@ -156,7 +156,165 @@ def encode_file(boundary, key, file):
156 156
         file.read()
157 157
     ]
158 158
 
159  
-class Client(object):
  159
+
  160
+
  161
+class RequestFactory(object):
  162
+    """
  163
+    Class that lets you create mock Request objects for use in testing.
  164
+
  165
+    Usage:
  166
+
  167
+    rf = RequestFactory()
  168
+    get_request = rf.get('/hello/')
  169
+    post_request = rf.post('/submit/', {'foo': 'bar'})
  170
+
  171
+    Once you have a request object you can pass it to any view function,
  172
+    just as if that view had been hooked up using a URLconf.
  173
+    """
  174
+    def __init__(self, **defaults):
  175
+        self.defaults = defaults
  176
+        self.cookies = SimpleCookie()
  177
+        self.errors = StringIO()
  178
+
  179
+    def _base_environ(self, **request):
  180
+        """
  181
+        The base environment for a request.
  182
+        """
  183
+        environ = {
  184
+            'HTTP_COOKIE':       self.cookies.output(header='', sep='; '),
  185
+            'PATH_INFO':         '/',
  186
+            'QUERY_STRING':      '',
  187
+            'REMOTE_ADDR':       '127.0.0.1',
  188
+            'REQUEST_METHOD':    'GET',
  189
+            'SCRIPT_NAME':       '',
  190
+            'SERVER_NAME':       'testserver',
  191
+            'SERVER_PORT':       '80',
  192
+            'SERVER_PROTOCOL':   'HTTP/1.1',
  193
+            'wsgi.version':      (1,0),
  194
+            'wsgi.url_scheme':   'http',
  195
+            'wsgi.errors':       self.errors,
  196
+            'wsgi.multiprocess': True,
  197
+            'wsgi.multithread':  False,
  198
+            'wsgi.run_once':     False,
  199
+        }
  200
+        environ.update(self.defaults)
  201
+        environ.update(request)
  202
+        return environ
  203
+
  204
+    def request(self, **request):
  205
+        "Construct a generic request object."
  206
+        return WSGIRequest(self._base_environ(**request))
  207
+
  208
+    def get(self, path, data={}, **extra):
  209
+        "Construct a GET request"
  210
+
  211
+        parsed = urlparse(path)
  212
+        r = {
  213
+            'CONTENT_TYPE':    'text/html; charset=utf-8',
  214
+            'PATH_INFO':       urllib.unquote(parsed[2]),
  215
+            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
  216
+            'REQUEST_METHOD': 'GET',
  217
+            'wsgi.input':      FakePayload('')
  218
+        }
  219
+        r.update(extra)
  220
+        return self.request(**r)
  221
+
  222
+    def post(self, path, data={}, content_type=MULTIPART_CONTENT,
  223
+             **extra):
  224
+        "Construct a POST request."
  225
+
  226
+        if content_type is MULTIPART_CONTENT:
  227
+            post_data = encode_multipart(BOUNDARY, data)
  228
+        else:
  229
+            # Encode the content so that the byte representation is correct.
  230
+            match = CONTENT_TYPE_RE.match(content_type)
  231
+            if match:
  232
+                charset = match.group(1)
  233
+            else:
  234
+                charset = settings.DEFAULT_CHARSET
  235
+            post_data = smart_str(data, encoding=charset)
  236
+
  237
+        parsed = urlparse(path)
  238
+        r = {
  239
+            'CONTENT_LENGTH': len(post_data),
  240
+            'CONTENT_TYPE':   content_type,
  241
+            'PATH_INFO':      urllib.unquote(parsed[2]),
  242
+            'QUERY_STRING':   parsed[4],
  243
+            'REQUEST_METHOD': 'POST',
  244
+            'wsgi.input':     FakePayload(post_data),
  245
+        }
  246
+        r.update(extra)
  247
+        return self.request(**r)
  248
+
  249
+    def head(self, path, data={}, **extra):
  250
+        "Construct a HEAD request."
  251
+
  252
+        parsed = urlparse(path)
  253
+        r = {
  254
+            'CONTENT_TYPE':    'text/html; charset=utf-8',
  255
+            'PATH_INFO':       urllib.unquote(parsed[2]),
  256
+            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
  257
+            'REQUEST_METHOD': 'HEAD',
  258
+            'wsgi.input':      FakePayload('')
  259
+        }
  260
+        r.update(extra)
  261
+        return self.request(**r)
  262
+
  263
+    def options(self, path, data={}, **extra):
  264
+        "Constrict an OPTIONS request"
  265
+
  266
+        parsed = urlparse(path)
  267
+        r = {
  268
+            'PATH_INFO':       urllib.unquote(parsed[2]),
  269
+            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
  270
+            'REQUEST_METHOD': 'OPTIONS',
  271
+            'wsgi.input':      FakePayload('')
  272
+        }
  273
+        r.update(extra)
  274
+        return self.request(**r)
  275
+
  276
+    def put(self, path, data={}, content_type=MULTIPART_CONTENT,
  277
+            **extra):
  278
+        "Construct a PUT request."
  279
+
  280
+        if content_type is MULTIPART_CONTENT:
  281
+            post_data = encode_multipart(BOUNDARY, data)
  282
+        else:
  283
+            post_data = data
  284
+
  285
+        # Make `data` into a querystring only if it's not already a string. If
  286
+        # it is a string, we'll assume that the caller has already encoded it.
  287
+        query_string = None
  288
+        if not isinstance(data, basestring):
  289
+            query_string = urlencode(data, doseq=True)
  290
+
  291
+        parsed = urlparse(path)
  292
+        r = {
  293
+            'CONTENT_LENGTH': len(post_data),
  294
+            'CONTENT_TYPE':   content_type,
  295
+            'PATH_INFO':      urllib.unquote(parsed[2]),
  296
+            'QUERY_STRING':   query_string or parsed[4],
  297
+            'REQUEST_METHOD': 'PUT',
  298
+            'wsgi.input':     FakePayload(post_data),
  299
+        }
  300
+        r.update(extra)
  301
+        return self.request(**r)
  302
+
  303
+    def delete(self, path, data={}, **extra):
  304
+        "Construct a DELETE request."
  305
+
  306
+        parsed = urlparse(path)
  307
+        r = {
  308
+            'PATH_INFO':       urllib.unquote(parsed[2]),
  309
+            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
  310
+            'REQUEST_METHOD': 'DELETE',
  311
+            'wsgi.input':      FakePayload('')
  312
+        }
  313
+        r.update(extra)
  314
+        return self.request(**r)
  315
+
  316
+
  317
+class Client(RequestFactory):
160 318
     """
161 319
     A class that can act as a client for testing purposes.
162 320
 
@@ -175,11 +333,9 @@ class Client(object):
175 333
     HTML rendered to the end-user.
176 334
     """
177 335
     def __init__(self, enforce_csrf_checks=False, **defaults):
  336
+        super(Client, self).__init__(**defaults)
178 337
         self.handler = ClientHandler(enforce_csrf_checks)
179  
-        self.defaults = defaults
180  
-        self.cookies = SimpleCookie()
181 338
         self.exc_info = None
182  
-        self.errors = StringIO()
183 339
 
184 340
     def store_exc_info(self, **kwargs):
185 341
         """
@@ -199,6 +355,7 @@ def _session(self):
199 355
         return {}
200 356
     session = property(_session)
201 357
 
  358
+
202 359
     def request(self, **request):
203 360
         """
204 361
         The master request method. Composes the environment dictionary
@@ -206,25 +363,7 @@ def request(self, **request):
206 363
         Assumes defaults for the query environment, which can be overridden
207 364
         using the arguments to the request.
208 365
         """
209  
-        environ = {
210  
-            'HTTP_COOKIE':       self.cookies.output(header='', sep='; '),
211  
-            'PATH_INFO':         '/',
212  
-            'QUERY_STRING':      '',
213  
-            'REMOTE_ADDR':       '127.0.0.1',
214  
-            'REQUEST_METHOD':    'GET',
215  
-            'SCRIPT_NAME':       '',
216  
-            'SERVER_NAME':       'testserver',
217  
-            'SERVER_PORT':       '80',
218  
-            'SERVER_PROTOCOL':   'HTTP/1.1',
219  
-            'wsgi.version':      (1,0),
220  
-            'wsgi.url_scheme':   'http',
221  
-            'wsgi.errors':       self.errors,
222  
-            'wsgi.multiprocess': True,
223  
-            'wsgi.multithread':  False,
224  
-            'wsgi.run_once':     False,
225  
-        }
226  
-        environ.update(self.defaults)
227  
-        environ.update(request)
  366
+        environ = self._base_environ(**request)
228 367
 
229 368
         # Curry a data dictionary into an instance of the template renderer
230 369
         # callback function.
@@ -290,22 +429,11 @@ def _get_template(self):
290 429
             signals.template_rendered.disconnect(dispatch_uid="template-render")
291 430
             got_request_exception.disconnect(dispatch_uid="request-exception")
292 431
 
293  
-
294 432
     def get(self, path, data={}, follow=False, **extra):
295 433
         """
296 434
         Requests a response from the server using GET.
297 435
         """
298  
-        parsed = urlparse(path)
299  
-        r = {
300  
-            'CONTENT_TYPE':    'text/html; charset=utf-8',
301  
-            'PATH_INFO':       urllib.unquote(parsed[2]),
302  
-            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
303  
-            'REQUEST_METHOD': 'GET',
304  
-            'wsgi.input':      FakePayload('')
305  
-        }
306  
-        r.update(extra)
307  
-
308  
-        response = self.request(**r)
  436
+        response = super(Client, self).get(path, data=data, **extra)
309 437
         if follow:
310 438
             response = self._handle_redirects(response, **extra)
311 439
         return response
@@ -315,29 +443,7 @@ def post(self, path, data={}, content_type=MULTIPART_CONTENT,
315 443
         """
316 444
         Requests a response from the server using POST.
317 445
         """
318  
-        if content_type is MULTIPART_CONTENT:
319  
-            post_data = encode_multipart(BOUNDARY, data)
320  
-        else:
321  
-            # Encode the content so that the byte representation is correct.
322  
-            match = CONTENT_TYPE_RE.match(content_type)
323  
-            if match:
324  
-                charset = match.group(1)
325  
-            else:
326  
-                charset = settings.DEFAULT_CHARSET
327  
-            post_data = smart_str(data, encoding=charset)
328  
-
329  
-        parsed = urlparse(path)
330  
-        r = {
331  
-            'CONTENT_LENGTH': len(post_data),
332  
-            'CONTENT_TYPE':   content_type,
333  
-            'PATH_INFO':      urllib.unquote(parsed[2]),
334  
-            'QUERY_STRING':   parsed[4],
335  
-            'REQUEST_METHOD': 'POST',
336  
-            'wsgi.input':     FakePayload(post_data),
337  
-        }
338  
-        r.update(extra)
339  
-
340  
-        response = self.request(**r)
  446
+        response = super(Client, self).post(path, data=data, content_type=content_type, **extra)
341 447
         if follow:
342 448
             response = self._handle_redirects(response, **extra)
343 449
         return response
@@ -346,17 +452,7 @@ def head(self, path, data={}, follow=False, **extra):
346 452
         """
347 453
         Request a response from the server using HEAD.
348 454
         """
349  
-        parsed = urlparse(path)
350  
-        r = {
351  
-            'CONTENT_TYPE':    'text/html; charset=utf-8',
352  
-            'PATH_INFO':       urllib.unquote(parsed[2]),
353  
-            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
354  
-            'REQUEST_METHOD': 'HEAD',
355  
-            'wsgi.input':      FakePayload('')
356  
-        }
357  
-        r.update(extra)
358  
-
359  
-        response = self.request(**r)
  455
+        response = super(Client, self).head(path, data=data, **extra)
360 456
         if follow:
361 457
             response = self._handle_redirects(response, **extra)
362 458
         return response
@@ -365,16 +461,7 @@ def options(self, path, data={}, follow=False, **extra):
365 461
         """
366 462
         Request a response from the server using OPTIONS.
367 463
         """
368  
-        parsed = urlparse(path)
369  
-        r = {
370  
-            'PATH_INFO':       urllib.unquote(parsed[2]),
371  
-            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
372  
-            'REQUEST_METHOD': 'OPTIONS',
373  
-            'wsgi.input':      FakePayload('')
374  
-        }
375  
-        r.update(extra)
376  
-
377  
-        response = self.request(**r)
  464
+        response = super(Client, self).options(path, data=data, **extra)
378 465
         if follow:
379 466
             response = self._handle_redirects(response, **extra)
380 467
         return response
@@ -384,29 +471,7 @@ def put(self, path, data={}, content_type=MULTIPART_CONTENT,
384 471
         """
385 472
         Send a resource to the server using PUT.
386 473
         """
387  
-        if content_type is MULTIPART_CONTENT:
388  
-            post_data = encode_multipart(BOUNDARY, data)
389  
-        else:
390  
-            post_data = data
391  
-
392  
-        # Make `data` into a querystring only if it's not already a string. If
393  
-        # it is a string, we'll assume that the caller has already encoded it.
394  
-        query_string = None
395  
-        if not isinstance(data, basestring):
396  
-            query_string = urlencode(data, doseq=True)
397  
-
398  
-        parsed = urlparse(path)
399  
-        r = {
400  
-            'CONTENT_LENGTH': len(post_data),
401  
-            'CONTENT_TYPE':   content_type,
402  
-            'PATH_INFO':      urllib.unquote(parsed[2]),
403  
-            'QUERY_STRING':   query_string or parsed[4],
404  
-            'REQUEST_METHOD': 'PUT',
405  
-            'wsgi.input':     FakePayload(post_data),
406  
-        }
407  
-        r.update(extra)
408  
-
409  
-        response = self.request(**r)
  474
+        response = super(Client, self).put(path, data=data, content_type=content_type, **extra)
410 475
         if follow:
411 476
             response = self._handle_redirects(response, **extra)
412 477
         return response
@@ -415,23 +480,14 @@ def delete(self, path, data={}, follow=False, **extra):
415 480
         """
416 481
         Send a DELETE request to the server.
417 482
         """
418  
-        parsed = urlparse(path)
419  
-        r = {
420  
-            'PATH_INFO':       urllib.unquote(parsed[2]),
421  
-            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
422  
-            'REQUEST_METHOD': 'DELETE',
423  
-            'wsgi.input':      FakePayload('')
424  
-        }
425  
-        r.update(extra)
426  
-
427  
-        response = self.request(**r)
  483
+        response = super(Client, self).delete(path, data=data, **extra)
428 484
         if follow:
429 485
             response = self._handle_redirects(response, **extra)
430 486
         return response
431 487
 
432 488
     def login(self, **credentials):
433 489
         """
434  
-        Sets the Client to appear as if it has successfully logged into a site.
  490
+        Sets the Factory to appear as if it has successfully logged into a site.
435 491
 
436 492
         Returns True if login is possible; False if the provided credentials
437 493
         are incorrect, or the user is inactive, or if the sessions framework is
@@ -506,4 +562,3 @@ def _handle_redirects(self, response, **extra):
506 562
             if response.redirect_chain[-1] in response.redirect_chain[0:-1]:
507 563
                 break
508 564
         return response
509  
-
45  docs/topics/testing.txt
@@ -1014,6 +1014,51 @@ The following is a simple unit test using the test client::
1014 1014
             # Check that the rendered context contains 5 customers.
1015 1015
             self.assertEqual(len(response.context['customers']), 5)
1016 1016
 
  1017
+The request factory
  1018
+-------------------
  1019
+
  1020
+.. Class:: RequestFactory
  1021
+
  1022
+The :class:`~django.test.client.RequestFactory` is a simplified
  1023
+version of the test client that provides a way to generate a request
  1024
+instance that can be used as the first argument to any view. This
  1025
+means you can test a view function the same way as you would test any
  1026
+other function -- as a black box, with exactly known inputs, testing
  1027
+for specific outputs.
  1028
+
  1029
+The API for the :class:`~django.test.client.RequestFactory` is a slightly
  1030
+restricted subset of the test client API:
  1031
+
  1032
+    * It only has access to the HTTP methods :meth:`~Client.get()`,
  1033
+      :meth:`~Client.post()`, :meth:`~Client.put()`,
  1034
+      :meth:`~Client.delete()`, :meth:`~Client.head()` and
  1035
+      :meth:`~Client.options()`.
  1036
+
  1037
+    * These methods accept all the same arguments *except* for
  1038
+      ``follows``. Since this is just a factory for producing
  1039
+      requests, it's up to you to handle the response.
  1040
+
  1041
+Example
  1042
+~~~~~~~
  1043
+
  1044
+The following is a simple unit test using the request factory::
  1045
+
  1046
+    from django.utils import unittest
  1047
+    from django.test.client import RequestFactory
  1048
+
  1049
+    class SimpleTest(unittest.TestCase):
  1050
+        def setUp(self):
  1051
+            # Every test needs a client.
  1052
+            self.factory = RequestFactory()
  1053
+
  1054
+        def test_details(self):
  1055
+            # Issue a GET request.
  1056
+            request = self.factory.get('/customer/details')
  1057
+
  1058
+            # Test my_view() as if it were deployed at /customer/details
  1059
+            response = my_view(request)
  1060
+            self.assertEquals(response.status_code, 200)
  1061
+
1017 1062
 TestCase
1018 1063
 --------
1019 1064
 
14  tests/modeltests/test_client/models.py
@@ -20,9 +20,12 @@
20 20
 rather than the HTML rendered to the end-user.
21 21
 
22 22
 """
23  
-from django.test import Client, TestCase
24 23
 from django.conf import settings
25 24
 from django.core import mail
  25
+from django.test import Client, TestCase, RequestFactory
  26
+
  27
+from views import get_view
  28
+
26 29
 
27 30
 class ClientTest(TestCase):
28 31
     fixtures = ['testdata.json']
@@ -469,3 +472,12 @@ def test_custom_test_client(self):
469 472
         """A test case can specify a custom class for self.client."""
470 473
         self.assertEqual(hasattr(self.client, "i_am_customized"), True)
471 474
 
  475
+
  476
+class RequestFactoryTest(TestCase):
  477
+    def test_request_factory(self):
  478
+        factory = RequestFactory()
  479
+        request = factory.get('/somewhere/')
  480
+        response = get_view(request)
  481
+
  482
+        self.assertEqual(response.status_code, 200)
  483
+        self.assertContains(response, 'This is a test')

0 notes on commit eec45e8

Please sign in to comment.
Something went wrong with that request. Please try again.