Permalink
Browse files

Fixed #5416 -- Added TestCase.assertNumQueries, which tests that a gi…

…ven function executes the correct number of queries.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@14183 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
1 parent ceef628 commit 5506653b777d7547d21ea2d74e9588fb94314b77 @alex alex committed Oct 12, 2010
@@ -21,6 +21,7 @@ def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
self.settings_dict = settings_dict
self.alias = alias
self.vendor = 'unknown'
+ self.use_debug_cursor = None
def __eq__(self, other):
return self.settings_dict == other.settings_dict
@@ -74,7 +75,8 @@ def close(self):
def cursor(self):
from django.conf import settings
cursor = self._cursor()
- if settings.DEBUG:
+ if (self.use_debug_cursor or
+ (self.use_debug_cursor is None and settings.DEBUG)):
return self.make_debug_cursor(cursor)
return cursor
@@ -1,4 +1,5 @@
import re
+import sys
from urlparse import urlsplit, urlunsplit
from xml.dom.minidom import parseString, Node
@@ -205,6 +206,33 @@ def report_unexpected_exception(self, out, test, example, exc_info):
for conn in connections:
transaction.rollback_unless_managed(using=conn)
+class _AssertNumQueriesContext(object):
+ def __init__(self, test_case, num, connection):
+ self.test_case = test_case
+ self.num = num
+ self.connection = connection
+
+ def __enter__(self):
+ self.old_debug_cursor = self.connection.use_debug_cursor
+ self.connection.use_debug_cursor = True
+ self.starting_queries = len(self.connection.queries)
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if exc_type is not None:
+ return
+
+ self.connection.use_debug_cursor = self.old_debug_cursor
+ final_queries = len(self.connection.queries)
+ executed = final_queries - self.starting_queries
+
+ self.test_case.assertEqual(
+ executed, self.num, "%d queries executed, %d expected" % (
+ executed, self.num
+ )
+ )
+
+
class TransactionTestCase(unittest.TestCase):
# The class we'll use for the test client self.client.
# Can be overridden in derived classes.
@@ -469,6 +497,22 @@ def assertTemplateNotUsed(self, response, template_name, msg_prefix=''):
def assertQuerysetEqual(self, qs, values, transform=repr):
return self.assertEqual(map(transform, qs), values)
+ def assertNumQueries(self, num, func=None, *args, **kwargs):
+ using = kwargs.pop("using", DEFAULT_DB_ALIAS)
+ connection = connections[using]
+
+ context = _AssertNumQueriesContext(self, num, connection)
+ if func is None:
+ return context
+
+ # Basically emulate the `with` statement here.
+
+ context.__enter__()
+ try:
+ func(*args, **kwargs)
+ finally:
+ context.__exit__(*sys.exc_info())
+
def connections_support_transactions():
"""
Returns True if all connections support transactions. This is messy
@@ -1372,6 +1372,32 @@ cause of an failure in your test suite.
implicit ordering, you will need to apply a ``order_by()`` clause to your
queryset to ensure that the test will pass reliably.
+.. method:: TestCase.assertNumQueries(num, func, *args, **kwargs):
+
+ .. versionadded:: 1.3
+
+ Asserts that when ``func`` is called with ``*args`` and ``**kwargs`` that
+ ``num`` database queries are executed.
+
+ If a ``"using"`` key is present in ``kwargs`` it is used as the database
+ alias for which to check the number of queries. If you wish to call a
+ function with a ``using`` parameter you can do it by wrapping the call with
+ a ``lambda`` to add an extra parameter::
+
+ self.assertNumQueries(7, lambda: my_function(using=7))
+
+ If you're using Python 2.5 or greater you can also use this as a context
+ manager::
+
+ # This is necessary in Python 2.5 to enable the with statement, in 2.6
+ # and up it is no longer necessary.
+ from __future__ import with_statement
+
+ with self.assertNumQueries(2):
+ Person.objects.create(name="Aaron")
+ Person.objects.create(name="Daniel")
+
+
.. _topics-testing-email:
E-mail services
@@ -1,6 +1,4 @@
from django.test import TestCase
-from django.conf import settings
-from django import db
from models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species
@@ -36,73 +34,73 @@ def setUp(self):
# queries so we'll set it to True here and reset it at the end of the
# test case.
self.create_base_data()
- settings.DEBUG = True
- db.reset_queries()
-
- def tearDown(self):
- settings.DEBUG = False
def test_access_fks_without_select_related(self):
"""
Normally, accessing FKs doesn't fill in related objects
"""
- fly = Species.objects.get(name="melanogaster")
- domain = fly.genus.family.order.klass.phylum.kingdom.domain
- self.assertEqual(domain.name, 'Eukaryota')
- self.assertEqual(len(db.connection.queries), 8)
+ def test():
+ fly = Species.objects.get(name="melanogaster")
+ domain = fly.genus.family.order.klass.phylum.kingdom.domain
+ self.assertEqual(domain.name, 'Eukaryota')
+ self.assertNumQueries(8, test)
def test_access_fks_with_select_related(self):
"""
A select_related() call will fill in those related objects without any
extra queries
"""
- person = Species.objects.select_related(depth=10).get(name="sapiens")
- domain = person.genus.family.order.klass.phylum.kingdom.domain
- self.assertEqual(domain.name, 'Eukaryota')
- self.assertEqual(len(db.connection.queries), 1)
+ def test():
+ person = Species.objects.select_related(depth=10).get(name="sapiens")
+ domain = person.genus.family.order.klass.phylum.kingdom.domain
+ self.assertEqual(domain.name, 'Eukaryota')
+ self.assertNumQueries(1, test)
def test_list_without_select_related(self):
"""
select_related() also of course applies to entire lists, not just
items. This test verifies the expected behavior without select_related.
"""
- world = Species.objects.all()
- families = [o.genus.family.name for o in world]
- self.assertEqual(families, [
- 'Drosophilidae',
- 'Hominidae',
- 'Fabaceae',
- 'Amanitacae',
- ])
- self.assertEqual(len(db.connection.queries), 9)
+ def test():
+ world = Species.objects.all()
+ families = [o.genus.family.name for o in world]
+ self.assertEqual(families, [
+ 'Drosophilidae',
+ 'Hominidae',
+ 'Fabaceae',
+ 'Amanitacae',
+ ])
+ self.assertNumQueries(9, test)
def test_list_with_select_related(self):
"""
select_related() also of course applies to entire lists, not just
items. This test verifies the expected behavior with select_related.
"""
- world = Species.objects.all().select_related()
- families = [o.genus.family.name for o in world]
- self.assertEqual(families, [
- 'Drosophilidae',
- 'Hominidae',
- 'Fabaceae',
- 'Amanitacae',
- ])
- self.assertEqual(len(db.connection.queries), 1)
+ def test():
+ world = Species.objects.all().select_related()
+ families = [o.genus.family.name for o in world]
+ self.assertEqual(families, [
+ 'Drosophilidae',
+ 'Hominidae',
+ 'Fabaceae',
+ 'Amanitacae',
+ ])
+ self.assertNumQueries(1, test)
def test_depth(self, depth=1, expected=7):
"""
The "depth" argument to select_related() will stop the descent at a
particular level.
"""
- pea = Species.objects.select_related(depth=depth).get(name="sativum")
- self.assertEqual(
- pea.genus.family.order.klass.phylum.kingdom.domain.name,
- 'Eukaryota'
- )
+ def test():
+ pea = Species.objects.select_related(depth=depth).get(name="sativum")
+ self.assertEqual(
+ pea.genus.family.order.klass.phylum.kingdom.domain.name,
+ 'Eukaryota'
+ )
# Notice: one fewer queries than above because of depth=1
- self.assertEqual(len(db.connection.queries), expected)
+ self.assertNumQueries(expected, test)
def test_larger_depth(self):
"""
@@ -116,11 +114,12 @@ def test_list_with_depth(self):
The "depth" argument to select_related() will stop the descent at a
particular level. This can be used on lists as well.
"""
- world = Species.objects.all().select_related(depth=2)
- orders = [o.genus.family.order.name for o in world]
- self.assertEqual(orders,
- ['Diptera', 'Primates', 'Fabales', 'Agaricales'])
- self.assertEqual(len(db.connection.queries), 5)
+ def test():
+ world = Species.objects.all().select_related(depth=2)
+ orders = [o.genus.family.order.name for o in world]
+ self.assertEqual(orders,
+ ['Diptera', 'Primates', 'Fabales', 'Agaricales'])
+ self.assertNumQueries(5, test)
def test_select_related_with_extra(self):
s = Species.objects.all().select_related(depth=1)\
@@ -136,28 +135,31 @@ def test_certain_fields(self):
In this case, we explicitly say to select the 'genus' and
'genus.family' models, leading to the same number of queries as before.
"""
- world = Species.objects.select_related('genus__family')
- families = [o.genus.family.name for o in world]
- self.assertEqual(families,
- ['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae'])
- self.assertEqual(len(db.connection.queries), 1)
+ def test():
+ world = Species.objects.select_related('genus__family')
+ families = [o.genus.family.name for o in world]
+ self.assertEqual(families,
+ ['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae'])
+ self.assertNumQueries(1, test)
def test_more_certain_fields(self):
"""
In this case, we explicitly say to select the 'genus' and
'genus.family' models, leading to the same number of queries as before.
"""
- world = Species.objects.filter(genus__name='Amanita')\
- .select_related('genus__family')
- orders = [o.genus.family.order.name for o in world]
- self.assertEqual(orders, [u'Agaricales'])
- self.assertEqual(len(db.connection.queries), 2)
+ def test():
+ world = Species.objects.filter(genus__name='Amanita')\
+ .select_related('genus__family')
+ orders = [o.genus.family.order.name for o in world]
+ self.assertEqual(orders, [u'Agaricales'])
+ self.assertNumQueries(2, test)
def test_field_traversal(self):
- s = Species.objects.all().select_related('genus__family__order'
- ).order_by('id')[0:1].get().genus.family.order.name
- self.assertEqual(s, u'Diptera')
- self.assertEqual(len(db.connection.queries), 1)
+ def test():
+ s = Species.objects.all().select_related('genus__family__order'
+ ).order_by('id')[0:1].get().genus.family.order.name
+ self.assertEqual(s, u'Diptera')
+ self.assertNumQueries(1, test)
def test_depth_fields_fails(self):
self.assertRaises(TypeError,
@@ -2,9 +2,11 @@
from django.conf import settings
from django.db import connection
+from django.test import TestCase
from django.utils import unittest
-from models import CustomPKModel, UniqueTogetherModel, UniqueFieldsModel, UniqueForDateModel, ModelToValidate
+from models import (CustomPKModel, UniqueTogetherModel, UniqueFieldsModel,
+ UniqueForDateModel, ModelToValidate)
class GetUniqueCheckTests(unittest.TestCase):
@@ -51,37 +53,26 @@ def test_unique_for_date_exclusion(self):
), m._get_unique_checks(exclude='start_date')
)
-class PerformUniqueChecksTest(unittest.TestCase):
- def setUp(self):
- # Set debug to True to gain access to connection.queries.
- self._old_debug, settings.DEBUG = settings.DEBUG, True
- super(PerformUniqueChecksTest, self).setUp()
-
- def tearDown(self):
- # Restore old debug value.
- settings.DEBUG = self._old_debug
- super(PerformUniqueChecksTest, self).tearDown()
-
+class PerformUniqueChecksTest(TestCase):
def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self):
# Regression test for #12560
- query_count = len(connection.queries)
- mtv = ModelToValidate(number=10, name='Some Name')
- setattr(mtv, '_adding', True)
- mtv.full_clean()
- self.assertEqual(query_count, len(connection.queries))
+ def test():
+ mtv = ModelToValidate(number=10, name='Some Name')
+ setattr(mtv, '_adding', True)
+ mtv.full_clean()
+ self.assertNumQueries(0, test)
def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self):
# Regression test for #12560
- query_count = len(connection.queries)
- mtv = ModelToValidate(number=10, name='Some Name', id=123)
- setattr(mtv, '_adding', True)
- mtv.full_clean()
- self.assertEqual(query_count + 1, len(connection.queries))
+ def test():
+ mtv = ModelToValidate(number=10, name='Some Name', id=123)
+ setattr(mtv, '_adding', True)
+ mtv.full_clean()
+ self.assertNumQueries(1, test)
def test_primary_key_unique_check_not_performed_when_not_adding(self):
# Regression test for #12132
- query_count= len(connection.queries)
- mtv = ModelToValidate(number=10, name='Some Name')
- mtv.full_clean()
- self.assertEqual(query_count, len(connection.queries))
-
+ def test():
+ mtv = ModelToValidate(number=10, name='Some Name')
+ mtv.full_clean()
+ self.assertNumQueries(0, test)
@@ -6,7 +6,8 @@
# Import other tests for this package.
from modeltests.validation.validators import TestModelsWithValidators
-from modeltests.validation.test_unique import GetUniqueCheckTests, PerformUniqueChecksTest
+from modeltests.validation.test_unique import (GetUniqueCheckTests,
+ PerformUniqueChecksTest)
from modeltests.validation.test_custom_messages import CustomMessagesTest
@@ -111,4 +112,3 @@ def test_validation_with_invalid_blank_field(self):
article = Article(author_id=self.author.id)
form = ArticleForm(data, instance=article)
self.assertEqual(form.errors.keys(), ['pub_date'])
-
Oops, something went wrong.

0 comments on commit 5506653

Please sign in to comment.