Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Fixed #5416 -- Added TestCase.assertNumQueries, which tests that a gi…

…ven function executes the correct number of queries.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@14183 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit 5506653b777d7547d21ea2d74e9588fb94314b77 1 parent ceef628
Alex Gaynor authored October 12, 2010
4  django/db/backends/__init__.py
@@ -21,6 +21,7 @@ def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
21 21
         self.settings_dict = settings_dict
22 22
         self.alias = alias
23 23
         self.vendor = 'unknown'
  24
+        self.use_debug_cursor = None
24 25
 
25 26
     def __eq__(self, other):
26 27
         return self.settings_dict == other.settings_dict
@@ -74,7 +75,8 @@ def close(self):
74 75
     def cursor(self):
75 76
         from django.conf import settings
76 77
         cursor = self._cursor()
77  
-        if settings.DEBUG:
  78
+        if (self.use_debug_cursor or
  79
+            (self.use_debug_cursor is None and settings.DEBUG)):
78 80
             return self.make_debug_cursor(cursor)
79 81
         return cursor
80 82
 
44  django/test/testcases.py
... ...
@@ -1,4 +1,5 @@
1 1
 import re
  2
+import sys
2 3
 from urlparse import urlsplit, urlunsplit
3 4
 from xml.dom.minidom import parseString, Node
4 5
 
@@ -205,6 +206,33 @@ def report_unexpected_exception(self, out, test, example, exc_info):
205 206
         for conn in connections:
206 207
             transaction.rollback_unless_managed(using=conn)
207 208
 
  209
+class _AssertNumQueriesContext(object):
  210
+    def __init__(self, test_case, num, connection):
  211
+        self.test_case = test_case
  212
+        self.num = num
  213
+        self.connection = connection
  214
+
  215
+    def __enter__(self):
  216
+        self.old_debug_cursor = self.connection.use_debug_cursor
  217
+        self.connection.use_debug_cursor = True
  218
+        self.starting_queries = len(self.connection.queries)
  219
+        return self
  220
+
  221
+    def __exit__(self, exc_type, exc_value, traceback):
  222
+        if exc_type is not None:
  223
+            return
  224
+
  225
+        self.connection.use_debug_cursor = self.old_debug_cursor
  226
+        final_queries = len(self.connection.queries)
  227
+        executed = final_queries - self.starting_queries
  228
+
  229
+        self.test_case.assertEqual(
  230
+            executed, self.num, "%d queries executed, %d expected" % (
  231
+                executed, self.num
  232
+            )
  233
+        )
  234
+
  235
+
208 236
 class TransactionTestCase(unittest.TestCase):
209 237
     # The class we'll use for the test client self.client.
210 238
     # Can be overridden in derived classes.
@@ -469,6 +497,22 @@ def assertTemplateNotUsed(self, response, template_name, msg_prefix=''):
469 497
     def assertQuerysetEqual(self, qs, values, transform=repr):
470 498
         return self.assertEqual(map(transform, qs), values)
471 499
 
  500
+    def assertNumQueries(self, num, func=None, *args, **kwargs):
  501
+        using = kwargs.pop("using", DEFAULT_DB_ALIAS)
  502
+        connection = connections[using]
  503
+
  504
+        context = _AssertNumQueriesContext(self, num, connection)
  505
+        if func is None:
  506
+            return context
  507
+
  508
+        # Basically emulate the `with` statement here.
  509
+
  510
+        context.__enter__()
  511
+        try:
  512
+            func(*args, **kwargs)
  513
+        finally:
  514
+            context.__exit__(*sys.exc_info())
  515
+
472 516
 def connections_support_transactions():
473 517
     """
474 518
     Returns True if all connections support transactions.  This is messy
26  docs/topics/testing.txt
@@ -1372,6 +1372,32 @@ cause of an failure in your test suite.
1372 1372
     implicit ordering, you will need to apply a ``order_by()`` clause to your
1373 1373
     queryset to ensure that the test will pass reliably.
1374 1374
 
  1375
+.. method:: TestCase.assertNumQueries(num, func, *args, **kwargs):
  1376
+
  1377
+    .. versionadded:: 1.3
  1378
+
  1379
+    Asserts that when ``func`` is called with ``*args`` and ``**kwargs`` that
  1380
+    ``num`` database queries are executed.
  1381
+
  1382
+    If a ``"using"`` key is present in ``kwargs`` it is used as the database
  1383
+    alias for which to check the number of queries.  If you wish to call a
  1384
+    function with a ``using`` parameter you can do it by wrapping the call with
  1385
+    a ``lambda`` to add an extra parameter::
  1386
+
  1387
+        self.assertNumQueries(7, lambda: my_function(using=7))
  1388
+
  1389
+    If you're using Python 2.5 or greater you can also use this as a context
  1390
+    manager::
  1391
+
  1392
+        # This is necessary in Python 2.5 to enable the with statement, in 2.6
  1393
+        # and up it is no longer necessary.
  1394
+        from __future__ import with_statement
  1395
+
  1396
+        with self.assertNumQueries(2):
  1397
+            Person.objects.create(name="Aaron")
  1398
+            Person.objects.create(name="Daniel")
  1399
+
  1400
+
1375 1401
 .. _topics-testing-email:
1376 1402
 
1377 1403
 E-mail services
118  tests/modeltests/select_related/tests.py
... ...
@@ -1,6 +1,4 @@
@@ -36,73 +34,73 @@ def setUp(self):
@@ -116,11 +114,12 @@ def test_list_with_depth(self):
@@ -136,28 +135,31 @@ def test_certain_fields(self):
45  tests/modeltests/validation/test_unique.py
@@ -2,9 +2,11 @@
2 2
 
3 3
 from django.conf import settings
4 4
 from django.db import connection
  5
+from django.test import TestCase
5 6
 from django.utils import unittest
6 7
 
7  
-from models import CustomPKModel, UniqueTogetherModel, UniqueFieldsModel, UniqueForDateModel, ModelToValidate
  8
+from models import (CustomPKModel, UniqueTogetherModel, UniqueFieldsModel,
  9
+    UniqueForDateModel, ModelToValidate)
8 10
 
9 11
 
10 12
 class GetUniqueCheckTests(unittest.TestCase):
@@ -51,37 +53,26 @@ def test_unique_for_date_exclusion(self):
51 53
             ), m._get_unique_checks(exclude='start_date')
52 54
         )
53 55
 
54  
-class PerformUniqueChecksTest(unittest.TestCase):
55  
-    def setUp(self):
56  
-        # Set debug to True to gain access to connection.queries.
57  
-        self._old_debug, settings.DEBUG = settings.DEBUG, True
58  
-        super(PerformUniqueChecksTest, self).setUp()
59  
-
60  
-    def tearDown(self):
61  
-        # Restore old debug value.
62  
-        settings.DEBUG = self._old_debug
63  
-        super(PerformUniqueChecksTest, self).tearDown()
64  
-
  56
+class PerformUniqueChecksTest(TestCase):
65 57
     def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self):
66 58
         # Regression test for #12560
67  
-        query_count = len(connection.queries)
68  
-        mtv = ModelToValidate(number=10, name='Some Name')
69  
-        setattr(mtv, '_adding', True)
70  
-        mtv.full_clean()
71  
-        self.assertEqual(query_count, len(connection.queries))
  59
+        def test():
  60
+            mtv = ModelToValidate(number=10, name='Some Name')
  61
+            setattr(mtv, '_adding', True)
  62
+            mtv.full_clean()
  63
+        self.assertNumQueries(0, test)
72 64
 
73 65
     def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self):
74 66
         # Regression test for #12560
75  
-        query_count = len(connection.queries)
76  
-        mtv = ModelToValidate(number=10, name='Some Name', id=123)
77  
-        setattr(mtv, '_adding', True)
78  
-        mtv.full_clean()
79  
-        self.assertEqual(query_count + 1, len(connection.queries))
  67
+        def test():
  68
+            mtv = ModelToValidate(number=10, name='Some Name', id=123)
  69
+            setattr(mtv, '_adding', True)
  70
+            mtv.full_clean()
  71
+        self.assertNumQueries(1, test)
80 72
 
81 73
     def test_primary_key_unique_check_not_performed_when_not_adding(self):
82 74
         # Regression test for #12132
83  
-        query_count= len(connection.queries)
84  
-        mtv = ModelToValidate(number=10, name='Some Name')
85  
-        mtv.full_clean()
86  
-        self.assertEqual(query_count, len(connection.queries))
87  
-
  75
+        def test():
  76
+            mtv = ModelToValidate(number=10, name='Some Name')
  77
+            mtv.full_clean()
  78
+        self.assertNumQueries(0, test)
4  tests/modeltests/validation/tests.py
@@ -6,7 +6,8 @@
6 6
 
7 7
 # Import other tests for this package.
8 8
 from modeltests.validation.validators import TestModelsWithValidators
9  
-from modeltests.validation.test_unique import GetUniqueCheckTests, PerformUniqueChecksTest
  9
+from modeltests.validation.test_unique import (GetUniqueCheckTests,
  10
+    PerformUniqueChecksTest)
10 11
 from modeltests.validation.test_custom_messages import CustomMessagesTest
11 12
 
12 13
 
@@ -111,4 +112,3 @@ def test_validation_with_invalid_blank_field(self):
111 112
         article = Article(author_id=self.author.id)
112 113
         form = ArticleForm(data, instance=article)
113 114
         self.assertEqual(form.errors.keys(), ['pub_date'])
114  
-
19  tests/regressiontests/defer_regress/tests.py
@@ -11,17 +11,6 @@
11 11
 
12 12
 
13 13
 class DeferRegressionTest(TestCase):
14  
-    def assert_num_queries(self, n, func, *args, **kwargs):
15  
-        old_DEBUG = settings.DEBUG
16  
-        settings.DEBUG = True
17  
-        starting_queries = len(connection.queries)
18  
-        try:
19  
-            func(*args, **kwargs)
20  
-        finally:
21  
-            settings.DEBUG = old_DEBUG
22  
-        self.assertEqual(starting_queries + n, len(connection.queries))
23  
-
24  
-
25 14
     def test_basic(self):
26 15
         # Deferred fields should really be deferred and not accidentally use
27 16
         # the field's default value just because they aren't passed to __init__
@@ -33,19 +22,19 @@ def test_basic(self):
33 22
         def test():
34 23
             self.assertEqual(obj.name, "first")
35 24
             self.assertEqual(obj.other_value, 0)
36  
-        self.assert_num_queries(0, test)
  25
+        self.assertNumQueries(0, test)
37 26
 
38 27
         def test():
39 28
             self.assertEqual(obj.value, 42)
40  
-        self.assert_num_queries(1, test)
  29
+        self.assertNumQueries(1, test)
41 30
 
42 31
         def test():
43 32
             self.assertEqual(obj.text, "xyzzy")
44  
-        self.assert_num_queries(1, test)
  33
+        self.assertNumQueries(1, test)
45 34
 
46 35
         def test():
47 36
             self.assertEqual(obj.text, "xyzzy")
48  
-        self.assert_num_queries(0, test)
  37
+        self.assertNumQueries(0, test)
49 38
 
50 39
         # Regression test for #10695. Make sure different instances don't
51 40
         # inadvertently share data in the deferred descriptor objects.
19  tests/regressiontests/forms/models.py
... ...
@@ -1,10 +1,9 @@
1 1
 # -*- coding: utf-8 -*-
2 2
 import datetime
3  
-import tempfile
4 3
 import shutil
  4
+import tempfile
5 5
 
6  
-from django.db import models, connection
7  
-from django.conf import settings
  6
+from django.db import models
8 7
 # Can't import as "forms" due to implementation details in the test suite (the
9 8
 # current file is called "forms" and is already imported).
10 9
 from django import forms as django_forms
@@ -77,19 +76,13 @@ class TestTicket12510(TestCase):
77 76
     ''' It is not necessary to generate choices for ModelChoiceField (regression test for #12510). '''
78 77
     def setUp(self):
79 78
         self.groups = [Group.objects.create(name=name) for name in 'abc']
80  
-        self.old_debug = settings.DEBUG
81  
-        # turn debug on to get access to connection.queries
82  
-        settings.DEBUG = True
83  
-
84  
-    def tearDown(self):
85  
-        settings.DEBUG = self.old_debug
86 79
 
87 80
     def test_choices_not_fetched_when_not_rendering(self):
88  
-        initial_queries = len(connection.queries)
89  
-        field = django_forms.ModelChoiceField(Group.objects.order_by('-name'))
90  
-        self.assertEqual('a', field.clean(self.groups[0].pk).name)
  81
+        def test():
  82
+            field = django_forms.ModelChoiceField(Group.objects.order_by('-name'))
  83
+            self.assertEqual('a', field.clean(self.groups[0].pk).name)
91 84
         # only one query is required to pull the model from DB
92  
-        self.assertEqual(initial_queries+1, len(connection.queries))
  85
+        self.assertNumQueries(1, test)
93 86
 
94 87
 class ModelFormCallableModelDefault(TestCase):
95 88
     def test_no_empty_option(self):
16  tests/regressiontests/model_forms_regress/tests.py
... ...
@@ -1,10 +1,8 @@
1 1
 import unittest
2 2
 from datetime import date
3 3
 
4  
-from django import db
5 4
 from django import forms
6 5
 from django.forms.models import modelform_factory, ModelChoiceField
7  
-from django.conf import settings
8 6
 from django.test import TestCase
9 7
 from django.core.exceptions import FieldError, ValidationError
10 8
 from django.core.files.uploadedfile import SimpleUploadedFile
@@ -14,14 +12,6 @@
14 12
 
15 13
 
16 14
 class ModelMultipleChoiceFieldTests(TestCase):
17  
-
18  
-    def setUp(self):
19  
-        self.old_debug = settings.DEBUG
20  
-        settings.DEBUG = True
21  
-
22  
-    def tearDown(self):
23  
-        settings.DEBUG = self.old_debug
24  
-
25 15
     def test_model_multiple_choice_number_of_queries(self):
26 16
         """
27 17
         Test that ModelMultipleChoiceField does O(1) queries instead of
@@ -30,10 +20,8 @@ def test_model_multiple_choice_number_of_queries(self):
30 20
         for i in range(30):
31 21
             Person.objects.create(name="Person %s" % i)
32 22
 
33  
-        db.reset_queries()
34 23
         f = forms.ModelMultipleChoiceField(queryset=Person.objects.all())
35  
-        selected = f.clean([1, 3, 5, 7, 9])
36  
-        self.assertEquals(len(db.connection.queries), 1)
  24
+        self.assertNumQueries(1, f.clean, [1, 3, 5, 7, 9])
37 25
 
38 26
 class TripleForm(forms.ModelForm):
39 27
     class Meta:
@@ -312,7 +300,7 @@ class Meta:
312 300
                     model = Person
313 301
                     fields = ('name', 'no-field')
314 302
         except FieldError, e:
315  
-            # Make sure the exception contains some reference to the 
  303
+            # Make sure the exception contains some reference to the
316 304
             # field responsible for the problem.
317 305
             self.assertTrue('no-field' in e.args[0])
318 306
         else:
90  tests/regressiontests/select_related_onetoone/tests.py
@@ -7,11 +7,6 @@
@@ -26,65 +21,66 @@ def setUp(self):
5  tests/regressiontests/test_utils/models.py
... ...
@@ -0,0 +1,5 @@
  1
+from django.db import models
  2
+
  3
+
  4
+class Person(models.Model):
  5
+    name = models.CharField(max_length=100)
30  tests/regressiontests/test_utils/python_25.py
... ...
@@ -0,0 +1,30 @@
  1
+from __future__ import with_statement
  2
+
  3
+from django.test import TestCase
  4
+
  5
+from models import Person
  6
+
  7
+
  8
+class AssertNumQueriesTests(TestCase):
  9
+    def test_simple(self):
  10
+        with self.assertNumQueries(0):
  11
+            pass
  12
+
  13
+        with self.assertNumQueries(1):
  14
+            # Guy who wrote Linux
  15
+            Person.objects.create(name="Linus Torvalds")
  16
+
  17
+        with self.assertNumQueries(2):
  18
+            # Guy who owns the bagel place I like
  19
+            Person.objects.create(name="Uncle Ricky")
  20
+            self.assertEqual(Person.objects.count(), 2)
  21
+
  22
+    def test_failure(self):
  23
+        with self.assertRaises(AssertionError) as exc_info:
  24
+            with self.assertNumQueries(2):
  25
+                Person.objects.count()
  26
+        self.assertEqual(str(exc_info.exception), "1 != 2 : 1 queries executed, 2 expected")
  27
+
  28
+        with self.assertRaises(TypeError):
  29
+            with self.assertNumQueries(4000):
  30
+                raise TypeError
14  tests/regressiontests/test_utils/tests.py
... ...
@@ -1,6 +1,12 @@
1  
-r"""
  1
+import sys
  2
+
  3
+if sys.version_info >= (2, 5):
  4
+    from python_25 import AssertNumQueriesTests
  5
+
  6
+
  7
+__test__ = {"API_TEST": r"""
2 8
 # Some checks of the doctest output normalizer.
3  
-# Standard doctests do fairly 
  9
+# Standard doctests do fairly
4 10
 >>> from django.utils import simplejson
5 11
 >>> from django.utils.xmlutils import SimplerXMLGenerator
6 12
 >>> from StringIO import StringIO
@@ -55,7 +61,7 @@
55 61
 >>> produce_json()
56 62
 '["foo", {"whiz": 42, "bar": ["baz", null, 1.0, 2]}]'
57 63
 
58  
-# XML output is normalized for attribute order, so it doesn't matter 
  64
+# XML output is normalized for attribute order, so it doesn't matter
59 65
 # which order XML element attributes are listed in output
60 66
 >>> produce_xml()
61 67
 '<?xml version="1.0" encoding="UTF-8"?>\n<foo aaa="1.0" bbb="2.0"><bar ccc="3.0">Hello</bar><whiz>Goodbye</whiz></foo>'
@@ -69,4 +75,4 @@
69 75
 >>> produce_xml_fragment()
70 76
 '<foo bbb="2.0" aaa="1.0">Hello</foo><bar ddd="4.0" ccc="3.0"></bar>'
71 77
 
72  
-"""
  78
+"""}

0 notes on commit 5506653

Please sign in to comment.
Something went wrong with that request. Please try again.