Skip to content

Commit

Permalink
Merge pull request #338 from njgheorghita/sqlite-dict
Browse files Browse the repository at this point in the history
Add lru sql dict tooling
  • Loading branch information
njgheorghita committed Dec 28, 2020
2 parents c11553a + c73f3fd commit 341e84e
Show file tree
Hide file tree
Showing 8 changed files with 648 additions and 8 deletions.
3 changes: 3 additions & 0 deletions ddht/tools/driver/alexandria.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ddht.v5_1.abc import NetworkAPI
from ddht.v5_1.alexandria.abc import AlexandriaClientAPI, AlexandriaNetworkAPI
from ddht.v5_1.alexandria.advertisement_db import AdvertisementDatabase
from ddht.v5_1.alexandria.broadcast_log import BroadcastLog
from ddht.v5_1.alexandria.client import AlexandriaClient
from ddht.v5_1.alexandria.content_storage import MemoryContentStorage
from ddht.v5_1.alexandria.network import AlexandriaNetwork
Expand All @@ -35,6 +36,7 @@ def __init__(self, node: NodeAPI) -> None:
self.remote_advertisement_db = AdvertisementDatabase(
sqlite3.connect(":memory:"),
)
self.broadcast_log = BroadcastLog(sqlite3.connect(":memory:"))
self._lock = NamedLock()

@property
Expand Down Expand Up @@ -82,6 +84,7 @@ async def network(
commons_content_storage=self.commons_content_storage,
pinned_content_storage=self.pinned_content_storage,
local_advertisement_db=self.local_advertisement_db,
broadcast_log=self.broadcast_log,
remote_advertisement_db=self.remote_advertisement_db,
max_advertisement_count=max_advertisement_count,
)
Expand Down
275 changes: 275 additions & 0 deletions ddht/tools/lru_sql_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
import sqlite3
from typing import (
Any,
Callable,
Generic,
ItemsView,
Iterator,
MutableMapping,
NamedTuple,
Tuple,
TypeVar,
ValuesView,
)

from eth_utils.toolz import first

#
# SQL Schema
#
# key (primary key): bytes
# value: bytes
# pref: bytes - None if node is head, else points to the previous node's key
# nref: bytes - None if node is tail, else points to the next node's key


CREATE_CACHE_QUERY = """
CREATE TABLE IF NOT EXISTS cache (
key BLOB NOT NULL PRIMARY KEY,
value BLOB NOT NULL,
pref BLOB UNIQUE,
nref BLOB UNIQUE
)
"""

TKey = TypeVar("TKey")
TValue = TypeVar("TValue")


class Node(Generic[TKey, TValue], NamedTuple):
key: TKey
value: TValue
pref: TKey
nref: TKey


class LRUSQLDict(MutableMapping[TKey, TValue]):
"""
SQLite3-backed dictionary that implements an LRU cache.
"""

def __init__(
self,
conn: sqlite3.Connection,
key_encoder: Callable[[TKey], bytes],
key_decoder: Callable[[bytes], TKey],
value_encoder: Callable[[TValue], bytes],
value_decoder: Callable[[bytes], TValue],
cache_size: int = None,
) -> None:
self._conn = conn
self.cache_size = cache_size
self._key_encoder = key_encoder
self._key_decoder = key_decoder
self._value_encoder = value_encoder
self._value_decoder = value_decoder

self._execute(CREATE_CACHE_QUERY)

# evict lru key/value if local db size > current cache size
if self.cache_size:
while self.__len__() > self.cache_size:
self.__delitem__(self.tail.key)

def __iter__(self) -> Iterator[TKey]:
with self._conn:
for key in self._conn.execute("SELECT key FROM cache"):
yield self._key_decoder(first(key))

def __len__(self) -> int:
(result,) = self._fetch_single_query("SELECT COUNT(*) FROM cache;")
# ignore b/c mypy cannot interpret the result as an integer
return result # type: ignore

def __setitem__(self, key: TKey, value: TValue) -> None:
# setting / updating a key/value will move the pair to the head of the lru cache
try:
self.__getitem__(key)
except KeyError:
self._insert_item(key, value)
else:
self._update_item(key, value)

def __getitem__(self, key: TKey) -> TValue:
# accessing a key/value will move the pair to the head of the lru cache
serialized_key = self._key_encoder(key)
lookup_result = self._fetch_single_query(
"SELECT value FROM cache WHERE key=?;", (serialized_key,),
)

if not lookup_result:
raise KeyError(key)

(value,) = lookup_result

deserialized_value = self._value_decoder(value)

# update cache
# TODO: rather than move kv pair to head by deleting and re-inserting,
# change this to update all outdated references directly
self.__delitem__(key)
self._insert_item(key, deserialized_value)

return deserialized_value

def __delitem__(self, key: TKey) -> None:
serialized_key = self._key_encoder(key)
result = self._fetch_single_query(
"SELECT key FROM cache WHERE key=?;", (serialized_key,),
)

if not result:
raise KeyError(key)

node_key = first(result)

# delete key from cache
self._execute(
"DELETE FROM cache WHERE key=?;", (node_key,),
)

# update any nrefs/prefs in cache
nref_result = self._fetch_single_query(
"SELECT key FROM cache WHERE nref=?;", (node_key,),
)
pref_result = self._fetch_single_query(
"SELECT key FROM cache WHERE pref=?;", (node_key,),
)

if nref_result and pref_result:
nref_key = first(nref_result)
pref_key = first(pref_result)
self._execute(
"UPDATE cache SET nref=? WHERE key=?;", (pref_key, nref_key),
)
self._execute(
"UPDATE cache SET pref=? WHERE key=?;", (nref_key, pref_key),
)
elif nref_result:
self._execute(
"UPDATE cache SET nref=? WHERE key=?;", (None, first(nref_result)),
)
elif pref_result:
self._execute(
"UPDATE cache SET pref=? WHERE key=?;", (None, first(pref_result)),
)

def _insert_item(self, key: TKey, value: TValue) -> None:
# insert new item into map and cache
serialized_key = self._key_encoder(key)
serialized_value = self._value_encoder(value)
if self.is_empty:
new_pref = None
new_nref = None

self._execute(
"INSERT INTO cache VALUES (?, ?, ?, ?);",
(serialized_key, serialized_value, new_pref, new_nref),
)

else:
# evict lru key/value if local db size >= current cache size
while self.is_full:
self.__delitem__(self.tail.key)

# get old head
old_head_key = self._key_encoder(self.head.key)

# add new head to cache
self._execute(
"INSERT INTO cache VALUES (?, ?, ?, ?);",
(serialized_key, serialized_value, None, old_head_key),
)

# update old head in cache
self._execute(
"UPDATE cache SET pref=? WHERE key=?;", (serialized_key, old_head_key),
)

def _update_item(self, key: TKey, value: TValue) -> None:
serialized_key = self._key_encoder(key)
serialized_value = self._value_encoder(value)
self._execute(
"UPDATE cache SET value=? WHERE key=?;", (serialized_value, serialized_key)
)

def _fetch_single_query(self, query: str, args: Tuple[Any, ...] = ()) -> Any:
with self._conn:
cursor = self._conn.execute(query, args).fetchall()
if len(cursor) > 1:
raise Exception(
f"Invalid db state. More than one result found for query: {query}."
)
if not cursor:
return None
return first(cursor)

def _execute(self, query: str, args: Tuple[Any, ...] = ()) -> None:
with self._conn:
self._conn.execute(query, args)

@property
def is_full(self) -> bool:
if not self.cache_size:
return False
return self.__len__() >= self.cache_size

@property
def is_empty(self) -> bool:
return self.__len__() == 0

@property
def head(self) -> Node:
head = self._fetch_single_query(
"SELECT key,value,pref,nref FROM cache WHERE pref IS NULL;"
)
if not head:
raise KeyError("No head found.")
deserialized_key = self._key_decoder(head[0])
deserialized_value = self._value_decoder(head[1])
return Node(deserialized_key, deserialized_value, head[2], head[3])

@property
def tail(self) -> Node:
tail = self._fetch_single_query(
"SELECT key,value,pref,nref FROM cache WHERE nref IS NULL;"
)
if not tail:
raise KeyError("No tail found.")
deserialized_key = self._key_decoder(tail[0])
deserialized_value = self._value_decoder(tail[1])
return Node(deserialized_key, deserialized_value, tail[2], tail[3])

# custom iterator type not compatible with supertype
def values(self) -> ValuesView[TValue]: # type: ignore
for key in self.__iter__():
result = self._fetch_single_query(
"SELECT value FROM cache WHERE key=?;", (key,)
)
yield self._value_decoder(first(result))

# custom iterator type not compatible with supertype
def items(self) -> ItemsView[TKey, TValue]: # type: ignore
for key in self.__iter__():
result = self._fetch_single_query(
"SELECT value FROM cache WHERE key=?;", (key,)
)
yield key, self._value_decoder(first(result))

def iter_lru_cache(self) -> Iterator[Tuple[TKey, TValue]]:
if self.is_empty:
raise IndexError("Cannot iterate over empty dict.")

head = self.head
yield head.key, head.value
nref = head.nref

for _ in range(self.__len__() - 1):
result = self._fetch_single_query(
"SELECT key,value,pref,nref FROM cache WHERE key=?;", (nref,),
)
deserialized_key = self._key_decoder(result[0])
deserialized_value = self._value_decoder(result[1])
node = Node(deserialized_key, deserialized_value, result[2], result[3])
yield node.key, node.value
nref = node.nref
10 changes: 10 additions & 0 deletions ddht/v5_1/alexandria/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ def was_logged(
) -> bool:
...

@property
@abstractmethod
def cache_size(self) -> Optional[int]:
...

@property
@abstractmethod
def count(self) -> int:
...


class ContentValidatorAPI(ABC):
@abstractmethod
Expand Down
11 changes: 11 additions & 0 deletions ddht/v5_1/alexandria/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ddht.v5_1.alexandria.abc import AdvertisementDatabaseAPI, ContentStorageAPI
from ddht.v5_1.alexandria.advertisement_db import AdvertisementDatabase
from ddht.v5_1.alexandria.boot_info import AlexandriaBootInfo
from ddht.v5_1.alexandria.broadcast_log import BroadcastLog
from ddht.v5_1.alexandria.content_storage import (
FileSystemContentStorage,
MemoryContentStorage,
Expand Down Expand Up @@ -108,13 +109,17 @@ async def run(self) -> None:
sqlite3.connect(str(remote_advertisement_db_path)),
)

broadcast_log_db_path = xdg_alexandria_root / "broadcast_log.sqlite3"
broadcast_log = BroadcastLog(sqlite3.connect(str(broadcast_log_db_path)))

alexandria_network = AlexandriaNetwork(
network=self.base_protocol_app.network,
bootnodes=self._alexandria_boot_info.bootnodes,
commons_content_storage=commons_content_storage,
pinned_content_storage=pinned_content_storage,
local_advertisement_db=local_advertisement_db,
remote_advertisement_db=remote_advertisement_db,
broadcast_log=broadcast_log,
commons_content_storage_max_size=commons_content_storage_max_size,
max_advertisement_count=max_advertisement_count,
)
Expand Down Expand Up @@ -147,6 +152,12 @@ async def run(self) -> None:
remote_advertisement_db.count(),
max_advertisement_count,
)
self.logger.info(
"BroadcastLog: storage=%s total=%d max=%d",
broadcast_log_db_path,
broadcast_log.count,
broadcast_log.cache_size,
)

await alexandria_network.ready()

Expand Down

0 comments on commit 341e84e

Please sign in to comment.