diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index ed4fce98..5d1baf9f 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -472,9 +472,16 @@ def __init__(self, self.container_dir = container_dir self._make_container_dir() - # isolation_level=None: enable autocommit mode - # https://www.sqlite.org/lang_transaction.html#implicit_versus_explicit_transactions - self.conn = sqlite3.connect(self.filename, isolation_level=None) + from threading import Lock + self.mutex = Lock() + + # * isolation_level=None: enable autocommit mode + # https://www.sqlite.org/lang_transaction.html#implicit_versus_explicit_transactions + # * check_same_thread=False: thread-level concurrency is handled by the + # mutex above + self.conn = sqlite3.connect(self.filename, + isolation_level=None, + check_same_thread=False) self._exec_sql( "CREATE TABLE IF NOT EXISTS dict " @@ -515,8 +522,9 @@ def __init__(self, self._exec_sql("PRAGMA cache_size = -64000") def __del__(self) -> None: - if self.conn: - self.conn.close() + with self.mutex: + if self.conn: + self.conn.close() def _collision_check(self, key: K, stored_key: K) -> None: if stored_key != key: @@ -550,21 +558,22 @@ def execute() -> sqlite3.Cursor: def _exec_sql_fn(self, fn: Callable[[], T]) -> Optional[T]: n = 0 - while True: - n += 1 - try: - return fn() - except sqlite3.OperationalError as e: - # If the database is busy, retry - if (hasattr(e, "sqlite_errorcode") - and not e.sqlite_errorcode == sqlite3.SQLITE_BUSY): - raise - if n % 20 == 0: - from warnings import warn - warn(f"PersistentDict: database '{self.filename}' busy, {n} " - "retries", stacklevel=3) - else: - break + with self.mutex: + while True: + n += 1 + try: + return fn() + except sqlite3.OperationalError as e: + # If the database is busy, retry + if (hasattr(e, "sqlite_errorcode") + and not e.sqlite_errorcode == sqlite3.SQLITE_BUSY): + raise + if n % 20 == 0: + from warnings import warn + warn(f"PersistentDict: database '{self.filename}' busy, {n} " + "retries", stacklevel=3) + else: + break def store_if_not_present(self, key: K, value: V) -> None: """Store (*key*, *value*) if *key* is not already present.""" @@ -716,9 +725,19 @@ def store(self, key: K, value: V, _skip_if_present: bool = False) -> None: def _fetch_uncached(self, keyhash: str) -> Tuple[K, V]: # This method is separate from fetch() to allow for LRU caching - c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?", - (keyhash,)) - row = c.fetchone() + + def fetch_inner() -> Optional[Tuple[Any]]: + assert self.conn is not None + + # This is separate from fetch() so that the mutex covers the + # fetchone() call + c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?", + (keyhash,)) + res = c.fetchone() + assert res is None or isinstance(res, tuple) + return res + + row = self._exec_sql_fn(fetch_inner) if row is None: raise KeyError @@ -797,9 +816,19 @@ def store(self, key: K, value: V, _skip_if_present: bool = False) -> None: def fetch(self, key: K) -> V: keyhash = self.key_builder(key) - c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?", - (keyhash,)) - row = c.fetchone() + def fetch_inner() -> Optional[Tuple[Any]]: + assert self.conn is not None + + # This is separate from fetch() so that the mutex covers the + # fetchone() call + c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?", + (keyhash,)) + res = c.fetchone() + assert res is None or isinstance(res, tuple) + return res + + row = self._exec_sql_fn(fetch_inner) + if row is None: raise NoSuchEntryError(key) diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index b0e050ed..858f22fb 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -3,7 +3,7 @@ import tempfile from dataclasses import dataclass from enum import Enum, IntEnum -from typing import Any, Dict +from typing import Any, Dict, Optional import pytest @@ -905,22 +905,34 @@ def method(self): # }}} -# {{{ basic concurrency test +# {{{ basic concurrency tests -def _mp_fn(tmpdir: str) -> None: +def _conc_fn(tmpdir: Optional[str] = None, + pdict: Optional[PersistentDict[int, int]] = None) -> None: import time - pdict: PersistentDict[int, int] = PersistentDict("pytools-test", - container_dir=tmpdir, - safe_sync=False) + + assert (pdict is None) ^ (tmpdir is None) + + if pdict is None: + pdict = PersistentDict("pytools-test", + container_dir=tmpdir, + safe_sync=False) n = 10000 s = 0 start = time.time() for i in range(n): - if i % 100 == 0: + if i % 1000 == 0: print(f"i={i}") - pdict[i] = i + + if isinstance(pdict, WriteOncePersistentDict): + try: + pdict[i] = i + except ReadOnlyEntryError: + pass + else: + pdict[i] = i try: s += pdict[i] @@ -928,11 +940,12 @@ def _mp_fn(tmpdir: str) -> None: # Someone else already deleted the entry pass - try: - del pdict[i] - except NoSuchEntryError: - # Someone else already deleted the entry - pass + if not isinstance(pdict, WriteOncePersistentDict): + try: + del pdict[i] + except NoSuchEntryError: + # Someone else already deleted the entry + pass end = time.time() @@ -940,13 +953,15 @@ def _mp_fn(tmpdir: str) -> None: f"{pdict.filename}: {end-start} s={s}") -def test_concurrency() -> None: +def test_concurrency_processes() -> None: from multiprocessing import Process - tmpdir = "_tmp/" # must be the same across all processes in this test + tmpdir = "_tmp_proc/" # must be the same across all processes in this test try: - p = [Process(target=_mp_fn, args=(tmpdir, )) for _ in range(4)] + # multiprocessing needs to pickle function arguments, so we can't pass + # the PersistentDict object (which is unpicklable) directly. + p = [Process(target=_conc_fn, args=(tmpdir, None)) for _ in range(4)] for pp in p: pp.start() for pp in p: @@ -956,6 +971,56 @@ def test_concurrency() -> None: finally: shutil.rmtree(tmpdir) + +from threading import Thread + + +class RaisingThread(Thread): + def run(self) -> None: + self._exc = None + try: + super().run() + except Exception as e: + self._exc = e + + def join(self, timeout: Optional[float] = None) -> None: + super().join(timeout=timeout) + if self._exc: + raise self._exc + + +def test_concurrency_threads() -> None: + tmpdir = "_tmp_threads/" # must be the same across all threads in this test + + try: + # Share this pdict object among all threads to test thread safety + pdict: PersistentDict[int, int] = PersistentDict("pytools-test", + container_dir=tmpdir, + safe_sync=False) + t = [RaisingThread(target=_conc_fn, args=(None, pdict)) for _ in range(4)] + for tt in t: + tt.start() + for tt in t: + tt.join() + # Threads will raise in join() if they encountered an exception + finally: + shutil.rmtree(tmpdir) + + try: + # Share this pdict object among all threads to test thread safety + pdict2: WriteOncePersistentDict[int, int] = WriteOncePersistentDict( + "pytools-test", + container_dir=tmpdir, + safe_sync=False) + t = [RaisingThread(target=_conc_fn, args=(None, pdict2)) for _ in range(4)] + for tt in t: + tt.start() + for tt in t: + tt.join() + # Threads will raise in join() if they encountered an exception + finally: + shutil.rmtree(tmpdir) + # }}}