Skip to content

Commit

Permalink
refactor: cache runtime optimization (#1853)
Browse files Browse the repository at this point in the history
* refactor: cache runtime optimization

* refactor: cache driver test
  • Loading branch information
florian-hoenicke committed Feb 4, 2021
1 parent 5c07c3e commit a5c2213
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 48 deletions.
6 changes: 3 additions & 3 deletions jina/drivers/cache.py
Expand Up @@ -33,10 +33,10 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
if self.field == CONTENT_HASH_KEY:
data = d.content_hash
result = self.exec[data]
if result is None:
self.on_miss(d, data)
else:
if result:
self.on_hit(d, result)
else:
self.on_miss(d, data)

def on_miss(self, req_doc: 'Document', data: Any) -> None:
"""Function to call when document is missing, the default behavior is to add to cache when miss.
Expand Down
2 changes: 1 addition & 1 deletion jina/executors/indexers/__init__.py
Expand Up @@ -201,7 +201,7 @@ def _filter_nonexistent_keys_values(self, keys: Iterable, values: Iterable, exis
keys = list(keys)
values = list(values)
if len(keys) != len(values):
raise ValueError(f'Keys of length {len(keys)} did not match values of lenth {len(values)}')
raise ValueError(f'Keys of length {len(keys)} did not match values of length {len(values)}')
indices_to_drop = self._get_indices_to_drop(keys, existent_keys, check_path)
keys = [keys[i] for i in range(len(keys)) if i not in indices_to_drop]
values = [values[i] for i in range(len(values)) if i not in indices_to_drop]
Expand Down
65 changes: 27 additions & 38 deletions jina/executors/indexers/cache.py
Expand Up @@ -8,10 +8,6 @@
ID_KEY = 'id'
CONTENT_HASH_KEY = 'content_hash'

# noinspection PyUnreachableCode
if False:
from jina.types.document import UniqueId


class BaseCache(BaseKVIndexer):
"""Base class of the cache inherited :class:`BaseKVIndexer`
Expand All @@ -37,18 +33,17 @@ class CacheHandler:
def __init__(self, path, logger):
self.path = path
try:
# TODO maybe mmap?
self.ids = pickle.load(open(path + '.ids', 'rb'))
self.content_hash = pickle.load(open(path + '.cache', 'rb'))
self.id_to_cache_val = pickle.load(open(path + '.ids', 'rb'))
self.cache_val_to_id = pickle.load(open(path + '.cache', 'rb'))
except FileNotFoundError as e:
logger.warning(
f'File path did not exist : {path}.ids or {path}.cache: {e!r}. Creating new CacheHandler...')
self.ids = []
self.content_hash = []
self.id_to_cache_val = dict()
self.cache_val_to_id = dict()

def close(self):
pickle.dump(self.ids, open(self.path + '.ids', 'wb'))
pickle.dump(self.content_hash, open(self.path + '.cache', 'wb'))
pickle.dump(self.id_to_cache_val, open(self.path + '.ids', 'wb'))
pickle.dump(self.cache_val_to_id, open(self.path + '.cache', 'wb'))

supported_fields = [ID_KEY, CONTENT_HASH_KEY]
default_field = ID_KEY
Expand All @@ -67,19 +62,17 @@ def __init__(self, index_filename: Optional[str] = None, field: Optional[str] =
if self.field not in self.supported_fields:
raise ValueError(f"Field '{self.field}' not in supported list of {self.supported_fields}")

def add(self, doc_id: 'UniqueId', *args, **kwargs):
def add(self, doc_id: str, *args, **kwargs):
"""Add a document to the cache depending on `self.field`.
:param doc_id: document id to be added
"""
self.query_handler.ids.append(doc_id)

# optimization. don't duplicate ids
if self.field != ID_KEY:
data = kwargs.get(DATA_FIELD, None)
if data is None:
raise ValueError(f'Got None from CacheDriver')
self.query_handler.content_hash.append(data)
else:
data = doc_id
self.query_handler.id_to_cache_val[doc_id] = data
self.query_handler.cache_val_to_id[data] = doc_id
self._size += 1

def query(self, data, *args, **kwargs) -> Optional[bool]:
Expand All @@ -88,39 +81,35 @@ def query(self, data, *args, **kwargs) -> Optional[bool]:
:param data: either the id or the content_hash of a Document
:return: status
"""
if self.field == ID_KEY:
status = (data in self.query_handler.ids) or None
else:
status = (data in self.query_handler.content_hash) or None

return status
return data in self.query_handler.cache_val_to_id


def update(self, keys: Iterable['UniqueId'], values: Iterable[any], *args, **kwargs):
def update(self, keys: Iterable[str], values: Iterable[any], *args, **kwargs):
"""Update cached documents.
:param keys: list of Document.id
:param values: list of either `id` or `content_hash` of :class:`Document`"""
# if we don't cache anything else, no need
if self.field != ID_KEY:
keys, values = self._filter_nonexistent_keys_values(keys, values, self.query_handler.ids, self.save_abspath)
for key, value in zip(keys, values):
if key not in self.query_handler.id_to_cache_val:
continue
old_value = self.query_handler.id_to_cache_val[key]
self.query_handler.id_to_cache_val[key] = value
del self.query_handler.cache_val_to_id[old_value]
self.query_handler.cache_val_to_id[value] = key

for key, cached_field in zip(keys, values):
key_idx = self.query_handler.ids.index(key)
# optimization. don't duplicate ids
if self.field != ID_KEY:
self.query_handler.content_hash[key_idx] = cached_field

def delete(self, keys: Iterable['UniqueId'], *args, **kwargs):
def delete(self, keys: Iterable[str], *args, **kwargs):
"""Delete documents from the cache.
:param keys: list of Document.id
"""
keys = self._filter_nonexistent_keys(keys, self.query_handler.ids, self.save_abspath)

for key in keys:
key_idx = self.query_handler.ids.index(key)
self.query_handler.ids = [query_id for idx, query_id in enumerate(self.query_handler.ids) if idx != key_idx]
if self.field != ID_KEY:
self.query_handler.content_hash = [cached_field for idx, cached_field in
enumerate(self.query_handler.content_hash) if idx != key_idx]
if key not in self.query_handler.id_to_cache_val:
continue
value = self.query_handler.id_to_cache_val[key]
del self.query_handler.id_to_cache_val[key]
del self.query_handler.cache_val_to_id[value]
self._size -= 1

def get_add_handler(self):
Expand Down
1 change: 0 additions & 1 deletion tests/integration/docidcache/test_crud_cache.py
Expand Up @@ -141,7 +141,6 @@ def check_docs(chunk_content, chunks, same_content, docs, ids_used, index_start=

def check_indexers_size(chunks, nr_docs, field, tmp_path, same_content, shards, post_op):
cache_indexer_path = tmp_path / 'cache.bin'
cache_full_size = 0
with BaseIndexer.load(cache_indexer_path) as cache:
assert isinstance(cache, DocCache)
cache_full_size = cache.size
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/drivers/test_cache_driver.py
Expand Up @@ -76,8 +76,8 @@ def test_cache_driver_from_file(tmpdir, test_metas):
folder = os.path.join(test_metas["workspace"])
bin_full_path = os.path.join(folder, filename)
docs = list(random_docs(10, embedding=False))
pickle.dump([doc.id for doc in docs], open(f'{bin_full_path}.ids', 'wb'))
pickle.dump([doc.content_hash for doc in docs], open(f'{bin_full_path}.cache', 'wb'))
pickle.dump({doc.id: doc.content_hash for doc in docs}, open(f'{bin_full_path}.ids', 'wb'))
pickle.dump({doc.content_hash: doc.id for doc in docs}, open(f'{bin_full_path}.cache', 'wb'))

driver = MockCacheDriver()
with DocCache(filename, metas=test_metas, field=CONTENT_HASH_KEY) as executor:
Expand Down Expand Up @@ -150,18 +150,18 @@ def test_cache_content_driver_same_content(tmpdir, test_metas):

with BaseExecutor.load(filename) as executor:
assert executor.query(doc1.content_hash) is True
assert executor.query(old_doc.content_hash) is None
assert executor.query(old_doc.content_hash) is False

# delete
with BaseExecutor.load(filename) as executor:
executor.delete([doc1.id])

with BaseExecutor.load(filename) as executor:
assert executor.query(doc1.content_hash) is None
assert executor.query(doc1.content_hash) is False


def test_cache_content_driver_same_id(tmp_path, test_metas):
filename = tmp_path / 'DocCache.bin'
filename = os.path.join(tmp_path, 'DocCache.bin')
doc1 = Document(id=1)
doc1.text = 'blabla'
doc1.update_content_hash()
Expand Down

0 comments on commit a5c2213

Please sign in to comment.