Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PersistentDict: concurrency improvements #231

Merged
merged 14 commits into from
Jun 7, 2024
Merged
100 changes: 61 additions & 39 deletions pytools/persistent_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,21 +463,21 @@ def __init__(self, identifier: str,
# https://www.sqlite.org/lang_transaction.html#implicit_versus_explicit_transactions
self.conn = sqlite3.connect(self.filename, isolation_level=None)

self.conn.execute(
self._exec_sql(
"CREATE TABLE IF NOT EXISTS dict "
"(keyhash TEXT NOT NULL PRIMARY KEY, key_value TEXT NOT NULL)"
)

# https://www.sqlite.org/wal.html
if enable_wal:
self.conn.execute("PRAGMA journal_mode = 'WAL'")
self._exec_sql("PRAGMA journal_mode = 'WAL'")

# Note: the following configuration values were taken mostly from litedict:
# https://github.com/litements/litedict/blob/377603fa597453ffd9997186a493ed4fd23e5399/litedict.py#L67-L70

# Use in-memory temp store
# https://www.sqlite.org/pragma.html#pragma_temp_store
self.conn.execute("PRAGMA temp_store = 'MEMORY'")
self._exec_sql("PRAGMA temp_store = 'MEMORY'")

# fsync() can be extremely slow on some systems.
# See https://github.com/inducer/pytools/issues/227 for context.
Expand All @@ -493,13 +493,13 @@ def __init__(self, identifier: str,
"Pass 'safe_sync=False' if occasional data loss is tolerable. "
"Pass 'safe_sync=True' to suppress this warning.",
stacklevel=3)
self.conn.execute("PRAGMA synchronous = 'NORMAL'")
self._exec_sql("PRAGMA synchronous = 'NORMAL'")
else:
self.conn.execute("PRAGMA synchronous = 'OFF'")
self._exec_sql("PRAGMA synchronous = 'OFF'")

# 64 MByte of cache
# https://www.sqlite.org/pragma.html#pragma_cache_size
self.conn.execute("PRAGMA cache_size = -64000")
self._exec_sql("PRAGMA cache_size = -64000")

def __del__(self) -> None:
if self.conn:
Expand All @@ -522,6 +522,28 @@ def _collision_check(self, key: K, stored_key: K) -> None:
stored_key == key # pylint:disable=pointless-statement # noqa: B015
raise NoSuchEntryCollisionError(key)

def _exec_sql(self, *args: Any) -> sqlite3.Cursor:
return self._exec_sql_fn(lambda: self.conn.execute(*args))

def _exec_sql_fn(self, fn: Any) -> Any:
n = 0

while True:
n += 1
try:
return fn()
except sqlite3.OperationalError as e:
# If the database is busy, retry
if (hasattr(e, "sqlite_errorcode")
and not e.sqlite_errorcode == sqlite3.SQLITE_BUSY):
raise
if n % 20 == 0:
from warnings import warn
warn(f"PersistentDict: database '{self.filename}' busy, {n} "
"retries")
else:
inducer marked this conversation as resolved.
Show resolved Hide resolved
inducer marked this conversation as resolved.
Show resolved Hide resolved
break

def store_if_not_present(self, key: K, value: V) -> None:
"""Store (*key*, *value*) if *key* is not already present."""
self.store(key, value, _skip_if_present=True)
Expand All @@ -548,30 +570,30 @@ def __setitem__(self, key: K, value: V) -> None:

def __len__(self) -> int:
"""Return the number of entries in the dictionary."""
return next(self.conn.execute("SELECT COUNT(*) FROM dict"))[0]
return next(self._exec_sql("SELECT COUNT(*) FROM dict"))[0]

def __iter__(self) -> Generator[K, None, None]:
"""Return an iterator over the keys in the dictionary."""
return self.keys()

def keys(self) -> Generator[K, None, None]:
"""Return an iterator over the keys in the dictionary."""
for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"):
for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
yield pickle.loads(row[0])[0]

def values(self) -> Generator[V, None, None]:
"""Return an iterator over the values in the dictionary."""
for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"):
for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
yield pickle.loads(row[0])[1]

def items(self) -> Generator[tuple[K, V], None, None]:
"""Return an iterator over the items in the dictionary."""
for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"):
for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
yield pickle.loads(row[0])

def nbytes(self) -> int:
"""Return the size of the dictionary in bytes."""
return next(self.conn.execute("SELECT page_size * page_count FROM "
return next(self._exec_sql("SELECT page_size * page_count FROM "
"pragma_page_size(), pragma_page_count()"))[0]

def __repr__(self) -> str:
Expand All @@ -580,7 +602,7 @@ def __repr__(self) -> str:

def clear(self) -> None:
"""Remove all entries from the dictionary."""
self.conn.execute("DELETE FROM dict")
self._exec_sql("DELETE FROM dict")


class WriteOncePersistentDict(_PersistentDictBase[K, V]):
Expand Down Expand Up @@ -644,11 +666,11 @@ def store(self, key: K, value: V, _skip_if_present: bool = False) -> None:
v = pickle.dumps((key, value))

if _skip_if_present:
self.conn.execute("INSERT OR IGNORE INTO dict VALUES (?, ?)",
self._exec_sql("INSERT OR IGNORE INTO dict VALUES (?, ?)",
(keyhash, v))
else:
try:
self.conn.execute("INSERT INTO dict VALUES (?, ?)", (keyhash, v))
self._exec_sql("INSERT INTO dict VALUES (?, ?)", (keyhash, v))
except sqlite3.IntegrityError as e:
if hasattr(e, "sqlite_errorcode"):
if e.sqlite_errorcode == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY:
Expand All @@ -662,7 +684,7 @@ def store(self, key: K, value: V, _skip_if_present: bool = False) -> None:

def _fetch(self, keyhash: str) -> Tuple[K, V]: # pylint:disable=method-hidden
# This method is separate from fetch() to allow for LRU caching
c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?",
(keyhash,))
row = c.fetchone()
if row is None:
Expand Down Expand Up @@ -730,17 +752,15 @@ def store(self, key: K, value: V, _skip_if_present: bool = False) -> None:
keyhash = self.key_builder(key)
v = pickle.dumps((key, value))

if _skip_if_present:
self.conn.execute("INSERT OR IGNORE INTO dict VALUES (?, ?)",
(keyhash, v))
else:
self.conn.execute("INSERT OR REPLACE INTO dict VALUES (?, ?)",
mode = "IGNORE" if _skip_if_present else "REPLACE"

self._exec_sql(f"INSERT OR {mode} INTO dict VALUES (?, ?)",
(keyhash, v))

def fetch(self, key: K) -> V:
keyhash = self.key_builder(key)

c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?",
(keyhash,))
row = c.fetchone()
if row is None:
Expand All @@ -754,24 +774,26 @@ def remove(self, key: K) -> None:
"""Remove the entry associated with *key* from the dictionary."""
keyhash = self.key_builder(key)

self.conn.execute("BEGIN EXCLUSIVE TRANSACTION")

try:
# This is split into SELECT/DELETE to allow for a collision check
c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
(keyhash,))
row = c.fetchone()
if row is None:
raise NoSuchEntryError(key)

stored_key, _value = pickle.loads(row[0])
self._collision_check(key, stored_key)

self.conn.execute("DELETE FROM dict WHERE keyhash=?", (keyhash,))
self.conn.execute("COMMIT")
except Exception as e:
self.conn.execute("ROLLBACK")
raise e
def remove_inner() -> None:
self.conn.execute("BEGIN EXCLUSIVE TRANSACTION")
try:
# This is split into SELECT/DELETE to allow for a collision check
c = self.conn.execute("SELECT key_value FROM dict WHERE "
"keyhash=?", (keyhash,))
row = c.fetchone()
if row is None:
raise NoSuchEntryError(key)

stored_key, _value = pickle.loads(row[0])
self._collision_check(key, stored_key)

self.conn.execute("DELETE FROM dict WHERE keyhash=?", (keyhash,))
self.conn.execute("COMMIT")
except Exception as e:
self.conn.execute("ROLLBACK")
raise e

self._exec_sql_fn(remove_inner)

def __delitem__(self, key: K) -> None:
"""Remove the entry associated with *key* from the dictionary."""
Expand Down
54 changes: 54 additions & 0 deletions pytools/test/test_persistent_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,60 @@ def method(self):
# }}}


# {{{ basic concurrency test

def _mp_fn(tmpdir: str) -> None:
import time
pdict: PersistentDict[int, int] = PersistentDict("pytools-test",
container_dir=tmpdir,
safe_sync=False)
n = 10000
s = 0

start = time.time()

for i in range(n):
if i % 100 == 0:
print(f"i={i}")
pdict[i] = i

try:
s += pdict[i]
except NoSuchEntryError:
# Someone else already deleted the entry
pass

try:
del pdict[i]
except NoSuchEntryError:
# Someone else already deleted the entry
pass

end = time.time()

print(f"PersistentDict: time taken to write {n} entries to "
f"{pdict.filename}: {end-start} s={s}")


def test_concurrency() -> None:
from multiprocessing import Process

tmpdir = "_tmp/" # must be the same across all processes in this test

try:
p = [Process(target=_mp_fn, args=(tmpdir, )) for _ in range(4)]
for pp in p:
pp.start()
for pp in p:
pp.join()

assert all(pp.exitcode == 0 for pp in p), [pp.exitcode for pp in p]
finally:
shutil.rmtree(tmpdir)

# }}}


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from setuptools import find_packages, setup


ver_dic = {}
version_file = open("pytools/version.py")
try:
Expand Down
Loading