From b181aba7dd24c73ec9923c39e35393b0487a5f47 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Sat, 12 Jan 2019 16:14:54 -0500 Subject: [PATCH] Refs #28478 -- Prevented database feature based skipping on tests disallowing queries. Database features may require a connection to be established to determine whether or not they are enabled. --- django/test/testcases.py | 33 +++++++-- tests/backends/base/test_operations.py | 33 ++++----- .../test_ordinary_fields.py | 4 +- .../test_relative_fields.py | 11 +-- tests/model_indexes/tests.py | 7 +- tests/test_utils/tests.py | 72 +++++++++++++++---- tests/timezones/tests.py | 2 +- 7 files changed, 121 insertions(+), 41 deletions(-) diff --git a/django/test/testcases.py b/django/test/testcases.py index 0e1da33064a17..3dd14a08c09e2 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -1205,12 +1205,24 @@ def __get__(self, instance, cls=None): return False -def _deferredSkip(condition, reason): +def _deferredSkip(condition, reason, name): def decorator(test_func): + nonlocal condition if not (isinstance(test_func, type) and issubclass(test_func, unittest.TestCase)): @wraps(test_func) def skip_wrapper(*args, **kwargs): + if (args and isinstance(args[0], unittest.TestCase) and + connection.alias not in getattr(args[0], 'databases', {})): + raise ValueError( + "%s cannot be used on %s as %s doesn't allow queries " + "against the %r database." % ( + name, + args[0], + args[0].__class__.__qualname__, + connection.alias, + ) + ) if condition(): raise unittest.SkipTest(reason) return test_func(*args, **kwargs) @@ -1218,6 +1230,16 @@ def skip_wrapper(*args, **kwargs): else: # Assume a class is decorated test_item = test_func + databases = getattr(test_item, 'databases', None) + if not databases or connection.alias not in databases: + # Defer raising to allow importing test class's module. + def condition(): + raise ValueError( + "%s cannot be used on %s as it doesn't allow queries " + "against the '%s' database." % ( + name, test_item, connection.alias, + ) + ) # Retrieve the possibly existing value from the class's dict to # avoid triggering the descriptor. skip = test_func.__dict__.get('__unittest_skip__') @@ -1233,7 +1255,8 @@ def skipIfDBFeature(*features): """Skip a test if a database has at least one of the named features.""" return _deferredSkip( lambda: any(getattr(connection.features, feature, False) for feature in features), - "Database has feature(s) %s" % ", ".join(features) + "Database has feature(s) %s" % ", ".join(features), + 'skipIfDBFeature', ) @@ -1241,7 +1264,8 @@ def skipUnlessDBFeature(*features): """Skip a test unless a database has all the named features.""" return _deferredSkip( lambda: not all(getattr(connection.features, feature, False) for feature in features), - "Database doesn't support feature(s): %s" % ", ".join(features) + "Database doesn't support feature(s): %s" % ", ".join(features), + 'skipUnlessDBFeature', ) @@ -1249,7 +1273,8 @@ def skipUnlessAnyDBFeature(*features): """Skip a test unless a database has any of the named features.""" return _deferredSkip( lambda: not any(getattr(connection.features, feature, False) for feature in features), - "Database doesn't support any of the feature(s): %s" % ", ".join(features) + "Database doesn't support any of the feature(s): %s" % ", ".join(features), + 'skipUnlessAnyDBFeature', ) diff --git a/tests/backends/base/test_operations.py b/tests/backends/base/test_operations.py index 7ca0535135a45..607afb6dfc445 100644 --- a/tests/backends/base/test_operations.py +++ b/tests/backends/base/test_operations.py @@ -15,12 +15,6 @@ class SimpleDatabaseOperationTests(SimpleTestCase): def setUp(self): self.ops = BaseDatabaseOperations(connection=connection) - @skipIfDBFeature('can_distinct_on_fields') - def test_distinct_on_fields(self): - msg = 'DISTINCT ON fields is not supported by this database backend' - with self.assertRaisesMessage(NotSupportedError, msg): - self.ops.distinct_sql(['a', 'b'], None) - def test_deferrable_sql(self): self.assertEqual(self.ops.deferrable_sql(), '') @@ -123,6 +117,23 @@ def test_datetime_extract_sql(self): with self.assertRaisesMessage(NotImplementedError, self.may_requre_msg % 'datetime_extract_sql'): self.ops.datetime_extract_sql(None, None, None) + +class DatabaseOperationTests(TestCase): + def setUp(self): + self.ops = BaseDatabaseOperations(connection=connection) + + @skipIfDBFeature('supports_over_clause') + def test_window_frame_raise_not_supported_error(self): + msg = 'This backend does not support window expressions.' + with self.assertRaisesMessage(NotSupportedError, msg): + self.ops.window_frame_rows_start_end() + + @skipIfDBFeature('can_distinct_on_fields') + def test_distinct_on_fields(self): + msg = 'DISTINCT ON fields is not supported by this database backend' + with self.assertRaisesMessage(NotSupportedError, msg): + self.ops.distinct_sql(['a', 'b'], None) + @skipIfDBFeature('supports_temporal_subtraction') def test_subtract_temporals(self): duration_field = DurationField() @@ -133,13 +144,3 @@ def test_subtract_temporals(self): ) with self.assertRaisesMessage(NotSupportedError, msg): self.ops.subtract_temporals(duration_field_internal_type, None, None) - - -class DatabaseOperationTests(TestCase): - # Checking the 'supports_over_clause' feature requires a query for the - # MySQL backend to perform a version check. - @skipIfDBFeature('supports_over_clause') - def test_window_frame_raise_not_supported_error(self): - msg = 'This backend does not support window expressions.' - with self.assertRaisesMessage(NotSupportedError, msg): - connection.ops.window_frame_rows_start_end() diff --git a/tests/invalid_models_tests/test_ordinary_fields.py b/tests/invalid_models_tests/test_ordinary_fields.py index 3922c835b9e5d..9c7cf7f88c5cf 100644 --- a/tests/invalid_models_tests/test_ordinary_fields.py +++ b/tests/invalid_models_tests/test_ordinary_fields.py @@ -2,7 +2,7 @@ from django.core.checks import Error, Warning as DjangoWarning from django.db import connection, models -from django.test import SimpleTestCase, skipIfDBFeature +from django.test import SimpleTestCase, TestCase, skipIfDBFeature from django.test.utils import isolate_apps, override_settings from django.utils.functional import lazy from django.utils.timezone import now @@ -680,7 +680,7 @@ def test_fix_default_value_tz(self): @isolate_apps('invalid_models_tests') -class TextFieldTests(SimpleTestCase): +class TextFieldTests(TestCase): @skipIfDBFeature('supports_index_on_text_field') def test_max_length_warning(self): diff --git a/tests/invalid_models_tests/test_relative_fields.py b/tests/invalid_models_tests/test_relative_fields.py index cf1b3f737bb42..e68dd41c6f4a7 100644 --- a/tests/invalid_models_tests/test_relative_fields.py +++ b/tests/invalid_models_tests/test_relative_fields.py @@ -1,7 +1,9 @@ +from unittest import mock + from django.core.checks import Error, Warning as DjangoWarning -from django.db import models +from django.db import connection, models from django.db.models.fields.related import ForeignObject -from django.test.testcases import SimpleTestCase, skipIfDBFeature +from django.test.testcases import SimpleTestCase from django.test.utils import isolate_apps, override_settings @@ -501,13 +503,14 @@ class Model(models.Model): ), ]) - @skipIfDBFeature('interprets_empty_strings_as_nulls') def test_nullable_primary_key(self): class Model(models.Model): field = models.IntegerField(primary_key=True, null=True) field = Model._meta.get_field('field') - self.assertEqual(field.check(), [ + with mock.patch.object(connection.features, 'interprets_empty_strings_as_nulls', False): + results = field.check() + self.assertEqual(results, [ Error( 'Primary keys must not have null=True.', hint='Set null=False on the field, or remove primary_key=True argument.', diff --git a/tests/model_indexes/tests.py b/tests/model_indexes/tests.py index 0a34584c3e423..60fc0560e418f 100644 --- a/tests/model_indexes/tests.py +++ b/tests/model_indexes/tests.py @@ -1,13 +1,13 @@ from django.conf import settings from django.db import connection, models from django.db.models.query_utils import Q -from django.test import SimpleTestCase, skipUnlessDBFeature +from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test.utils import isolate_apps from .models import Book, ChildModel1, ChildModel2 -class IndexesTests(SimpleTestCase): +class SimpleIndexesTests(SimpleTestCase): def test_suffix(self): self.assertEqual(models.Index.suffix, 'idx') @@ -156,6 +156,9 @@ def test_abstract_children(self): index_names = [index.name for index in ChildModel2._meta.indexes] self.assertEqual(index_names, ['model_index_name_b6c374_idx']) + +class IndexesTests(TestCase): + @skipUnlessDBFeature('supports_tablespaces') def test_db_tablespace(self): editor = connection.schema_editor() diff --git a/tests/test_utils/tests.py b/tests/test_utils/tests.py index 680924b8380fd..3a315e7c10445 100644 --- a/tests/test_utils/tests.py +++ b/tests/test_utils/tests.py @@ -29,15 +29,14 @@ class SkippingTestCase(SimpleTestCase): - def _assert_skipping(self, func, expected_exc): - # We cannot simply use assertRaises because a SkipTest exception will go unnoticed + def _assert_skipping(self, func, expected_exc, msg=None): try: - func() - except expected_exc: - pass - except Exception as e: - self.fail("No %s exception should have been raised for %s." % ( - e.__class__.__name__, func.__name__)) + if msg is not None: + self.assertRaisesMessage(expected_exc, msg, func) + else: + self.assertRaises(expected_exc, func) + except unittest.SkipTest: + self.fail('%s should not result in a skipped test.' % func.__name__) def test_skip_unless_db_feature(self): """ @@ -65,6 +64,20 @@ def test_func4(): self._assert_skipping(test_func3, ValueError) self._assert_skipping(test_func4, unittest.SkipTest) + class SkipTestCase(SimpleTestCase): + @skipUnlessDBFeature('missing') + def test_foo(self): + pass + + self._assert_skipping( + SkipTestCase('test_foo').test_foo, + ValueError, + "skipUnlessDBFeature cannot be used on test_foo (test_utils.tests." + "SkippingTestCase.test_skip_unless_db_feature..SkipTestCase) " + "as SkippingTestCase.test_skip_unless_db_feature..SkipTestCase " + "doesn't allow queries against the 'default' database." + ) + def test_skip_if_db_feature(self): """ Testing the django.test.skipIfDBFeature decorator. @@ -95,17 +108,31 @@ def test_func5(): self._assert_skipping(test_func4, unittest.SkipTest) self._assert_skipping(test_func5, ValueError) + class SkipTestCase(SimpleTestCase): + @skipIfDBFeature('missing') + def test_foo(self): + pass + + self._assert_skipping( + SkipTestCase('test_foo').test_foo, + ValueError, + "skipIfDBFeature cannot be used on test_foo (test_utils.tests." + "SkippingTestCase.test_skip_if_db_feature..SkipTestCase) " + "as SkippingTestCase.test_skip_if_db_feature..SkipTestCase " + "doesn't allow queries against the 'default' database." + ) + -class SkippingClassTestCase(SimpleTestCase): +class SkippingClassTestCase(TestCase): def test_skip_class_unless_db_feature(self): @skipUnlessDBFeature("__class__") - class NotSkippedTests(unittest.TestCase): + class NotSkippedTests(TestCase): def test_dummy(self): return @skipUnlessDBFeature("missing") @skipIfDBFeature("__class__") - class SkippedTests(unittest.TestCase): + class SkippedTests(TestCase): def test_will_be_skipped(self): self.fail("We should never arrive here.") @@ -119,13 +146,34 @@ class SkippedTestsSubclass(SkippedTests): test_suite.addTest(SkippedTests('test_will_be_skipped')) test_suite.addTest(SkippedTestsSubclass('test_will_be_skipped')) except unittest.SkipTest: - self.fail("SkipTest should not be raised at this stage") + self.fail('SkipTest should not be raised here.') result = unittest.TextTestRunner(stream=StringIO()).run(test_suite) self.assertEqual(result.testsRun, 3) self.assertEqual(len(result.skipped), 2) self.assertEqual(result.skipped[0][1], 'Database has feature(s) __class__') self.assertEqual(result.skipped[1][1], 'Database has feature(s) __class__') + def test_missing_default_databases(self): + @skipIfDBFeature('missing') + class MissingDatabases(SimpleTestCase): + def test_assertion_error(self): + pass + + suite = unittest.TestSuite() + try: + suite.addTest(MissingDatabases('test_assertion_error')) + except unittest.SkipTest: + self.fail("SkipTest should not be raised at this stage") + runner = unittest.TextTestRunner(stream=StringIO()) + msg = ( + "skipIfDBFeature cannot be used on ." + "MissingDatabases'> as it doesn't allow queries against the " + "'default' database." + ) + with self.assertRaisesMessage(ValueError, msg): + runner.run(suite) + @override_settings(ROOT_URLCONF='test_utils.urls') class AssertNumQueriesTests(TestCase): diff --git a/tests/timezones/tests.py b/tests/timezones/tests.py index 06ac51594c399..e54b011c04269 100644 --- a/tests/timezones/tests.py +++ b/tests/timezones/tests.py @@ -581,7 +581,7 @@ def test_write_datetime(self): @skipUnlessDBFeature('supports_timezones') @override_settings(TIME_ZONE='Africa/Nairobi', USE_TZ=True) -class UnsupportedTimeZoneDatabaseTests(SimpleTestCase): +class UnsupportedTimeZoneDatabaseTests(TestCase): def test_time_zone_parameter_not_supported_if_database_supports_timezone(self): connections.databases['tz'] = connections.databases['default'].copy()