Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Added a context manager to capture queries while testing.

Also made some import cleanups while I was there.

Refs #10399.
  • Loading branch information...
commit 952ba5237ea62e7647cdd5214b1df79c0e7cea38 1 parent 203c17c
Simon Charette authored March 01, 2013
35  django/test/testcases.py
@@ -24,31 +24,30 @@
24 24
 from django.core.handlers.wsgi import WSGIHandler
25 25
 from django.core.management import call_command
26 26
 from django.core.management.color import no_style
27  
-from django.core.signals import request_started
28 27
 from django.core.servers.basehttp import (WSGIRequestHandler, WSGIServer,
29 28
     WSGIServerException)
30 29
 from django.core.urlresolvers import clear_url_caches
31 30
 from django.core.validators import EMPTY_VALUES
32  
-from django.db import (transaction, connection, connections, DEFAULT_DB_ALIAS,
33  
-    reset_queries)
  31
+from django.db import connection, connections, DEFAULT_DB_ALIAS, transaction
34 32
 from django.forms.fields import CharField
35 33
 from django.http import QueryDict
36 34
 from django.test import _doctest as doctest
37 35
 from django.test.client import Client
38 36
 from django.test.html import HTMLParseError, parse_html
39 37
 from django.test.signals import template_rendered
40  
-from django.test.utils import (override_settings, compare_xml, strip_quotes)
41  
-from django.test.utils import ContextList
42  
-from django.utils import unittest as ut2
  38
+from django.test.utils import (CaptureQueriesContext, ContextList,
  39
+    override_settings, compare_xml, strip_quotes)
  40
+from django.utils import six, unittest as ut2
43 41
 from django.utils.encoding import force_text
44  
-from django.utils import six
  42
+from django.utils.unittest import skipIf # Imported here for backward compatibility
45 43
 from django.utils.unittest.util import safe_repr
46  
-from django.utils.unittest import skipIf
47 44
 from django.views.static import serve
48 45
 
  46
+
49 47
 __all__ = ('DocTestRunner', 'OutputChecker', 'TestCase', 'TransactionTestCase',
50 48
            'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature')
51 49
 
  50
+
52 51
 normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
53 52
 normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)",
54 53
                                 lambda m: "Decimal(\"%s\")" % m.groups()[0], s)
@@ -168,28 +167,17 @@ def report_unexpected_exception(self, out, test, example, exc_info):
168 167
             transaction.rollback_unless_managed(using=conn)
169 168
 
170 169
 
171  
-class _AssertNumQueriesContext(object):
  170
+class _AssertNumQueriesContext(CaptureQueriesContext):
172 171
     def __init__(self, test_case, num, connection):
173 172
         self.test_case = test_case
174 173
         self.num = num
175  
-        self.connection = connection
176  
-
177  
-    def __enter__(self):
178  
-        self.old_debug_cursor = self.connection.use_debug_cursor
179  
-        self.connection.use_debug_cursor = True
180  
-        self.starting_queries = len(self.connection.queries)
181  
-        request_started.disconnect(reset_queries)
182  
-        return self
  174
+        super(_AssertNumQueriesContext, self).__init__(connection)
183 175
 
184 176
     def __exit__(self, exc_type, exc_value, traceback):
185  
-        self.connection.use_debug_cursor = self.old_debug_cursor
186  
-        request_started.connect(reset_queries)
187 177
         if exc_type is not None:
188 178
             return
189  
-
190  
-        final_queries = len(self.connection.queries)
191  
-        executed = final_queries - self.starting_queries
192  
-
  179
+        super(_AssertNumQueriesContext, self).__exit__(exc_type, exc_value, traceback)
  180
+        executed = len(self)
193 181
         self.test_case.assertEqual(
194 182
             executed, self.num, "%d queries executed, %d expected" % (
195 183
                 executed, self.num
@@ -1051,7 +1039,6 @@ def run(self):
1051 1039
         http requests.
1052 1040
         """
1053 1041
         if self.connections_override:
1054  
-            from django.db import connections
1055 1042
             # Override this thread's database connections with the ones
1056 1043
             # provided by the main thread.
1057 1044
             for alias, conn in self.connections_override.items():
39  django/test/utils.py
@@ -4,6 +4,8 @@
4 4
 
5 5
 from django.conf import settings, UserSettingsHolder
6 6
 from django.core import mail
  7
+from django.core.signals import request_started
  8
+from django.db import reset_queries
7 9
 from django.template import Template, loader, TemplateDoesNotExist
8 10
 from django.template.loaders import cached
9 11
 from django.test.signals import template_rendered, setting_changed
@@ -339,5 +341,42 @@ def is_quoted_unicode(s):
339 341
         got = got.strip()[2:-1]
340 342
     return want, got
341 343
 
  344
+
342 345
 def str_prefix(s):
343 346
     return s % {'_': '' if six.PY3 else 'u'}
  347
+
  348
+
  349
+class CaptureQueriesContext(object):
  350
+    """
  351
+    Context manager that captures queries executed by the specified connection.
  352
+    """
  353
+    def __init__(self, connection):
  354
+        self.connection = connection
  355
+
  356
+    def __iter__(self):
  357
+        return iter(self.captured_queries)
  358
+
  359
+    def __getitem__(self, index):
  360
+        return self.captured_queries[index]
  361
+
  362
+    def __len__(self):
  363
+        return len(self.captured_queries)
  364
+
  365
+    @property
  366
+    def captured_queries(self):
  367
+        return self.connection.queries[self.initial_queries:self.final_queries]
  368
+
  369
+    def __enter__(self):
  370
+        self.use_debug_cursor = self.connection.use_debug_cursor
  371
+        self.connection.use_debug_cursor = True
  372
+        self.initial_queries = len(self.connection.queries)
  373
+        self.final_queries = None
  374
+        request_started.disconnect(reset_queries)
  375
+        return self
  376
+
  377
+    def __exit__(self, exc_type, exc_value, traceback):
  378
+        self.connection.use_debug_cursor = self.use_debug_cursor
  379
+        request_started.connect(reset_queries)
  380
+        if exc_type is not None:
  381
+            return
  382
+        self.final_queries = len(self.connection.queries)
67  tests/test_utils/tests.py
... ...
@@ -1,10 +1,14 @@
1 1
 # -*- coding: utf-8 -*-
2 2
 from __future__ import absolute_import, unicode_literals
  3
+import warnings
3 4
 
  5
+from django.db import connection
4 6
 from django.forms import EmailField, IntegerField
5 7
 from django.http import HttpResponse
6 8
 from django.template.loader import render_to_string
7 9
 from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
  10
+from django.test.html import HTMLParseError, parse_html
  11
+from django.test.utils import CaptureQueriesContext
8 12
 from django.utils import six
9 13
 from django.utils.unittest import skip
10 14
 
@@ -94,6 +98,60 @@ def test_undefined_order(self):
94 98
         )
95 99
 
96 100
 
  101
+class CaptureQueriesContextManagerTests(TestCase):
  102
+    urls = 'test_utils.urls'
  103
+
  104
+    def setUp(self):
  105
+        self.person_pk = six.text_type(Person.objects.create(name='test').pk)
  106
+
  107
+    def test_simple(self):
  108
+        with CaptureQueriesContext(connection) as captured_queries:
  109
+            Person.objects.get(pk=self.person_pk)
  110
+        self.assertEqual(len(captured_queries), 1)
  111
+        self.assertIn(self.person_pk, captured_queries[0]['sql'])
  112
+
  113
+        with CaptureQueriesContext(connection) as captured_queries:
  114
+            pass
  115
+        self.assertEqual(0, len(captured_queries))
  116
+
  117
+    def test_within(self):
  118
+        with CaptureQueriesContext(connection) as captured_queries:
  119
+            Person.objects.get(pk=self.person_pk)
  120
+            self.assertEqual(len(captured_queries), 1)
  121
+            self.assertIn(self.person_pk, captured_queries[0]['sql'])
  122
+
  123
+    def test_nested(self):
  124
+        with CaptureQueriesContext(connection) as captured_queries:
  125
+            Person.objects.count()
  126
+            with CaptureQueriesContext(connection) as nested_captured_queries:
  127
+                Person.objects.count()
  128
+        self.assertEqual(1, len(nested_captured_queries))
  129
+        self.assertEqual(2, len(captured_queries))
  130
+
  131
+    def test_failure(self):
  132
+        with self.assertRaises(TypeError):
  133
+            with CaptureQueriesContext(connection):
  134
+                raise TypeError
  135
+
  136
+    def test_with_client(self):
  137
+        with CaptureQueriesContext(connection) as captured_queries:
  138
+            self.client.get("/test_utils/get_person/%s/" % self.person_pk)
  139
+        self.assertEqual(len(captured_queries), 1)
  140
+        self.assertIn(self.person_pk, captured_queries[0]['sql'])
  141
+
  142
+        with CaptureQueriesContext(connection) as captured_queries:
  143
+            self.client.get("/test_utils/get_person/%s/" % self.person_pk)
  144
+        self.assertEqual(len(captured_queries), 1)
  145
+        self.assertIn(self.person_pk, captured_queries[0]['sql'])
  146
+
  147
+        with CaptureQueriesContext(connection) as captured_queries:
  148
+            self.client.get("/test_utils/get_person/%s/" % self.person_pk)
  149
+            self.client.get("/test_utils/get_person/%s/" % self.person_pk)
  150
+        self.assertEqual(len(captured_queries), 2)
  151
+        self.assertIn(self.person_pk, captured_queries[0]['sql'])
  152
+        self.assertIn(self.person_pk, captured_queries[1]['sql'])
  153
+
  154
+
97 155
 class AssertNumQueriesContextManagerTests(TestCase):
98 156
     urls = 'test_utils.urls'
99 157
 
@@ -219,7 +277,6 @@ def test_save_restore_warnings_state(self):
219 277
         # In reality this test could be satisfied by many broken implementations
220 278
         # of save_warnings_state/restore_warnings_state (e.g. just
221 279
         # warnings.resetwarnings()) , but it is difficult to test more.
222  
-        import warnings
223 280
         with warnings.catch_warnings():
224 281
             warnings.simplefilter("ignore", DeprecationWarning)
225 282
 
@@ -245,7 +302,6 @@ class MyWarning(Warning):
245 302
 
246 303
 class HTMLEqualTests(TestCase):
247 304
     def test_html_parser(self):
248  
-        from django.test.html import parse_html
249 305
         element = parse_html('<div><p>Hello</p></div>')
250 306
         self.assertEqual(len(element.children), 1)
251 307
         self.assertEqual(element.children[0].name, 'p')
@@ -259,7 +315,6 @@ def test_html_parser(self):
259 315
         self.assertEqual(dom[0], 'foo')
260 316
 
261 317
     def test_parse_html_in_script(self):
262  
-        from django.test.html import parse_html
263 318
         parse_html('<script>var a = "<p" + ">";</script>');
264 319
         parse_html('''
265 320
             <script>
@@ -275,8 +330,6 @@ def test_parse_html_in_script(self):
275 330
         self.assertEqual(dom.children[0], "<p>foo</p> '</scr'+'ipt>' <span>bar</span>")
276 331
 
277 332
     def test_self_closing_tags(self):
278  
-        from django.test.html import parse_html
279  
-
280 333
         self_closing_tags = ('br' , 'hr', 'input', 'img', 'meta', 'spacer',
281 334
             'link', 'frame', 'base', 'col')
282 335
         for tag in self_closing_tags:
@@ -400,7 +453,6 @@ def test_complex_examples(self):
400 453
         </html>""")
401 454
 
402 455
     def test_html_contain(self):
403  
-        from django.test.html import parse_html
404 456
         # equal html contains each other
405 457
         dom1 = parse_html('<p>foo')
406 458
         dom2 = parse_html('<p>foo</p>')
@@ -424,7 +476,6 @@ def test_html_contain(self):
424 476
         self.assertTrue(dom1 in dom2)
425 477
 
426 478
     def test_count(self):
427  
-        from django.test.html import parse_html
428 479
         # equal html contains each other one time
429 480
         dom1 = parse_html('<p>foo')
430 481
         dom2 = parse_html('<p>foo</p>')
@@ -459,7 +510,6 @@ def test_count(self):
459 510
         self.assertEqual(dom2.count(dom1), 0)
460 511
 
461 512
     def test_parsing_errors(self):
462  
-        from django.test.html import HTMLParseError, parse_html
463 513
         with self.assertRaises(AssertionError):
464 514
             self.assertHTMLEqual('<p>', '')
465 515
         with self.assertRaises(AssertionError):
@@ -488,7 +538,6 @@ def test_contains_html(self):
488 538
             self.assertContains(response, '<p "whats" that>')
489 539
 
490 540
     def test_unicode_handling(self):
491  
-        from django.http import HttpResponse
492 541
         response = HttpResponse('<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>')
493 542
         self.assertContains(response, '<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>', html=True)
494 543
 

0 notes on commit 952ba52

Please sign in to comment.
Something went wrong with that request. Please try again.