Skip to content

Commit

Permalink
PersistentDict: thread safety (#234)
Browse files Browse the repository at this point in the history
Co-authored-by: Andreas Kloeckner <inform@tiker.net>
  • Loading branch information
matthiasdiener and inducer committed Jul 2, 2024
1 parent 7c1bee6 commit a3015f1
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 42 deletions.
81 changes: 55 additions & 26 deletions pytools/persistent_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,16 @@ def __init__(self,
self.container_dir = container_dir
self._make_container_dir()

# isolation_level=None: enable autocommit mode
# https://www.sqlite.org/lang_transaction.html#implicit_versus_explicit_transactions
self.conn = sqlite3.connect(self.filename, isolation_level=None)
from threading import Lock
self.mutex = Lock()

# * isolation_level=None: enable autocommit mode
# https://www.sqlite.org/lang_transaction.html#implicit_versus_explicit_transactions
# * check_same_thread=False: thread-level concurrency is handled by the
# mutex above
self.conn = sqlite3.connect(self.filename,
isolation_level=None,
check_same_thread=False)

self._exec_sql(
"CREATE TABLE IF NOT EXISTS dict "
Expand Down Expand Up @@ -515,8 +522,9 @@ def __init__(self,
self._exec_sql("PRAGMA cache_size = -64000")

def __del__(self) -> None:
if self.conn:
self.conn.close()
with self.mutex:
if self.conn:
self.conn.close()

def _collision_check(self, key: K, stored_key: K) -> None:
if stored_key != key:
Expand Down Expand Up @@ -550,21 +558,22 @@ def execute() -> sqlite3.Cursor:
def _exec_sql_fn(self, fn: Callable[[], T]) -> Optional[T]:
n = 0

while True:
n += 1
try:
return fn()
except sqlite3.OperationalError as e:
# If the database is busy, retry
if (hasattr(e, "sqlite_errorcode")
and not e.sqlite_errorcode == sqlite3.SQLITE_BUSY):
raise
if n % 20 == 0:
from warnings import warn
warn(f"PersistentDict: database '{self.filename}' busy, {n} "
"retries", stacklevel=3)
else:
break
with self.mutex:
while True:
n += 1
try:
return fn()
except sqlite3.OperationalError as e:
# If the database is busy, retry
if (hasattr(e, "sqlite_errorcode")
and not e.sqlite_errorcode == sqlite3.SQLITE_BUSY):
raise
if n % 20 == 0:
from warnings import warn
warn(f"PersistentDict: database '{self.filename}' busy, {n} "
"retries", stacklevel=3)
else:
break

def store_if_not_present(self, key: K, value: V) -> None:
"""Store (*key*, *value*) if *key* is not already present."""
Expand Down Expand Up @@ -716,9 +725,19 @@ def store(self, key: K, value: V, _skip_if_present: bool = False) -> None:

def _fetch_uncached(self, keyhash: str) -> Tuple[K, V]:
# This method is separate from fetch() to allow for LRU caching
c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?",
(keyhash,))
row = c.fetchone()

def fetch_inner() -> Optional[Tuple[Any]]:
assert self.conn is not None

# This is separate from fetch() so that the mutex covers the
# fetchone() call
c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
(keyhash,))
res = c.fetchone()
assert res is None or isinstance(res, tuple)
return res

row = self._exec_sql_fn(fetch_inner)
if row is None:
raise KeyError

Expand Down Expand Up @@ -797,9 +816,19 @@ def store(self, key: K, value: V, _skip_if_present: bool = False) -> None:
def fetch(self, key: K) -> V:
keyhash = self.key_builder(key)

c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?",
(keyhash,))
row = c.fetchone()
def fetch_inner() -> Optional[Tuple[Any]]:
assert self.conn is not None

# This is separate from fetch() so that the mutex covers the
# fetchone() call
c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
(keyhash,))
res = c.fetchone()
assert res is None or isinstance(res, tuple)
return res

row = self._exec_sql_fn(fetch_inner)

if row is None:
raise NoSuchEntryError(key)

Expand Down
97 changes: 81 additions & 16 deletions pytools/test/test_persistent_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import Any, Dict
from typing import Any, Dict, Optional

import pytest

Expand Down Expand Up @@ -905,48 +905,63 @@ def method(self):
# }}}


# {{{ basic concurrency test
# {{{ basic concurrency tests

def _mp_fn(tmpdir: str) -> None:
def _conc_fn(tmpdir: Optional[str] = None,
pdict: Optional[PersistentDict[int, int]] = None) -> None:
import time
pdict: PersistentDict[int, int] = PersistentDict("pytools-test",
container_dir=tmpdir,
safe_sync=False)

assert (pdict is None) ^ (tmpdir is None)

if pdict is None:
pdict = PersistentDict("pytools-test",
container_dir=tmpdir,
safe_sync=False)
n = 10000
s = 0

start = time.time()

for i in range(n):
if i % 100 == 0:
if i % 1000 == 0:
print(f"i={i}")
pdict[i] = i

if isinstance(pdict, WriteOncePersistentDict):
try:
pdict[i] = i
except ReadOnlyEntryError:
pass
else:
pdict[i] = i

try:
s += pdict[i]
except NoSuchEntryError:
# Someone else already deleted the entry
pass

try:
del pdict[i]
except NoSuchEntryError:
# Someone else already deleted the entry
pass
if not isinstance(pdict, WriteOncePersistentDict):
try:
del pdict[i]
except NoSuchEntryError:
# Someone else already deleted the entry
pass

end = time.time()

print(f"PersistentDict: time taken to write {n} entries to "
f"{pdict.filename}: {end-start} s={s}")


def test_concurrency() -> None:
def test_concurrency_processes() -> None:
from multiprocessing import Process

tmpdir = "_tmp/" # must be the same across all processes in this test
tmpdir = "_tmp_proc/" # must be the same across all processes in this test

try:
p = [Process(target=_mp_fn, args=(tmpdir, )) for _ in range(4)]
# multiprocessing needs to pickle function arguments, so we can't pass
# the PersistentDict object (which is unpicklable) directly.
p = [Process(target=_conc_fn, args=(tmpdir, None)) for _ in range(4)]
for pp in p:
pp.start()
for pp in p:
Expand All @@ -956,6 +971,56 @@ def test_concurrency() -> None:
finally:
shutil.rmtree(tmpdir)


from threading import Thread


class RaisingThread(Thread):
def run(self) -> None:
self._exc = None
try:
super().run()
except Exception as e:
self._exc = e

def join(self, timeout: Optional[float] = None) -> None:
super().join(timeout=timeout)
if self._exc:
raise self._exc


def test_concurrency_threads() -> None:
tmpdir = "_tmp_threads/" # must be the same across all threads in this test

try:
# Share this pdict object among all threads to test thread safety
pdict: PersistentDict[int, int] = PersistentDict("pytools-test",
container_dir=tmpdir,
safe_sync=False)
t = [RaisingThread(target=_conc_fn, args=(None, pdict)) for _ in range(4)]
for tt in t:
tt.start()
for tt in t:
tt.join()
# Threads will raise in join() if they encountered an exception
finally:
shutil.rmtree(tmpdir)

try:
# Share this pdict object among all threads to test thread safety
pdict2: WriteOncePersistentDict[int, int] = WriteOncePersistentDict(
"pytools-test",
container_dir=tmpdir,
safe_sync=False)
t = [RaisingThread(target=_conc_fn, args=(None, pdict2)) for _ in range(4)]
for tt in t:
tt.start()
for tt in t:
tt.join()
# Threads will raise in join() if they encountered an exception
finally:
shutil.rmtree(tmpdir)

# }}}


Expand Down

0 comments on commit a3015f1

Please sign in to comment.