Skip to content

Commit

Permalink
Refs #28478 -- Prevented connection attempts against disallowed datab…
Browse files Browse the repository at this point in the history
…ases in tests.

Mocking connect as well as cursor methods makes sure an appropriate error
message is surfaced when running a subset of test attempting to access a
a disallowed database.
  • Loading branch information
charettes authored and timgraham committed Jan 14, 2019
1 parent a96b901 commit f5b6350
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 27 deletions.
49 changes: 29 additions & 20 deletions django/test/testcases.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def message(self):
return '%s was rendered.' % self.template_name


class _CursorFailure:
class _DatabaseFailure:
def __init__(self, wrapped, message):
self.wrapped = wrapped
self.message = message
Expand Down Expand Up @@ -173,11 +173,17 @@ class SimpleTestCase(unittest.TestCase):

databases = _SimpleTestCaseDatabasesDescriptor()
_disallowed_database_msg = (
'Database queries are not allowed in SimpleTestCase subclasses. '
'Either subclass TestCase or TransactionTestCase to ensure proper '
'test isolation or add %(alias)r to %(test)s.databases to silence '
'Database %(operation)s to %(alias)r are not allowed in SimpleTestCase '
'subclasses. Either subclass TestCase or TransactionTestCase to ensure '
'proper test isolation or add %(alias)r to %(test)s.databases to silence '
'this failure.'
)
_disallowed_connection_methods = [
('connect', 'connections'),
('temporary_connection', 'connections'),
('cursor', 'queries'),
('chunked_cursor', 'queries'),
]

@classmethod
def setUpClass(cls):
Expand All @@ -188,7 +194,7 @@ def setUpClass(cls):
if cls._modified_settings:
cls._cls_modified_context = modify_settings(cls._modified_settings)
cls._cls_modified_context.enable()
cls._add_cursor_failures()
cls._add_databases_failures()

@classmethod
def _validate_databases(cls):
Expand All @@ -208,31 +214,34 @@ def _validate_databases(cls):
return frozenset(cls.databases)

@classmethod
def _add_cursor_failures(cls):
def _add_databases_failures(cls):
cls.databases = cls._validate_databases()
for alias in connections:
if alias in cls.databases:
continue
connection = connections[alias]
message = cls._disallowed_database_msg % {
'test': '%s.%s' % (cls.__module__, cls.__qualname__),
'alias': alias,
}
connection.cursor = _CursorFailure(connection.cursor, message)
connection.chunked_cursor = _CursorFailure(connection.chunked_cursor, message)
for name, operation in cls._disallowed_connection_methods:
message = cls._disallowed_database_msg % {
'test': '%s.%s' % (cls.__module__, cls.__qualname__),
'alias': alias,
'operation': operation,
}
method = getattr(connection, name)
setattr(connection, name, _DatabaseFailure(method, message))

@classmethod
def _remove_cursor_failures(cls):
def _remove_databases_failures(cls):
for alias in connections:
if alias in cls.databases:
continue
connection = connections[alias]
connection.cursor = connection.cursor.wrapped
connection.chunked_cursor = connection.chunked_cursor.wrapped
for name, _ in cls._disallowed_connection_methods:
method = getattr(connection, name)
setattr(connection, name, method.wrapped)

@classmethod
def tearDownClass(cls):
cls._remove_cursor_failures()
cls._remove_databases_failures()
if hasattr(cls, '_cls_modified_context'):
cls._cls_modified_context.disable()
delattr(cls, '_cls_modified_context')
Expand Down Expand Up @@ -894,8 +903,8 @@ class TransactionTestCase(SimpleTestCase):

databases = _TransactionTestCaseDatabasesDescriptor()
_disallowed_database_msg = (
'Database queries to %(alias)r are not allowed in this test. Add '
'%(alias)r to %(test)s.databases to ensure proper test isolation '
'Database %(operation)s to %(alias)r are not allowed in this test. '
'Add %(alias)r to %(test)s.databases to ensure proper test isolation '
'and silence this failure.'
)

Expand Down Expand Up @@ -1121,13 +1130,13 @@ def setUpClass(cls):
call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name})
except Exception:
cls._rollback_atomics(cls.cls_atomics)
cls._remove_cursor_failures()
cls._remove_databases_failures()
raise
try:
cls.setUpTestData()
except Exception:
cls._rollback_atomics(cls.cls_atomics)
cls._remove_cursor_failures()
cls._remove_databases_failures()
raise

@classmethod
Expand Down
13 changes: 12 additions & 1 deletion tests/test_utils/test_testcase.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from django.db import IntegrityError, transaction
from django.db import IntegrityError, connections, transaction
from django.test import TestCase, skipUnlessDBFeature

from .models import Car, PossessedCar
Expand All @@ -19,6 +19,17 @@ def test_fixture_teardown_checks_constraints(self):
finally:
self._rollback_atomics = rollback_atomics

def test_disallowed_database_connection(self):
message = (
"Database connections to 'other' are not allowed in this test. "
"Add 'other' to test_utils.test_testcase.TestTestCase.databases to "
"ensure proper test isolation and silence this failure."
)
with self.assertRaisesMessage(AssertionError, message):
connections['other'].connect()
with self.assertRaisesMessage(AssertionError, message):
connections['other'].temporary_connection()

def test_disallowed_database_queries(self):
message = (
"Database queries to 'other' are not allowed in this test. "
Expand Down
25 changes: 19 additions & 6 deletions tests/test_utils/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,11 +1159,24 @@ def test_failure_in_setUpTestData_should_rollback_transaction(self):


class DisallowedDatabaseQueriesTests(SimpleTestCase):
def test_disallowed_database_connections(self):
expected_message = (
"Database connections to 'default' are not allowed in SimpleTestCase "
"subclasses. Either subclass TestCase or TransactionTestCase to "
"ensure proper test isolation or add 'default' to "
"test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
"silence this failure."
)
with self.assertRaisesMessage(AssertionError, expected_message):
connection.connect()
with self.assertRaisesMessage(AssertionError, expected_message):
connection.temporary_connection()

def test_disallowed_database_queries(self):
expected_message = (
"Database queries are not allowed in SimpleTestCase subclasses. "
"Either subclass TestCase or TransactionTestCase to ensure proper "
"test isolation or add 'default' to "
"Database queries to 'default' are not allowed in SimpleTestCase "
"subclasses. Either subclass TestCase or TransactionTestCase to "
"ensure proper test isolation or add 'default' to "
"test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
"silence this failure."
)
Expand All @@ -1172,9 +1185,9 @@ def test_disallowed_database_queries(self):

def test_disallowed_database_chunked_cursor_queries(self):
expected_message = (
"Database queries are not allowed in SimpleTestCase subclasses. "
"Either subclass TestCase or TransactionTestCase to ensure proper "
"test isolation or add 'default' to "
"Database queries to 'default' are not allowed in SimpleTestCase "
"subclasses. Either subclass TestCase or TransactionTestCase to "
"ensure proper test isolation or add 'default' to "
"test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
"silence this failure."
)
Expand Down

0 comments on commit f5b6350

Please sign in to comment.