diff --git a/django/db/backends/base/creation.py b/django/db/backends/base/creation.py index 81cb34bd9f98..9bcebae9b0ea 100644 --- a/django/db/backends/base/creation.py +++ b/django/db/backends/base/creation.py @@ -340,3 +340,13 @@ def test_db_signature(self): settings_dict['ENGINE'], self._get_test_db_name(), ) + + def setup_worker_connection(self, _worker_id): + settings_dict = self.get_test_db_clone_settings(str(_worker_id)) + # connection.settings_dict must be updated in place for changes to be + # reflected in django.db.connections. If the following line assigned + # connection.settings_dict = settings_dict, new threads would connect + # to the default database instead of the appropriate clone. + self.connection.settings_dict.update(settings_dict) + self.mark_expected_failures_and_skips() + self.connection.close() diff --git a/django/db/backends/sqlite3/creation.py b/django/db/backends/sqlite3/creation.py index 4a4046c670f4..2608e2e1f7c6 100644 --- a/django/db/backends/sqlite3/creation.py +++ b/django/db/backends/sqlite3/creation.py @@ -1,5 +1,7 @@ +import multiprocessing import os import shutil +import sqlite3 import sys from pathlib import Path @@ -52,7 +54,10 @@ def get_test_db_clone_settings(self, suffix): orig_settings_dict = self.connection.settings_dict source_database_name = orig_settings_dict['NAME'] if self.is_in_memory_db(source_database_name): - return orig_settings_dict + if multiprocessing.get_start_method() == 'spawn': + return {**orig_settings_dict, 'NAME': f'{self.connection.alias}_{suffix}.sqlite3'} + elif multiprocessing.get_start_method() == 'fork': + return orig_settings_dict else: root, ext = os.path.splitext(orig_settings_dict['NAME']) return {**orig_settings_dict, 'NAME': '{}_{}{}'.format(root, suffix, ext)} @@ -80,6 +85,10 @@ def _clone_test_db(self, suffix, verbosity, keepdb=False): except Exception as e: self.log('Got an error cloning the test database: %s' % e) sys.exit(2) + else: + if multiprocessing.get_start_method() == 'spawn': + ondisk_db = sqlite3.connect(target_database_name, uri=True) + self.connection.connection.backup(ondisk_db) def _destroy_test_db(self, test_database_name, verbosity): if test_database_name and not self.is_in_memory_db(test_database_name): @@ -101,3 +110,17 @@ def test_db_signature(self): else: sig.append(test_database_name) return tuple(sig) + + def setup_worker_connection(self, _worker_id): + alias = self.connection.alias + worker_db = f'file:memorydb_{str(alias)}_{str(_worker_id)}?mode=memory&cache=shared' + sourcedb = sqlite3.connect(f'file:{str(alias)}_{str(_worker_id)}.sqlite3', uri=True) + second_db = sqlite3.connect(worker_db, uri=True) + sourcedb.backup(second_db) + sourcedb.close() + settings_dict = self.connection.settings_dict + settings_dict['NAME'] = worker_db + self.connection.settings_dict.update(settings_dict) + self.connection.connect() + second_db.close() + self.mark_expected_failures_and_skips() diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index ff3e3f47a9fa..d44efc58c499 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -5,6 +5,7 @@ from django.db.backends.base.features import BaseDatabaseFeatures from django.db.utils import OperationalError from django.utils.functional import cached_property +from django.utils.version import PY37 from .base import Database @@ -24,7 +25,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): can_rollback_ddl = True can_create_inline_fk = False supports_paramstyle_pyformat = False - can_clone_databases = True + can_clone_databases = PY37 supports_temporal_subtraction = True ignores_table_name_case = True supports_cast_with_precision = False diff --git a/django/test/runner.py b/django/test/runner.py index 878e62b1a885..2a7cc82dd834 100644 --- a/django/test/runner.py +++ b/django/test/runner.py @@ -22,7 +22,7 @@ from django.db import connections from django.test import SimpleTestCase, TestCase from django.test.utils import ( - NullTimeKeeper, TimeKeeper, iter_test_cases, + NullTimeKeeper, TimeKeeper, captured_stdout, iter_test_cases, setup_databases as _setup_databases, setup_test_environment, teardown_databases as _teardown_databases, teardown_test_environment, ) @@ -42,7 +42,7 @@ class DebugSQLTextTestResult(unittest.TextTestResult): def __init__(self, stream, descriptions, verbosity): - self.logger = logging.getLogger('django.db.backends') + self.logger = logging.getLogger("django.db.backends") self.logger.setLevel(logging.DEBUG) self.debug_sql_stream = None super().__init__(stream, descriptions, verbosity) @@ -65,7 +65,7 @@ def addError(self, test, err): super().addError(test, err) if self.debug_sql_stream is None: # Error before tests e.g. in setUpTestData(). - sql = '' + sql = "" else: self.debug_sql_stream.seek(0) sql = self.debug_sql_stream.read() @@ -80,7 +80,11 @@ def addSubTest(self, test, subtest, err): super().addSubTest(test, subtest, err) if err is not None: self.debug_sql_stream.seek(0) - errors = self.failures if issubclass(err[0], test.failureException) else self.errors + errors = ( + self.failures + if issubclass(err[0], test.failureException) + else self.errors + ) errors[-1] = errors[-1] + (self.debug_sql_stream.read(),) def printErrorList(self, flavour, errors): @@ -119,6 +123,7 @@ class DummyList: """ Dummy list class for faking storage of results in unittest.TestResult. """ + __slots__ = () def append(self, item): @@ -152,10 +157,10 @@ def __getstate__(self): # attributes. This is possible since they aren't used after unpickling # after being sent to ParallelTestSuite. state = self.__dict__.copy() - state.pop('_stdout_buffer', None) - state.pop('_stderr_buffer', None) - state.pop('_original_stdout', None) - state.pop('_original_stderr', None) + state.pop("_stdout_buffer", None) + state.pop("_stderr_buffer", None) + state.pop("_original_stdout", None) + state.pop("_original_stderr", None) return state @property @@ -171,7 +176,8 @@ def _confirm_picklable(self, obj): pickle.loads(pickle.dumps(obj)) def _print_unpicklable_subtest(self, test, subtest, pickle_exc): - print(""" + print( + """ Subtest failed: test: {} @@ -184,7 +190,10 @@ def _print_unpicklable_subtest(self, test, subtest, pickle_exc): You should re-run this test with --parallel=1 to reproduce the failure with a cleaner failure message. -""".format(test, subtest, pickle_exc)) +""".format( + test, subtest, pickle_exc + ) + ) def check_picklable(self, test, err): # Ensure that sys.exc_info() tuples are picklable. This displays a @@ -197,11 +206,16 @@ def check_picklable(self, test, err): self._confirm_picklable(err) except Exception as exc: original_exc_txt = repr(err[1]) - original_exc_txt = textwrap.fill(original_exc_txt, 75, initial_indent=' ', subsequent_indent=' ') + original_exc_txt = textwrap.fill( + original_exc_txt, 75, initial_indent=" ", subsequent_indent=" " + ) pickle_exc_txt = repr(exc) - pickle_exc_txt = textwrap.fill(pickle_exc_txt, 75, initial_indent=' ', subsequent_indent=' ') + pickle_exc_txt = textwrap.fill( + pickle_exc_txt, 75, initial_indent=" ", subsequent_indent=" " + ) if tblib is None: - print(""" + print( + """ {} failed: @@ -213,9 +227,13 @@ def check_picklable(self, test, err): In order to see the traceback, you should install tblib: python -m pip install tblib -""".format(test, original_exc_txt)) +""".format( + test, original_exc_txt + ) + ) else: - print(""" + print( + """ {} failed: @@ -230,7 +248,10 @@ def check_picklable(self, test, err): You should re-run this test with the --parallel=1 option to reproduce the failure and get a correct traceback. -""".format(test, original_exc_txt, pickle_exc_txt)) +""".format( + test, original_exc_txt, pickle_exc_txt + ) + ) raise def check_subtest_picklable(self, test, subtest): @@ -242,28 +263,28 @@ def check_subtest_picklable(self, test, subtest): def startTestRun(self): super().startTestRun() - self.events.append(('startTestRun',)) + self.events.append(("startTestRun",)) def stopTestRun(self): super().stopTestRun() - self.events.append(('stopTestRun',)) + self.events.append(("stopTestRun",)) def startTest(self, test): super().startTest(test) - self.events.append(('startTest', self.test_index)) + self.events.append(("startTest", self.test_index)) def stopTest(self, test): super().stopTest(test) - self.events.append(('stopTest', self.test_index)) + self.events.append(("stopTest", self.test_index)) def addError(self, test, err): self.check_picklable(test, err) - self.events.append(('addError', self.test_index, err)) + self.events.append(("addError", self.test_index, err)) super().addError(test, err) def addFailure(self, test, err): self.check_picklable(test, err) - self.events.append(('addFailure', self.test_index, err)) + self.events.append(("addFailure", self.test_index, err)) super().addFailure(test, err) def addSubTest(self, test, subtest, err): @@ -274,15 +295,15 @@ def addSubTest(self, test, subtest, err): # check_picklable() performs the tblib check. self.check_picklable(test, err) self.check_subtest_picklable(test, subtest) - self.events.append(('addSubTest', self.test_index, subtest, err)) + self.events.append(("addSubTest", self.test_index, subtest, err)) super().addSubTest(test, subtest, err) def addSuccess(self, test): - self.events.append(('addSuccess', self.test_index)) + self.events.append(("addSuccess", self.test_index)) super().addSuccess(test) def addSkip(self, test, reason): - self.events.append(('addSkip', self.test_index, reason)) + self.events.append(("addSkip", self.test_index, reason)) super().addSkip(test, reason) def addExpectedFailure(self, test, err): @@ -293,23 +314,23 @@ def addExpectedFailure(self, test, err): if tblib is None: err = err[0], err[1], None self.check_picklable(test, err) - self.events.append(('addExpectedFailure', self.test_index, err)) + self.events.append(("addExpectedFailure", self.test_index, err)) super().addExpectedFailure(test, err) def addUnexpectedSuccess(self, test): - self.events.append(('addUnexpectedSuccess', self.test_index)) + self.events.append(("addUnexpectedSuccess", self.test_index)) super().addUnexpectedSuccess(test) def wasSuccessful(self): """Tells whether or not this result was a success.""" - failure_types = {'addError', 'addFailure', 'addSubTest', 'addUnexpectedSuccess'} + failure_types = {"addError", "addFailure", "addSubTest", "addUnexpectedSuccess"} return all(e[0] not in failure_types for e in self.events) def _exc_info_to_string(self, err, test): # Make this method no-op. It only powers the default unittest behavior # for recording errors, but this class pickles errors into 'events' # instead. - return '' + return "" class RemoteTestRunner: @@ -340,9 +361,9 @@ def parallel_type(value): """Parse value passed to the --parallel option.""" # The current implementation of the parallel test runner requires # multiprocessing to start subprocesses with fork(). - if multiprocessing.get_start_method() != 'fork': + if multiprocessing.get_start_method() != "fork": return 1 - if value == 'auto': + if value == "auto": return multiprocessing.cpu_count() try: return int(value) @@ -355,7 +376,13 @@ def parallel_type(value): _worker_id = 0 -def _init_worker(counter): +def _init_worker( + counter, + process_setup=None, + process_setup_args=None, + initial_settings=None, + serialized_contents=None, +): """ Switch to databases dedicated to this worker. @@ -369,15 +396,20 @@ def _init_worker(counter): counter.value += 1 _worker_id = counter.value + if multiprocessing.get_start_method() == "spawn": + process_setup(*process_setup_args) + setup_test_environment() + for alias in connections: connection = connections[alias] - settings_dict = connection.creation.get_test_db_clone_settings(str(_worker_id)) - # connection.settings_dict must be updated in place for changes to be - # reflected in django.db.connections. If the following line assigned - # connection.settings_dict = settings_dict, new threads would connect - # to the default database instead of the appropriate clone. - connection.settings_dict.update(settings_dict) - connection.close() + if multiprocessing.get_start_method() == "spawn": + # Restore initial settings in spawned processes + connection.settings_dict.update(initial_settings[str(alias)]) + if serialized_contents and alias in serialized_contents.keys(): + connection._test_serialized_contents = serialized_contents[str(alias)] + connection.creation.setup_worker_connection(_worker_id) + with captured_stdout(): + call_command("check", verbosity=-1, databases=connections) def _run_subsuite(args): @@ -414,11 +446,21 @@ class ParallelTestSuite(unittest.TestSuite): run_subsuite = _run_subsuite runner_class = RemoteTestRunner - def __init__(self, subsuites, processes, failfast=False, buffer=False): + def __init__( + self, + subsuites, + processes, + failfast=False, + buffer=False, + process_setup=None, + process_setup_args=None, + ): self.subsuites = subsuites self.processes = processes self.failfast = failfast self.buffer = buffer + self.process_setup = process_setup + self.process_setup_args = process_setup_args super().__init__() def run(self, result): @@ -440,7 +482,13 @@ def run(self, result): pool = multiprocessing.Pool( processes=self.processes, initializer=self.init_worker.__func__, - initargs=[counter], + initargs=[ + counter, + self.process_setup, + self.process_setup_args, + self.initial_settings, + self.serialized_contents, + ], ) args = [ (self.runner_class, index, subsuite, self.failfast, self.buffer) @@ -490,30 +538,30 @@ class Shuffler: """ # This doesn't need to be cryptographically strong, so use what's fastest. - hash_algorithm = 'md5' + hash_algorithm = "md5" @classmethod def _hash_text(cls, text): h = hashlib.new(cls.hash_algorithm) - h.update(text.encode('utf-8')) + h.update(text.encode("utf-8")) return h.hexdigest() def __init__(self, seed=None): if seed is None: # Limit seeds to 10 digits for simpler output. - seed = random.randint(0, 10**10 - 1) - seed_source = 'generated' + seed = random.randint(0, 10 ** 10 - 1) + seed_source = "generated" else: - seed_source = 'given' + seed_source = "given" self.seed = seed self.seed_source = seed_source @property def seed_display(self): - return f'{self.seed!r} ({self.seed_source})' + return f"{self.seed!r} ({self.seed_source})" def _hash_item(self, item, key): - text = '{}{}'.format(self.seed, key(item)) + text = "{}{}".format(self.seed, key(item)) return self._hash_text(text) def shuffle(self, items, key): @@ -529,8 +577,10 @@ def shuffle(self, items, key): for item in items: hashed = self._hash_item(item, key) if hashed in hashes: - msg = 'item {!r} has same hash {!r} as item {!r}'.format( - item, hashed, hashes[hashed], + msg = "item {!r} has same hash {!r} as item {!r}".format( + item, + hashed, + hashes[hashed], ) raise RuntimeError(msg) hashes[hashed] = item @@ -546,12 +596,28 @@ class DiscoverRunner: test_loader = unittest.defaultTestLoader reorder_by = (TestCase, SimpleTestCase) - def __init__(self, pattern=None, top_level=None, verbosity=1, - interactive=True, failfast=False, keepdb=False, - reverse=False, debug_mode=False, debug_sql=False, parallel=0, - tags=None, exclude_tags=None, test_name_patterns=None, - pdb=False, buffer=False, enable_faulthandler=True, - timing=False, shuffle=False, **kwargs): + def __init__( + self, + pattern=None, + top_level=None, + verbosity=1, + interactive=True, + failfast=False, + keepdb=False, + reverse=False, + debug_mode=False, + debug_sql=False, + parallel=0, + tags=None, + exclude_tags=None, + test_name_patterns=None, + pdb=False, + buffer=False, + enable_faulthandler=True, + timing=False, + shuffle=False, + **kwargs, + ): self.pattern = pattern self.top_level = top_level @@ -572,7 +638,9 @@ def __init__(self, pattern=None, top_level=None, verbosity=1, faulthandler.enable(file=sys.__stderr__.fileno()) self.pdb = pdb if self.pdb and self.parallel > 1: - raise ValueError('You cannot use --pdb with parallel tests; pass --parallel=1 to use it.') + raise ValueError( + "You cannot use --pdb with parallel tests; pass --parallel=1 to use it." + ) self.buffer = buffer self.test_name_patterns = None self.time_keeper = TimeKeeper() if timing else NullTimeKeeper() @@ -580,7 +648,7 @@ def __init__(self, pattern=None, top_level=None, verbosity=1, # unittest does not export the _convert_select_pattern function # that converts command-line arguments to patterns. self.test_name_patterns = { - pattern if '*' in pattern else '*%s*' % pattern + pattern if "*" in pattern else "*%s*" % pattern for pattern in test_name_patterns } self.shuffle = shuffle @@ -589,77 +657,103 @@ def __init__(self, pattern=None, top_level=None, verbosity=1, @classmethod def add_arguments(cls, parser): parser.add_argument( - '-t', '--top-level-directory', dest='top_level', - help='Top level of project for unittest discovery.', + "-t", + "--top-level-directory", + dest="top_level", + help="Top level of project for unittest discovery.", ) parser.add_argument( - '-p', '--pattern', default="test*.py", - help='The test matching pattern. Defaults to test*.py.', + "-p", + "--pattern", + default="test*.py", + help="The test matching pattern. Defaults to test*.py.", ) parser.add_argument( - '--keepdb', action='store_true', - help='Preserves the test DB between runs.' + "--keepdb", action="store_true", help="Preserves the test DB between runs." ) parser.add_argument( - '--shuffle', nargs='?', default=False, type=int, metavar='SEED', - help='Shuffles test case order.', + "--shuffle", + nargs="?", + default=False, + type=int, + metavar="SEED", + help="Shuffles test case order.", ) parser.add_argument( - '-r', '--reverse', action='store_true', - help='Reverses test case order.', + "-r", + "--reverse", + action="store_true", + help="Reverses test case order.", ) parser.add_argument( - '--debug-mode', action='store_true', - help='Sets settings.DEBUG to True.', + "--debug-mode", + action="store_true", + help="Sets settings.DEBUG to True.", ) parser.add_argument( - '-d', '--debug-sql', action='store_true', - help='Prints logged SQL queries on failure.', + "-d", + "--debug-sql", + action="store_true", + help="Prints logged SQL queries on failure.", ) try: - default_parallel = int(os.environ['DJANGO_TEST_PROCESSES']) + default_parallel = int(os.environ["DJANGO_TEST_PROCESSES"]) except KeyError: default_parallel = 0 parser.add_argument( - '--parallel', nargs='?', const='auto', default=default_parallel, - type=parallel_type, metavar='N', + "--parallel", + nargs="?", + const="auto", + default=default_parallel, + type=parallel_type, + metavar="N", help=( - 'Run tests using up to N parallel processes. Use the value ' + "Run tests using up to N parallel processes. Use the value " '"auto" to run one test process for each processor core.' ), ) parser.add_argument( - '--tag', action='append', dest='tags', - help='Run only tests with the specified tag. Can be used multiple times.', + "--tag", + action="append", + dest="tags", + help="Run only tests with the specified tag. Can be used multiple times.", ) parser.add_argument( - '--exclude-tag', action='append', dest='exclude_tags', - help='Do not run tests with the specified tag. Can be used multiple times.', + "--exclude-tag", + action="append", + dest="exclude_tags", + help="Do not run tests with the specified tag. Can be used multiple times.", ) parser.add_argument( - '--pdb', action='store_true', - help='Runs a debugger (pdb, or ipdb if installed) on error or failure.' + "--pdb", + action="store_true", + help="Runs a debugger (pdb, or ipdb if installed) on error or failure.", ) parser.add_argument( - '-b', '--buffer', action='store_true', - help='Discard output from passing tests.', + "-b", + "--buffer", + action="store_true", + help="Discard output from passing tests.", ) parser.add_argument( - '--no-faulthandler', action='store_false', dest='enable_faulthandler', - help='Disables the Python faulthandler module during tests.', + "--no-faulthandler", + action="store_false", + dest="enable_faulthandler", + help="Disables the Python faulthandler module during tests.", ) parser.add_argument( - '--timing', action='store_true', - help=( - 'Output timings, including database set up and total run time.' - ), + "--timing", + action="store_true", + help=("Output timings, including database set up and total run time."), ) parser.add_argument( - '-k', action='append', dest='test_name_patterns', + "-k", + action="append", + dest="test_name_patterns", help=( - 'Only run test methods and classes that match the pattern ' - 'or substring. Can be used multiple times. Same as ' - 'unittest -k option.' + "Only run test methods and classes that match the pattern " + "or substring. Can be used multiple times. Same as " + "unittest -k option." ), ) @@ -690,7 +784,7 @@ def setup_shuffler(self): if self.shuffle is False: return shuffler = Shuffler(seed=self.shuffle) - self.log(f'Using shuffle seed: {shuffler.seed_display}') + self.log(f"Using shuffle seed: {shuffler.seed_display}") self._shuffler = shuffler @contextmanager @@ -722,15 +816,15 @@ def load_tests_for_label(self, label, discover_kwargs): if os.path.exists(label_as_path): assert tests is None raise RuntimeError( - f'One of the test labels is a path to a file: {label!r}, ' - f'which is not supported. Use a dotted module name or ' - f'path to a directory instead.' + f"One of the test labels is a path to a file: {label!r}, " + f"which is not supported. Use a dotted module name or " + f"path to a directory instead." ) return tests kwargs = discover_kwargs.copy() if os.path.isdir(label_as_path) and not self.top_level: - kwargs['top_level_dir'] = find_top_level(label_as_path) + kwargs["top_level_dir"] = find_top_level(label_as_path) with self.load_with_patterns(): tests = self.test_loader.discover(start_dir=label, **kwargs) @@ -740,21 +834,28 @@ def load_tests_for_label(self, label, discover_kwargs): self.test_loader._top_level_dir = None return tests - def build_suite(self, test_labels=None, extra_tests=None, **kwargs): + def build_suite( + self, + test_labels=None, + extra_tests=None, + process_setup=None, + process_setup_args=None, + **kwargs, + ): if extra_tests is not None: warnings.warn( - 'The extra_tests argument is deprecated.', + "The extra_tests argument is deprecated.", RemovedInDjango50Warning, stacklevel=2, ) - test_labels = test_labels or ['.'] + test_labels = test_labels or ["."] extra_tests = extra_tests or [] discover_kwargs = {} if self.pattern is not None: - discover_kwargs['pattern'] = self.pattern + discover_kwargs["pattern"] = self.pattern if self.top_level is not None: - discover_kwargs['top_level_dir'] = self.top_level + discover_kwargs["top_level_dir"] = self.top_level self.setup_shuffler() all_tests = [] @@ -767,12 +868,12 @@ def build_suite(self, test_labels=None, extra_tests=None, **kwargs): if self.tags or self.exclude_tags: if self.tags: self.log( - 'Including test tag(s): %s.' % ', '.join(sorted(self.tags)), + "Including test tag(s): %s." % ", ".join(sorted(self.tags)), level=logging.DEBUG, ) if self.exclude_tags: self.log( - 'Excluding test tag(s): %s.' % ', '.join(sorted(self.exclude_tags)), + "Excluding test tag(s): %s." % ", ".join(sorted(self.exclude_tags)), level=logging.DEBUG, ) all_tests = filter_tests_by_tags(all_tests, self.tags, self.exclude_tags) @@ -781,13 +882,15 @@ def build_suite(self, test_labels=None, extra_tests=None, **kwargs): # _FailedTest objects include things like test modules that couldn't be # found or that couldn't be loaded due to syntax errors. test_types = (unittest.loader._FailedTest, *self.reorder_by) - all_tests = list(reorder_tests( - all_tests, - test_types, - shuffler=self._shuffler, - reverse=self.reverse, - )) - self.log('Found %d test(s).' % len(all_tests)) + all_tests = list( + reorder_tests( + all_tests, + test_types, + shuffler=self._shuffler, + reverse=self.reverse, + ) + ) + self.log("Found %d test(s)." % len(all_tests)) suite = self.test_suite(all_tests) if self.parallel > 1: @@ -804,15 +907,36 @@ def build_suite(self, test_labels=None, extra_tests=None, **kwargs): processes, self.failfast, self.buffer, + process_setup, + process_setup_args, ) return suite def setup_databases(self, **kwargs): return _setup_databases( - self.verbosity, self.interactive, time_keeper=self.time_keeper, keepdb=self.keepdb, - debug_sql=self.debug_sql, parallel=self.parallel, **kwargs + self.verbosity, + self.interactive, + time_keeper=self.time_keeper, + keepdb=self.keepdb, + debug_sql=self.debug_sql, + parallel=self.parallel, + **kwargs, ) + def setup_spawn(self, suite, serialized_aliases): + if self.parallel > 1 and multiprocessing.get_start_method() == "spawn": + suite.initial_settings = { + str(alias): connections[alias].settings_dict for alias in connections + } + suite.serialized_contents = { + str(alias): connections[alias]._test_serialized_contents + for alias in connections + if alias in serialized_aliases + } + else: + suite.initial_settings = None + suite.serialized_contents = None + def get_resultclass(self): if self.debug_sql: return DebugSQLTextTestResult @@ -821,16 +945,16 @@ def get_resultclass(self): def get_test_runner_kwargs(self): return { - 'failfast': self.failfast, - 'resultclass': self.get_resultclass(), - 'verbosity': self.verbosity, - 'buffer': self.buffer, + "failfast": self.failfast, + "resultclass": self.get_resultclass(), + "verbosity": self.verbosity, + "buffer": self.buffer, } def run_checks(self, databases): # Checks are run after database creation since some checks require # database access. - call_command('check', verbosity=self.verbosity, databases=databases) + call_command("check", verbosity=self.verbosity, databases=databases) def run_suite(self, suite, **kwargs): kwargs = self.get_test_runner_kwargs() @@ -840,7 +964,7 @@ def run_suite(self, suite, **kwargs): finally: if self._shuffler is not None: seed_display = self._shuffler.seed_display - self.log(f'Used shuffle seed: {seed_display}') + self.log(f"Used shuffle seed: {seed_display}") def teardown_databases(self, old_config, **kwargs): """Destroy all the non-mirror databases.""" @@ -861,11 +985,11 @@ def suite_result(self, suite, result, **kwargs): def _get_databases(self, suite): databases = {} for test in iter_test_cases(suite): - test_databases = getattr(test, 'databases', None) - if test_databases == '__all__': + test_databases = getattr(test, "databases", None) + if test_databases == "__all__": test_databases = connections if test_databases: - serialized_rollback = getattr(test, 'serialized_rollback', False) + serialized_rollback = getattr(test, "serialized_rollback", False) databases.update( (alias, serialized_rollback or databases.get(alias, False)) for alias in test_databases @@ -877,12 +1001,20 @@ def get_databases(self, suite): unused_databases = [alias for alias in connections if alias not in databases] if unused_databases: self.log( - 'Skipping setup of unused database(s): %s.' % ', '.join(sorted(unused_databases)), + "Skipping setup of unused database(s): %s." + % ", ".join(sorted(unused_databases)), level=logging.DEBUG, ) return databases - def run_tests(self, test_labels, extra_tests=None, **kwargs): + def run_tests( + self, + test_labels, + extra_tests=None, + process_setup=None, + process_setup_args=None, + **kwargs, + ): """ Run the unit tests for all the test labels in the provided list. @@ -893,18 +1025,19 @@ def run_tests(self, test_labels, extra_tests=None, **kwargs): """ if extra_tests is not None: warnings.warn( - 'The extra_tests argument is deprecated.', + "The extra_tests argument is deprecated.", RemovedInDjango50Warning, stacklevel=2, ) self.setup_test_environment() - suite = self.build_suite(test_labels, extra_tests) + suite = self.build_suite( + test_labels, extra_tests, process_setup, process_setup_args + ) databases = self.get_databases(suite) serialized_aliases = set( - alias - for alias, serialize in databases.items() if serialize + alias for alias, serialize in databases.items() if serialize ) - with self.time_keeper.timed('Total database setup'): + with self.time_keeper.timed("Total database setup"): old_config = self.setup_databases( aliases=databases, serialized_aliases=serialized_aliases, @@ -912,13 +1045,14 @@ def run_tests(self, test_labels, extra_tests=None, **kwargs): run_failed = False try: self.run_checks(databases) + self.setup_spawn(suite, serialized_aliases) result = self.run_suite(suite) except Exception: run_failed = True raise finally: try: - with self.time_keeper.timed('Total database teardown'): + with self.time_keeper.timed("Total database teardown"): self.teardown_databases(old_config) self.teardown_test_environment() except Exception: @@ -941,7 +1075,7 @@ def try_importing(label): except (ImportError, TypeError): return (False, False) - return (True, hasattr(mod, '__path__')) + return (True, hasattr(mod, "__path__")) def find_top_level(top_level): @@ -957,7 +1091,7 @@ def find_top_level(top_level): # top-level module or as a directory path, unittest unfortunately prefers # the latter. while True: - init_py = os.path.join(top_level, '__init__.py') + init_py = os.path.join(top_level, "__init__.py") if not os.path.exists(init_py): break try_next = os.path.dirname(top_level) @@ -969,7 +1103,7 @@ def find_top_level(top_level): def _class_shuffle_key(cls): - return f'{cls.__module__}.{cls.__qualname__}' + return f"{cls.__module__}.{cls.__qualname__}" def shuffle_tests(tests, shuffler): @@ -1054,9 +1188,7 @@ def partition_suite_by_case(suite): """Partition a test suite by test case, preserving the order of tests.""" suite_class = type(suite) all_tests = iter_test_cases(suite) - return [ - suite_class(tests) for _, tests in itertools.groupby(all_tests, type) - ] + return [suite_class(tests) for _, tests in itertools.groupby(all_tests, type)] def test_match_tags(test, tags, exclude_tags): @@ -1064,11 +1196,11 @@ def test_match_tags(test, tags, exclude_tags): # Tests that couldn't load always match to prevent tests from falsely # passing due e.g. to syntax errors. return True - test_tags = set(getattr(test, 'tags', [])) - test_fn_name = getattr(test, '_testMethodName', str(test)) + test_tags = set(getattr(test, "tags", [])) + test_fn_name = getattr(test, "_testMethodName", str(test)) if hasattr(test, test_fn_name): test_fn = getattr(test, test_fn_name) - test_fn_tags = list(getattr(test_fn, 'tags', [])) + test_fn_tags = list(getattr(test_fn, "tags", [])) test_tags = test_tags.union(test_fn_tags) if tags and test_tags.isdisjoint(tags): return False diff --git a/django/utils/autoreload.py b/django/utils/autoreload.py index 15df088c4812..f80b1676cd27 100644 --- a/django/utils/autoreload.py +++ b/django/utils/autoreload.py @@ -126,7 +126,7 @@ def iter_modules_and_files(modules, extra_files): # cause issues here. if not isinstance(module, ModuleType): continue - if module.__name__ == '__main__': + if module.__name__ in ('__main__', '__mp_main__'): # __main__ (usually manage.py) doesn't always have a __spec__ set. # Handle this by falling back to using __file__, resolved below. # See https://docs.python.org/reference/import.html#main-spec diff --git a/tests/admin_checks/tests.py b/tests/admin_checks/tests.py index 67625c7c8685..e7d07fab5d04 100644 --- a/tests/admin_checks/tests.py +++ b/tests/admin_checks/tests.py @@ -68,6 +68,7 @@ class SessionMiddlewareSubclass(SessionMiddleware): ], ) class SystemChecksTestCase(SimpleTestCase): + databases = '__all__' def test_checks_are_performed(self): admin.site.register(Song, MyAdmin) diff --git a/tests/check_framework/tests.py b/tests/check_framework/tests.py index f43abaca12d2..77c4b986cb90 100644 --- a/tests/check_framework/tests.py +++ b/tests/check_framework/tests.py @@ -337,5 +337,7 @@ class ModelWithDescriptorCalledCheck(models.Model): class ChecksRunDuringTests(SimpleTestCase): + databases = '__all__' + def test_registered_check_did_run(self): self.assertTrue(my_check.did_run) diff --git a/tests/contenttypes_tests/test_checks.py b/tests/contenttypes_tests/test_checks.py index 44cd3c275852..d75caf001b60 100644 --- a/tests/contenttypes_tests/test_checks.py +++ b/tests/contenttypes_tests/test_checks.py @@ -13,6 +13,7 @@ @isolate_apps('contenttypes_tests', attr_name='apps') class GenericForeignKeyTests(SimpleTestCase): + databases = '__all__' def test_missing_content_type_field(self): class TaggedItem(models.Model): diff --git a/tests/contenttypes_tests/test_management.py b/tests/contenttypes_tests/test_management.py index 57d3757abe98..4c618de420ca 100644 --- a/tests/contenttypes_tests/test_management.py +++ b/tests/contenttypes_tests/test_management.py @@ -22,6 +22,8 @@ class RemoveStaleContentTypesTests(TestCase): @classmethod def setUpTestData(cls): + with captured_stdout(): + call_command('remove_stale_contenttypes', interactive=False, include_stale_apps=True, verbosity=2) cls.before_count = ContentType.objects.count() cls.content_type = ContentType.objects.create(app_label='contenttypes_tests', model='Fake') diff --git a/tests/postgres_tests/__init__.py b/tests/postgres_tests/__init__.py index 2b84fc25db0d..d9abe09f9c48 100644 --- a/tests/postgres_tests/__init__.py +++ b/tests/postgres_tests/__init__.py @@ -21,3 +21,9 @@ class PostgreSQLTestCase(TestCase): @modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) class PostgreSQLWidgetTestCase(WidgetTest, PostgreSQLSimpleTestCase): pass + + +@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests") +@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +class PostgreSQLHStoreTestCase(PostgreSQLTestCase): + pass diff --git a/tests/postgres_tests/test_bulk_update.py b/tests/postgres_tests/test_bulk_update.py index da5aee0f7059..64f0a2b2e1e4 100644 --- a/tests/postgres_tests/test_bulk_update.py +++ b/tests/postgres_tests/test_bulk_update.py @@ -1,6 +1,8 @@ from datetime import date -from . import PostgreSQLTestCase +from django.test import modify_settings + +from . import PostgreSQLHStoreTestCase from .models import ( HStoreModel, IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, RangesModel, @@ -12,7 +14,8 @@ pass # psycopg2 isn't installed. -class BulkSaveTests(PostgreSQLTestCase): +@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +class BulkSaveTests(PostgreSQLHStoreTestCase): def test_bulk_update(self): test_data = [ (IntegerArrayModel, 'field', [], [1, 2, 3]), diff --git a/tests/postgres_tests/test_hstore.py b/tests/postgres_tests/test_hstore.py index 0c01129e180d..d55f70e84f62 100644 --- a/tests/postgres_tests/test_hstore.py +++ b/tests/postgres_tests/test_hstore.py @@ -7,7 +7,7 @@ from django.forms import Form from django.test.utils import CaptureQueriesContext, isolate_apps -from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase +from . import PostgreSQLHStoreTestCase, PostgreSQLSimpleTestCase from .models import HStoreModel, PostgreSQLModel try: @@ -19,7 +19,7 @@ pass -class SimpleTests(PostgreSQLTestCase): +class SimpleTests(PostgreSQLHStoreTestCase): def test_save_load_success(self): value = {'a': 'b'} instance = HStoreModel(field=value) @@ -68,7 +68,7 @@ def test_array_field(self): self.assertEqual(instance.array_field, expected_value) -class TestQuerying(PostgreSQLTestCase): +class TestQuerying(PostgreSQLHStoreTestCase): @classmethod def setUpTestData(cls): diff --git a/tests/runtests.py b/tests/runtests.py index dfbc70818c3e..bc3de41135bb 100755 --- a/tests/runtests.py +++ b/tests/runtests.py @@ -16,7 +16,7 @@ import django except ImportError as e: raise RuntimeError( - 'Django module not found, reference tests/README.rst for instructions.' + "Django module not found, reference tests/README.rst for instructions." ) from e else: from django.apps import apps @@ -37,21 +37,23 @@ pass else: # Ignore informational warnings from QuerySet.explain(). - warnings.filterwarnings('ignore', r'\(1003, *', category=MySQLdb.Warning) + warnings.filterwarnings("ignore", r"\(1003, *", category=MySQLdb.Warning) # Make deprecation warnings errors to ensure no usage of deprecated features. -warnings.simplefilter('error', RemovedInDjango50Warning) -warnings.simplefilter('error', RemovedInDjango41Warning) +warnings.simplefilter("error", RemovedInDjango50Warning) +warnings.simplefilter("error", RemovedInDjango41Warning) # Make resource and runtime warning errors to ensure no usage of error prone # patterns. warnings.simplefilter("error", ResourceWarning) warnings.simplefilter("error", RuntimeWarning) # Ignore known warnings in test dependencies. -warnings.filterwarnings("ignore", "'U' mode is deprecated", DeprecationWarning, module='docutils.io') +warnings.filterwarnings( + "ignore", "'U' mode is deprecated", DeprecationWarning, module="docutils.io" +) # RemovedInDjango41Warning: Ignore MemcachedCache deprecation warning. warnings.filterwarnings( - 'ignore', - 'MemcachedCache is deprecated', + "ignore", + "MemcachedCache is deprecated", category=RemovedInDjango41Warning, ) @@ -64,13 +66,13 @@ RUNTESTS_DIR = os.path.abspath(os.path.dirname(__file__)) -TEMPLATE_DIR = os.path.join(RUNTESTS_DIR, 'templates') +TEMPLATE_DIR = os.path.join(RUNTESTS_DIR, "templates") # Create a specific subdirectory for the duration of the test suite. -TMPDIR = tempfile.mkdtemp(prefix='django_') +TMPDIR = tempfile.mkdtemp(prefix="django_") # Set the TMPDIR environment variable in addition to tempfile.tempdir # so that children processes inherit it. -tempfile.tempdir = os.environ['TMPDIR'] = TMPDIR +tempfile.tempdir = os.environ["TMPDIR"] = TMPDIR # Removing the temporary TMPDIR. atexit.register(shutil.rmtree, TMPDIR) @@ -79,35 +81,35 @@ # This is a dict mapping RUNTESTS_DIR subdirectory to subdirectories of that # directory to skip when searching for test modules. SUBDIRS_TO_SKIP = { - '': {'import_error_package', 'test_runner_apps'}, - 'gis_tests': {'data'}, + "": {"import_error_package", "test_runner_apps"}, + "gis_tests": {"data"}, } ALWAYS_INSTALLED_APPS = [ - 'django.contrib.contenttypes', - 'django.contrib.auth', - 'django.contrib.sites', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.admin.apps.SimpleAdminConfig', - 'django.contrib.staticfiles', + "django.contrib.contenttypes", + "django.contrib.auth", + "django.contrib.sites", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.admin.apps.SimpleAdminConfig", + "django.contrib.staticfiles", ] ALWAYS_MIDDLEWARE = [ - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", ] # Need to add the associated contrib app to INSTALLED_APPS in some cases to # avoid "RuntimeError: Model class X doesn't declare an explicit app_label # and isn't in an application in INSTALLED_APPS." CONTRIB_TESTS_TO_APPS = { - 'deprecation': ['django.contrib.flatpages', 'django.contrib.redirects'], - 'flatpages_tests': ['django.contrib.flatpages'], - 'redirects_tests': ['django.contrib.redirects'], + "deprecation": ["django.contrib.flatpages", "django.contrib.redirects"], + "flatpages_tests": ["django.contrib.flatpages"], + "redirects_tests": ["django.contrib.redirects"], } @@ -118,12 +120,12 @@ def get_test_modules(gis_enabled): The yielded names have either one dotted part like "test_runner" or, in the case of GIS tests, two dotted parts like "gis_tests.gdal_tests". """ - discovery_dirs = [''] + discovery_dirs = [""] if gis_enabled: # GIS tests are in nested apps - discovery_dirs.append('gis_tests') + discovery_dirs.append("gis_tests") else: - SUBDIRS_TO_SKIP[''].add('gis_tests') + SUBDIRS_TO_SKIP[""].add("gis_tests") for dirname in discovery_dirs: dirpath = os.path.join(RUNTESTS_DIR, dirname) @@ -131,15 +133,15 @@ def get_test_modules(gis_enabled): with os.scandir(dirpath) as entries: for f in entries: if ( - '.' in f.name or - os.path.basename(f.name) in subdirs_to_skip or - f.is_file() or - not os.path.exists(os.path.join(f.path, '__init__.py')) + "." in f.name + or os.path.basename(f.name) in subdirs_to_skip + or f.is_file() + or not os.path.exists(os.path.join(f.path, "__init__.py")) ): continue test_module = f.name if dirname: - test_module = dirname + '.' + test_module + test_module = dirname + "." + test_module yield test_module @@ -148,12 +150,12 @@ def get_label_module(label): path = Path(label) if len(path.parts) == 1: # Interpret the label as a dotted module name. - return label.split('.')[0] + return label.split(".")[0] # Otherwise, interpret the label as a path. Check existence first to # provide a better error message than relative_to() if it doesn't exist. if not path.exists(): - raise RuntimeError(f'Test label path {label} does not exist') + raise RuntimeError(f"Test label path {label} does not exist") path = path.resolve() rel_path = path.relative_to(RUNTESTS_DIR) return rel_path.parts[0] @@ -170,20 +172,20 @@ def get_filtered_test_modules(start_at, start_after, gis_enabled, test_labels=No # It would be nice to put this validation earlier but it must come after # django.setup() so that connection.features.gis_enabled can be accessed. - if 'gis_tests' in label_modules and not gis_enabled: - print('Aborting: A GIS database backend is required to run gis_tests.') + if "gis_tests" in label_modules and not gis_enabled: + print("Aborting: A GIS database backend is required to run gis_tests.") sys.exit(1) def _module_match_label(module_name, label): # Exact or ancestor match. - return module_name == label or module_name.startswith(label + '.') + return module_name == label or module_name.startswith(label + ".") start_label = start_at or start_after for test_module in get_test_modules(gis_enabled): if start_label: if not _module_match_label(test_module, start_label): continue - start_label = '' + start_label = "" if not start_at: assert start_after # Skip the current one before starting. @@ -191,58 +193,60 @@ def _module_match_label(module_name, label): # If the module (or an ancestor) was named on the command line, or # no modules were named (i.e., run all), include the test module. if not test_labels or any( - _module_match_label(test_module, label_module) for - label_module in label_modules + _module_match_label(test_module, label_module) + for label_module in label_modules ): yield test_module def setup_collect_tests(start_at, start_after, test_labels=None): state = { - 'INSTALLED_APPS': settings.INSTALLED_APPS, - 'ROOT_URLCONF': getattr(settings, "ROOT_URLCONF", ""), - 'TEMPLATES': settings.TEMPLATES, - 'LANGUAGE_CODE': settings.LANGUAGE_CODE, - 'STATIC_URL': settings.STATIC_URL, - 'STATIC_ROOT': settings.STATIC_ROOT, - 'MIDDLEWARE': settings.MIDDLEWARE, + "INSTALLED_APPS": settings.INSTALLED_APPS, + "ROOT_URLCONF": getattr(settings, "ROOT_URLCONF", ""), + "TEMPLATES": settings.TEMPLATES, + "LANGUAGE_CODE": settings.LANGUAGE_CODE, + "STATIC_URL": settings.STATIC_URL, + "STATIC_ROOT": settings.STATIC_ROOT, + "MIDDLEWARE": settings.MIDDLEWARE, } # Redirect some settings for the duration of these tests. settings.INSTALLED_APPS = ALWAYS_INSTALLED_APPS - settings.ROOT_URLCONF = 'urls' - settings.STATIC_URL = 'static/' - settings.STATIC_ROOT = os.path.join(TMPDIR, 'static') - settings.TEMPLATES = [{ - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [TEMPLATE_DIR], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', - ], - }, - }] - settings.LANGUAGE_CODE = 'en' + settings.ROOT_URLCONF = "urls" + settings.STATIC_URL = "static/" + settings.STATIC_ROOT = os.path.join(TMPDIR, "static") + settings.TEMPLATES = [ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [TEMPLATE_DIR], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + ], + }, + } + ] + settings.LANGUAGE_CODE = "en" settings.SITE_ID = 1 settings.MIDDLEWARE = ALWAYS_MIDDLEWARE settings.MIGRATION_MODULES = { # This lets us skip creating migrations for the test models as many of # them depend on one of the following contrib applications. - 'auth': None, - 'contenttypes': None, - 'sessions': None, + "auth": None, + "contenttypes": None, + "sessions": None, } log_config = copy.deepcopy(DEFAULT_LOGGING) # Filter out non-error logging so we don't have to capture it in lots of # tests. - log_config['loggers']['django']['level'] = 'ERROR' + log_config["loggers"]["django"]["level"] = "ERROR" settings.LOGGING = log_config settings.SILENCED_SYSTEM_CHECKS = [ - 'fields.W342', # ForeignKey(unique=True) -> OneToOneField + "fields.W342", # ForeignKey(unique=True) -> OneToOneField ] # Load all the ALWAYS_INSTALLED_APPS. @@ -253,9 +257,14 @@ def setup_collect_tests(start_at, start_after, test_labels=None): # backends (e.g. PostGIS). gis_enabled = connection.features.gis_enabled - test_modules = list(get_filtered_test_modules( - start_at, start_after, gis_enabled, test_labels=test_labels, - )) + test_modules = list( + get_filtered_test_modules( + start_at, + start_after, + gis_enabled, + test_labels=test_labels, + ) + ) return test_modules, state @@ -280,18 +289,20 @@ def get_apps_to_install(test_modules): # Add contrib.gis to INSTALLED_APPS if needed (rather than requiring # @override_settings(INSTALLED_APPS=...) on all test cases. if connection.features.gis_enabled: - yield 'django.contrib.gis' + yield "django.contrib.gis" def setup_run_tests(verbosity, start_at, start_after, test_labels=None): - test_modules, state = setup_collect_tests(start_at, start_after, test_labels=test_labels) + test_modules, state = setup_collect_tests( + start_at, start_after, test_labels=test_labels + ) installed_apps = set(get_installed()) for app in get_apps_to_install(test_modules): if app in installed_apps: continue if verbosity >= 2: - print(f'Importing application {app}') + print(f"Importing application {app}") settings.INSTALLED_APPS.append(app) installed_apps.add(app) @@ -300,15 +311,15 @@ def setup_run_tests(verbosity, start_at, start_after, test_labels=None): # Force declaring available_apps in TransactionTestCase for faster tests. def no_available_apps(self): raise Exception( - 'Please define available_apps in TransactionTestCase and its ' - 'subclasses.' + "Please define available_apps in TransactionTestCase and its " "subclasses." ) + TransactionTestCase.available_apps = property(no_available_apps) TestCase.available_apps = None # Set an environment variable that other code may consult to see if # Django's own test suite is running. - os.environ['RUNNING_DJANGOS_TEST_SUITE'] = 'true' + os.environ["RUNNING_DJANGOS_TEST_SUITE"] = "true" test_labels = test_labels or test_modules return test_labels, state @@ -321,15 +332,16 @@ def teardown_run_tests(state): # atexit.register(shutil.rmtree, TMPDIR) handler. Prevents # FileNotFoundError at the end of a test run (#27890). from multiprocessing.util import _finalizer_registry + _finalizer_registry.pop((-100, 0), None) - del os.environ['RUNNING_DJANGOS_TEST_SUITE'] + del os.environ["RUNNING_DJANGOS_TEST_SUITE"] def actual_test_processes(parallel): if parallel == 0: # This doesn't work before django.setup() on some databases. if all(conn.features.can_clone_databases for conn in connections.all()): - return parallel_type('auto') + return parallel_type("auto") else: return 1 else: @@ -340,32 +352,53 @@ class ActionSelenium(argparse.Action): """ Validate the comma-separated list of requested browsers. """ + def __call__(self, parser, namespace, values, option_string=None): - browsers = values.split(',') + browsers = values.split(",") for browser in browsers: try: SeleniumTestCaseBase.import_webdriver(browser) except ImportError: - raise argparse.ArgumentError(self, "Selenium browser specification '%s' is not valid." % browser) + raise argparse.ArgumentError( + self, "Selenium browser specification '%s' is not valid." % browser + ) setattr(namespace, self.dest, browsers) -def django_tests(verbosity, interactive, failfast, keepdb, reverse, - test_labels, debug_sql, parallel, tags, exclude_tags, - test_name_patterns, start_at, start_after, pdb, buffer, - timing, shuffle): +def django_tests( + verbosity, + interactive, + failfast, + keepdb, + reverse, + test_labels, + debug_sql, + parallel, + tags, + exclude_tags, + test_name_patterns, + start_at, + start_after, + pdb, + buffer, + timing, + shuffle, +): actual_parallel = actual_test_processes(parallel) if verbosity >= 1: - msg = "Testing against Django installed in '%s'" % os.path.dirname(django.__file__) + msg = "Testing against Django installed in '%s'" % os.path.dirname( + django.__file__ + ) if actual_parallel > 1: msg += " with up to %d processes" % actual_parallel print(msg) test_labels, state = setup_run_tests(verbosity, start_at, start_after, test_labels) + process_setup_args = (verbosity, test_labels, parallel, start_at, start_after) # Run the test suite, including the extra validation tests. - if not hasattr(settings, 'TEST_RUNNER'): - settings.TEST_RUNNER = 'django.test.runner.DiscoverRunner' + if not hasattr(settings, "TEST_RUNNER"): + settings.TEST_RUNNER = "django.test.runner.DiscoverRunner" TestRunner = get_runner(settings) test_runner = TestRunner( verbosity=verbosity, @@ -383,7 +416,11 @@ def django_tests(verbosity, interactive, failfast, keepdb, reverse, timing=timing, shuffle=shuffle, ) - failures = test_runner.run_tests(test_labels) + failures = test_runner.run_tests( + test_labels, + process_setup=setup_run_tests, + process_setup_args=process_setup_args, + ) teardown_run_tests(state) return failures @@ -395,24 +432,22 @@ def collect_test_modules(start_at, start_after): def get_subprocess_args(options): - subprocess_args = [ - sys.executable, __file__, '--settings=%s' % options.settings - ] + subprocess_args = [sys.executable, __file__, "--settings=%s" % options.settings] if options.failfast: - subprocess_args.append('--failfast') + subprocess_args.append("--failfast") if options.verbosity: - subprocess_args.append('--verbosity=%s' % options.verbosity) + subprocess_args.append("--verbosity=%s" % options.verbosity) if not options.interactive: - subprocess_args.append('--noinput') + subprocess_args.append("--noinput") if options.tags: - subprocess_args.append('--tag=%s' % options.tags) + subprocess_args.append("--tag=%s" % options.tags) if options.exclude_tags: - subprocess_args.append('--exclude_tag=%s' % options.exclude_tags) + subprocess_args.append("--exclude_tag=%s" % options.exclude_tags) if options.shuffle is not False: if options.shuffle is None: - subprocess_args.append('--shuffle') + subprocess_args.append("--shuffle") else: - subprocess_args.append('--shuffle=%s' % options.shuffle) + subprocess_args.append("--shuffle=%s" % options.shuffle) return subprocess_args @@ -420,11 +455,11 @@ def bisect_tests(bisection_label, options, test_labels, start_at, start_after): if not test_labels: test_labels = collect_test_modules(start_at, start_after) - print('***** Bisecting test suite: %s' % ' '.join(test_labels)) + print("***** Bisecting test suite: %s" % " ".join(test_labels)) # Make sure the bisection point isn't in the test list # Also remove tests that need to be run in specific combinations - for label in [bisection_label, 'model_inheritance_same_model_name']: + for label in [bisection_label, "model_inheritance_same_model_name"]: try: test_labels.remove(label) except ValueError: @@ -437,13 +472,13 @@ def bisect_tests(bisection_label, options, test_labels, start_at, start_after): midpoint = len(test_labels) // 2 test_labels_a = test_labels[:midpoint] + [bisection_label] test_labels_b = test_labels[midpoint:] + [bisection_label] - print('***** Pass %da: Running the first half of the test suite' % iteration) - print('***** Test labels: %s' % ' '.join(test_labels_a)) + print("***** Pass %da: Running the first half of the test suite" % iteration) + print("***** Test labels: %s" % " ".join(test_labels_a)) failures_a = subprocess.run(subprocess_args + test_labels_a) - print('***** Pass %db: Running the second half of the test suite' % iteration) - print('***** Test labels: %s' % ' '.join(test_labels_b)) - print('') + print("***** Pass %db: Running the second half of the test suite" % iteration) + print("***** Test labels: %s" % " ".join(test_labels_b)) + print("") failures_b = subprocess.run(subprocess_args + test_labels_b) if failures_a.returncode and not failures_b.returncode: @@ -469,11 +504,11 @@ def paired_tests(paired_test, options, test_labels, start_at, start_after): if not test_labels: test_labels = collect_test_modules(start_at, start_after) - print('***** Trying paired execution') + print("***** Trying paired execution") # Make sure the constant member of the pair isn't in the test list # Also remove tests that need to be run in specific combinations - for label in [paired_test, 'model_inheritance_same_model_name']: + for label in [paired_test, "model_inheritance_same_model_name"]: try: test_labels.remove(label) except ValueError: @@ -482,133 +517,169 @@ def paired_tests(paired_test, options, test_labels, start_at, start_after): subprocess_args = get_subprocess_args(options) for i, label in enumerate(test_labels): - print('***** %d of %d: Check test pairing with %s' % ( - i + 1, len(test_labels), label)) + print( + "***** %d of %d: Check test pairing with %s" + % (i + 1, len(test_labels), label) + ) failures = subprocess.call(subprocess_args + [label, paired_test]) if failures: - print('***** Found problem pair with %s' % label) + print("***** Found problem pair with %s" % label) return - print('***** No problem pair found') + print("***** No problem pair found") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the Django test suite.") parser.add_argument( - 'modules', nargs='*', metavar='module', + "modules", + nargs="*", + metavar="module", help='Optional path(s) to test modules; e.g. "i18n" or ' - '"i18n.tests.TranslationTests.test_lazy_objects".', + '"i18n.tests.TranslationTests.test_lazy_objects".', ) parser.add_argument( - '-v', '--verbosity', default=1, type=int, choices=[0, 1, 2, 3], - help='Verbosity level; 0=minimal output, 1=normal output, 2=all output', + "-v", + "--verbosity", + default=1, + type=int, + choices=[0, 1, 2, 3], + help="Verbosity level; 0=minimal output, 1=normal output, 2=all output", ) parser.add_argument( - '--noinput', action='store_false', dest='interactive', - help='Tells Django to NOT prompt the user for input of any kind.', + "--noinput", + action="store_false", + dest="interactive", + help="Tells Django to NOT prompt the user for input of any kind.", ) parser.add_argument( - '--failfast', action='store_true', - help='Tells Django to stop running the test suite after first failed test.', + "--failfast", + action="store_true", + help="Tells Django to stop running the test suite after first failed test.", ) parser.add_argument( - '--keepdb', action='store_true', - help='Tells Django to preserve the test database between runs.', + "--keepdb", + action="store_true", + help="Tells Django to preserve the test database between runs.", ) parser.add_argument( - '--settings', + "--settings", help='Python path to settings module, e.g. "myproject.settings". If ' - 'this isn\'t provided, either the DJANGO_SETTINGS_MODULE ' - 'environment variable or "test_sqlite" will be used.', + "this isn't provided, either the DJANGO_SETTINGS_MODULE " + 'environment variable or "test_sqlite" will be used.', ) parser.add_argument( - '--bisect', - help='Bisect the test suite to discover a test that causes a test ' - 'failure when combined with the named test.', + "--bisect", + help="Bisect the test suite to discover a test that causes a test " + "failure when combined with the named test.", ) parser.add_argument( - '--pair', - help='Run the test suite in pairs with the named test to find problem pairs.', + "--pair", + help="Run the test suite in pairs with the named test to find problem pairs.", ) parser.add_argument( - '--shuffle', nargs='?', default=False, type=int, metavar='SEED', + "--shuffle", + nargs="?", + default=False, + type=int, + metavar="SEED", help=( - 'Shuffle the order of test cases to help check that tests are ' - 'properly isolated.' + "Shuffle the order of test cases to help check that tests are " + "properly isolated." ), ) parser.add_argument( - '--reverse', action='store_true', - help='Sort test suites and test cases in opposite order to debug ' - 'test side effects not apparent with normal execution lineup.', + "--reverse", + action="store_true", + help="Sort test suites and test cases in opposite order to debug " + "test side effects not apparent with normal execution lineup.", ) parser.add_argument( - '--selenium', action=ActionSelenium, metavar='BROWSERS', - help='A comma-separated list of browsers to run the Selenium tests against.', + "--selenium", + action=ActionSelenium, + metavar="BROWSERS", + help="A comma-separated list of browsers to run the Selenium tests against.", ) parser.add_argument( - '--headless', action='store_true', - help='Run selenium tests in headless mode, if the browser supports the option.', + "--headless", + action="store_true", + help="Run selenium tests in headless mode, if the browser supports the option.", ) parser.add_argument( - '--selenium-hub', - help='A URL for a selenium hub instance to use in combination with --selenium.', + "--selenium-hub", + help="A URL for a selenium hub instance to use in combination with --selenium.", ) parser.add_argument( - '--external-host', default=socket.gethostname(), - help='The external host that can be reached by the selenium hub instance when running Selenium ' - 'tests via Selenium Hub.', + "--external-host", + default=socket.gethostname(), + help="The external host that can be reached by the selenium hub instance when running Selenium " + "tests via Selenium Hub.", ) parser.add_argument( - '--debug-sql', action='store_true', - help='Turn on the SQL query logger within tests.', + "--debug-sql", + action="store_true", + help="Turn on the SQL query logger within tests.", ) try: - default_parallel = int(os.environ['DJANGO_TEST_PROCESSES']) + default_parallel = int(os.environ["DJANGO_TEST_PROCESSES"]) except KeyError: # actual_test_processes() converts this to "auto" later on. default_parallel = 0 parser.add_argument( - '--parallel', nargs='?', const='auto', default=default_parallel, - type=parallel_type, metavar='N', + "--parallel", + nargs="?", + const="auto", + default=default_parallel, + type=parallel_type, + metavar="N", help=( 'Run tests using up to N parallel processes. Use the value "auto" ' - 'to run one test process for each processor core.' + "to run one test process for each processor core." ), ) parser.add_argument( - '--tag', dest='tags', action='append', - help='Run only tests with the specified tags. Can be used multiple times.', + "--tag", + dest="tags", + action="append", + help="Run only tests with the specified tags. Can be used multiple times.", ) parser.add_argument( - '--exclude-tag', dest='exclude_tags', action='append', - help='Do not run tests with the specified tag. Can be used multiple times.', + "--exclude-tag", + dest="exclude_tags", + action="append", + help="Do not run tests with the specified tag. Can be used multiple times.", ) parser.add_argument( - '--start-after', dest='start_after', - help='Run tests starting after the specified top-level module.', + "--start-after", + dest="start_after", + help="Run tests starting after the specified top-level module.", ) parser.add_argument( - '--start-at', dest='start_at', - help='Run tests starting at the specified top-level module.', + "--start-at", + dest="start_at", + help="Run tests starting at the specified top-level module.", ) parser.add_argument( - '--pdb', action='store_true', - help='Runs the PDB debugger on error or failure.' + "--pdb", action="store_true", help="Runs the PDB debugger on error or failure." ) parser.add_argument( - '-b', '--buffer', action='store_true', - help='Discard output of passing tests.', + "-b", + "--buffer", + action="store_true", + help="Discard output of passing tests.", ) parser.add_argument( - '--timing', action='store_true', - help='Output timings, including database set up and total run time.', + "--timing", + action="store_true", + help="Output timings, including database set up and total run time.", ) parser.add_argument( - '-k', dest='test_name_patterns', action='append', + "-k", + dest="test_name_patterns", + action="append", help=( - 'Only run test methods and classes matching test name pattern. ' - 'Same as unittest -k option. Can be used multiple times.' + "Only run test methods and classes matching test name pattern. " + "Same as unittest -k option. Can be used multiple times." ), ) @@ -616,36 +687,49 @@ def paired_tests(paired_test, options, test_labels, start_at, start_after): using_selenium_hub = options.selenium and options.selenium_hub if options.selenium_hub and not options.selenium: - parser.error('--selenium-hub and --external-host require --selenium to be used.') + parser.error( + "--selenium-hub and --external-host require --selenium to be used." + ) if using_selenium_hub and not options.external_host: - parser.error('--selenium-hub and --external-host must be used together.') + parser.error("--selenium-hub and --external-host must be used together.") # Allow including a trailing slash on app_labels for tab completion convenience options.modules = [os.path.normpath(labels) for labels in options.modules] - mutually_exclusive_options = [options.start_at, options.start_after, options.modules] - enabled_module_options = [bool(option) for option in mutually_exclusive_options].count(True) + mutually_exclusive_options = [ + options.start_at, + options.start_after, + options.modules, + ] + enabled_module_options = [ + bool(option) for option in mutually_exclusive_options + ].count(True) if enabled_module_options > 1: - print('Aborting: --start-at, --start-after, and test labels are mutually exclusive.') + print( + "Aborting: --start-at, --start-after, and test labels are mutually exclusive." + ) sys.exit(1) - for opt_name in ['start_at', 'start_after']: + for opt_name in ["start_at", "start_after"]: opt_val = getattr(options, opt_name) if opt_val: - if '.' in opt_val: - print('Aborting: --%s must be a top-level module.' % opt_name.replace('_', '-')) + if "." in opt_val: + print( + "Aborting: --%s must be a top-level module." + % opt_name.replace("_", "-") + ) sys.exit(1) setattr(options, opt_name, os.path.normpath(opt_val)) if options.settings: - os.environ['DJANGO_SETTINGS_MODULE'] = options.settings + os.environ["DJANGO_SETTINGS_MODULE"] = options.settings else: - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'test_sqlite') - options.settings = os.environ['DJANGO_SETTINGS_MODULE'] + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "test_sqlite") + options.settings = os.environ["DJANGO_SETTINGS_MODULE"] if options.selenium: if not options.tags: - options.tags = ['selenium'] - elif 'selenium' not in options.tags: - options.tags.append('selenium') + options.tags = ["selenium"] + elif "selenium" not in options.tags: + options.tags.append("selenium") if options.selenium_hub: SeleniumTestCaseBase.selenium_hub = options.selenium_hub SeleniumTestCaseBase.external_host = options.external_host @@ -654,25 +738,41 @@ def paired_tests(paired_test, options, test_labels, start_at, start_after): if options.bisect: bisect_tests( - options.bisect, options, options.modules, options.start_at, + options.bisect, + options, + options.modules, + options.start_at, options.start_after, ) elif options.pair: paired_tests( - options.pair, options, options.modules, options.start_at, + options.pair, + options, + options.modules, + options.start_at, options.start_after, ) else: time_keeper = TimeKeeper() if options.timing else NullTimeKeeper() - with time_keeper.timed('Total run'): + with time_keeper.timed("Total run"): failures = django_tests( - options.verbosity, options.interactive, options.failfast, - options.keepdb, options.reverse, options.modules, - options.debug_sql, options.parallel, options.tags, + options.verbosity, + options.interactive, + options.failfast, + options.keepdb, + options.reverse, + options.modules, + options.debug_sql, + options.parallel, + options.tags, options.exclude_tags, - getattr(options, 'test_name_patterns', None), - options.start_at, options.start_after, options.pdb, options.buffer, - options.timing, options.shuffle, + getattr(options, "test_name_patterns", None), + options.start_at, + options.start_after, + options.pdb, + options.buffer, + options.timing, + options.shuffle, ) time_keeper.print_results() if failures: