diff --git a/coverage/sqldata.py b/coverage/sqldata.py index e136c7f6a..714d5348a 100644 --- a/coverage/sqldata.py +++ b/coverage/sqldata.py @@ -18,6 +18,7 @@ import sqlite3 import sys import threading +import weakref import zlib from coverage.debug import NoDebugging, SimpleReprMixin, clipped_repr @@ -213,7 +214,6 @@ def __init__(self, basename=None, suffix=None, no_disk=False, warn=None, debug=N self._file_map = {} # Maps thread ids to SqliteDb objects. self._dbs = {} - self._pid = os.getpid() # Synchronize the operations used during collection. self._lock = threading.Lock() @@ -227,6 +227,21 @@ def __init__(self, basename=None, suffix=None, no_disk=False, warn=None, debug=N self._current_context_id = None self._query_context_ids = None + if hasattr(os, "fork"): + os.register_at_fork(after_in_child=functools.partial(self._at_fork, weakref.ref(self))) + + @staticmethod + def _at_fork(self_ref): + """A hook run in new child processes after a fork.""" + self = self_ref() + if self is None: + return + + # Looks like we forked! Have to start a new data file. + self._lock = threading.Lock() + self._reset() + self._choose_filename() + def _locked(method): # pylint: disable=no-self-argument """A decorator for methods that should hold self._lock.""" @functools.wraps(method) @@ -780,11 +795,6 @@ def write(self): def _start_using(self): """Call this before using the database at all.""" - if self._pid != os.getpid(): - # Looks like we forked! Have to start a new data file. - self._reset() - self._choose_filename() - self._pid = os.getpid() if not self._have_used: self.erase() self._have_used = True