Skip to content

Commit

Permalink
PersistentDict: concurrency improvements (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Jun 7, 2024
1 parent c15051b commit 93eab0f
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 39 deletions.
100 changes: 61 additions & 39 deletions pytools/persistent_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,21 +463,21 @@ def __init__(self, identifier: str,
# https://www.sqlite.org/lang_transaction.html#implicit_versus_explicit_transactions
self.conn = sqlite3.connect(self.filename, isolation_level=None)

self.conn.execute(
self._exec_sql(
"CREATE TABLE IF NOT EXISTS dict "
"(keyhash TEXT NOT NULL PRIMARY KEY, key_value TEXT NOT NULL)"
)

# https://www.sqlite.org/wal.html
if enable_wal:
self.conn.execute("PRAGMA journal_mode = 'WAL'")
self._exec_sql("PRAGMA journal_mode = 'WAL'")

# Note: the following configuration values were taken mostly from litedict:
# https://github.com/litements/litedict/blob/377603fa597453ffd9997186a493ed4fd23e5399/litedict.py#L67-L70

# Use in-memory temp store
# https://www.sqlite.org/pragma.html#pragma_temp_store
self.conn.execute("PRAGMA temp_store = 'MEMORY'")
self._exec_sql("PRAGMA temp_store = 'MEMORY'")

# fsync() can be extremely slow on some systems.
# See https://github.com/inducer/pytools/issues/227 for context.
Expand All @@ -493,13 +493,13 @@ def __init__(self, identifier: str,
"Pass 'safe_sync=False' if occasional data loss is tolerable. "
"Pass 'safe_sync=True' to suppress this warning.",
stacklevel=3)
self.conn.execute("PRAGMA synchronous = 'NORMAL'")
self._exec_sql("PRAGMA synchronous = 'NORMAL'")
else:
self.conn.execute("PRAGMA synchronous = 'OFF'")
self._exec_sql("PRAGMA synchronous = 'OFF'")

# 64 MByte of cache
# https://www.sqlite.org/pragma.html#pragma_cache_size
self.conn.execute("PRAGMA cache_size = -64000")
self._exec_sql("PRAGMA cache_size = -64000")

def __del__(self) -> None:
if self.conn:
Expand All @@ -522,6 +522,28 @@ def _collision_check(self, key: K, stored_key: K) -> None:
stored_key == key # pylint:disable=pointless-statement # noqa: B015
raise NoSuchEntryCollisionError(key)

def _exec_sql(self, *args: Any) -> sqlite3.Cursor:
return self._exec_sql_fn(lambda: self.conn.execute(*args))

def _exec_sql_fn(self, fn: Any) -> Any:
n = 0

while True:
n += 1
try:
return fn()
except sqlite3.OperationalError as e:
# If the database is busy, retry
if (hasattr(e, "sqlite_errorcode")
and not e.sqlite_errorcode == sqlite3.SQLITE_BUSY):
raise
if n % 20 == 0:
from warnings import warn
warn(f"PersistentDict: database '{self.filename}' busy, {n} "
"retries")
else:
break

def store_if_not_present(self, key: K, value: V) -> None:
"""Store (*key*, *value*) if *key* is not already present."""
self.store(key, value, _skip_if_present=True)
Expand All @@ -548,30 +570,30 @@ def __setitem__(self, key: K, value: V) -> None:

def __len__(self) -> int:
"""Return the number of entries in the dictionary."""
return next(self.conn.execute("SELECT COUNT(*) FROM dict"))[0]
return next(self._exec_sql("SELECT COUNT(*) FROM dict"))[0]

def __iter__(self) -> Generator[K, None, None]:
"""Return an iterator over the keys in the dictionary."""
return self.keys()

def keys(self) -> Generator[K, None, None]:
"""Return an iterator over the keys in the dictionary."""
for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"):
for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
yield pickle.loads(row[0])[0]

def values(self) -> Generator[V, None, None]:
"""Return an iterator over the values in the dictionary."""
for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"):
for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
yield pickle.loads(row[0])[1]

def items(self) -> Generator[tuple[K, V], None, None]:
"""Return an iterator over the items in the dictionary."""
for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"):
for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
yield pickle.loads(row[0])

def nbytes(self) -> int:
"""Return the size of the dictionary in bytes."""
return next(self.conn.execute("SELECT page_size * page_count FROM "
return next(self._exec_sql("SELECT page_size * page_count FROM "
"pragma_page_size(), pragma_page_count()"))[0]

def __repr__(self) -> str:
Expand All @@ -580,7 +602,7 @@ def __repr__(self) -> str:

def clear(self) -> None:
"""Remove all entries from the dictionary."""
self.conn.execute("DELETE FROM dict")
self._exec_sql("DELETE FROM dict")


class WriteOncePersistentDict(_PersistentDictBase[K, V]):
Expand Down Expand Up @@ -644,11 +666,11 @@ def store(self, key: K, value: V, _skip_if_present: bool = False) -> None:
v = pickle.dumps((key, value))

if _skip_if_present:
self.conn.execute("INSERT OR IGNORE INTO dict VALUES (?, ?)",
self._exec_sql("INSERT OR IGNORE INTO dict VALUES (?, ?)",
(keyhash, v))
else:
try:
self.conn.execute("INSERT INTO dict VALUES (?, ?)", (keyhash, v))
self._exec_sql("INSERT INTO dict VALUES (?, ?)", (keyhash, v))
except sqlite3.IntegrityError as e:
if hasattr(e, "sqlite_errorcode"):
if e.sqlite_errorcode == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY:
Expand All @@ -662,7 +684,7 @@ def store(self, key: K, value: V, _skip_if_present: bool = False) -> None:

def _fetch(self, keyhash: str) -> Tuple[K, V]: # pylint:disable=method-hidden
# This method is separate from fetch() to allow for LRU caching
c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?",
(keyhash,))
row = c.fetchone()
if row is None:
Expand Down Expand Up @@ -730,17 +752,15 @@ def store(self, key: K, value: V, _skip_if_present: bool = False) -> None:
keyhash = self.key_builder(key)
v = pickle.dumps((key, value))

if _skip_if_present:
self.conn.execute("INSERT OR IGNORE INTO dict VALUES (?, ?)",
(keyhash, v))
else:
self.conn.execute("INSERT OR REPLACE INTO dict VALUES (?, ?)",
mode = "IGNORE" if _skip_if_present else "REPLACE"

self._exec_sql(f"INSERT OR {mode} INTO dict VALUES (?, ?)",
(keyhash, v))

def fetch(self, key: K) -> V:
keyhash = self.key_builder(key)

c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?",
(keyhash,))
row = c.fetchone()
if row is None:
Expand All @@ -754,24 +774,26 @@ def remove(self, key: K) -> None:
"""Remove the entry associated with *key* from the dictionary."""
keyhash = self.key_builder(key)

self.conn.execute("BEGIN EXCLUSIVE TRANSACTION")

try:
# This is split into SELECT/DELETE to allow for a collision check
c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
(keyhash,))
row = c.fetchone()
if row is None:
raise NoSuchEntryError(key)

stored_key, _value = pickle.loads(row[0])
self._collision_check(key, stored_key)

self.conn.execute("DELETE FROM dict WHERE keyhash=?", (keyhash,))
self.conn.execute("COMMIT")
except Exception as e:
self.conn.execute("ROLLBACK")
raise e
def remove_inner() -> None:
self.conn.execute("BEGIN EXCLUSIVE TRANSACTION")
try:
# This is split into SELECT/DELETE to allow for a collision check
c = self.conn.execute("SELECT key_value FROM dict WHERE "
"keyhash=?", (keyhash,))
row = c.fetchone()
if row is None:
raise NoSuchEntryError(key)

stored_key, _value = pickle.loads(row[0])
self._collision_check(key, stored_key)

self.conn.execute("DELETE FROM dict WHERE keyhash=?", (keyhash,))
self.conn.execute("COMMIT")
except Exception as e:
self.conn.execute("ROLLBACK")
raise e

self._exec_sql_fn(remove_inner)

def __delitem__(self, key: K) -> None:
"""Remove the entry associated with *key* from the dictionary."""
Expand Down
54 changes: 54 additions & 0 deletions pytools/test/test_persistent_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,60 @@ def method(self):
# }}}


# {{{ basic concurrency test

def _mp_fn(tmpdir: str) -> None:
import time
pdict: PersistentDict[int, int] = PersistentDict("pytools-test",
container_dir=tmpdir,
safe_sync=False)
n = 10000
s = 0

start = time.time()

for i in range(n):
if i % 100 == 0:
print(f"i={i}")
pdict[i] = i

try:
s += pdict[i]
except NoSuchEntryError:
# Someone else already deleted the entry
pass

try:
del pdict[i]
except NoSuchEntryError:
# Someone else already deleted the entry
pass

end = time.time()

print(f"PersistentDict: time taken to write {n} entries to "
f"{pdict.filename}: {end-start} s={s}")


def test_concurrency() -> None:
from multiprocessing import Process

tmpdir = "_tmp/" # must be the same across all processes in this test

try:
p = [Process(target=_mp_fn, args=(tmpdir, )) for _ in range(4)]
for pp in p:
pp.start()
for pp in p:
pp.join()

assert all(pp.exitcode == 0 for pp in p), [pp.exitcode for pp in p]
finally:
shutil.rmtree(tmpdir)

# }}}


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from setuptools import find_packages, setup


ver_dic = {}
version_file = open("pytools/version.py")
try:
Expand Down

0 comments on commit 93eab0f

Please sign in to comment.