Skip to content

Commit

Permalink
fix: doccache compatibility (#2032)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristianmtr committed Feb 24, 2021
1 parent fbb8099 commit f45ec5d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
11 changes: 9 additions & 2 deletions jina/executors/indexers/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import pickle
import tempfile
from typing import Optional, Iterable, List, Tuple
from typing import Optional, Iterable, List, Tuple, Union

from jina.executors.indexers import BaseKVIndexer
from jina.helper import deprecated_alias

DATA_FIELD = 'data'
ID_KEY = 'id'
Expand Down Expand Up @@ -71,11 +72,17 @@ def close(self):

default_fields = (ID_KEY,)

def __init__(self, index_filename: Optional[str] = None, fields: Optional[Tuple[str]] = None, *args, **kwargs):
@deprecated_alias(field=('fields', 0))
def __init__(self,
index_filename: Optional[str] = None,
fields: Optional[Union[str, Tuple[str]]] = None, # str for backwards compatibility
*args, **kwargs):
if not index_filename:
# create a new temp file if not exist
index_filename = tempfile.NamedTemporaryFile(delete=False).name
super().__init__(index_filename, *args, **kwargs)
if isinstance(fields, str):
fields = (fields,)
# order shouldn't matter
self.fields = sorted(fields or self.default_fields)

Expand Down
39 changes: 37 additions & 2 deletions tests/unit/drivers/test_cache_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ def test_cache_driver_from_file(tmpdir, test_metas):
folder = os.path.join(test_metas["workspace"])
bin_full_path = os.path.join(folder, filename)
docs = DocumentSet(list(random_docs(10, embedding=False)))
pickle.dump({doc.id: BaseCacheDriver.hash_doc(doc, ['content_hash']) for doc in docs}, open(f'{bin_full_path}.bin.ids', 'wb'))
pickle.dump({BaseCacheDriver.hash_doc(doc, ['content_hash']): doc.id for doc in docs}, open(f'{bin_full_path}.bin.cache', 'wb'))
pickle.dump({doc.id: BaseCacheDriver.hash_doc(doc, ['content_hash']) for doc in docs},
open(f'{bin_full_path}.bin.ids', 'wb'))
pickle.dump({BaseCacheDriver.hash_doc(doc, ['content_hash']): doc.id for doc in docs},
open(f'{bin_full_path}.bin.cache', 'wb'))

driver = MockCacheDriver()
with DocCache(metas=test_metas, fields=(CONTENT_HASH_KEY,)) as executor:
Expand Down Expand Up @@ -266,3 +268,36 @@ def test_hash():
d2.tags['b'] = '23456'
assert BaseCacheDriver.hash_doc(d1, ['tags__a', 'tags__b']) == BaseCacheDriver.hash_doc(d1, ['tags__a', 'tags__b'])
assert BaseCacheDriver.hash_doc(d1, ['tags__a', 'tags__b']) != BaseCacheDriver.hash_doc(d2, ['tags__a', 'tags__b'])


def test_cache_legacy_field_type(tmp_path, test_metas):
filename = os.path.join(tmp_path, 'DocCache.bin')
doc1 = Document(id=1)
doc1.text = 'blabla'
doc1.update_content_hash()
docs1 = DocumentSet([doc1])

doc2 = Document(id=1)
doc2.text = 'blabla2'
doc2.update_content_hash()
docs2 = DocumentSet([doc2])

doc3 = Document(id=12312)
doc3.text = 'blabla'
doc3.update_content_hash()
docs3 = DocumentSet([doc3])

driver = MockBaseCacheDriver()

with DocCache(filename, metas=test_metas, field=CONTENT_HASH_KEY) as executor:
driver.attach(executor=executor, runtime=None)
assert executor.fields == [CONTENT_HASH_KEY]
driver._apply_all(docs1)
driver._apply_all(docs2)
assert executor.size == 2

with BaseExecutor.load(executor.save_abspath) as executor:
driver.attach(executor=executor, runtime=None)
assert executor.fields == [CONTENT_HASH_KEY]
with pytest.raises(NotImplementedError):
driver._apply_all(docs3)

0 comments on commit f45ec5d

Please sign in to comment.