Skip to content

Commit

Permalink
fix: keys handling
Browse files Browse the repository at this point in the history
  • Loading branch information
cristianmtr committed Jan 13, 2021
1 parent 69def5d commit d00290e
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 37 deletions.
26 changes: 20 additions & 6 deletions jina/executors/indexers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,31 @@ def flush(self):
except:
pass

def _filter_nonexistent_keys(self, keys: Iterator, existent_keys: Iterator, check_path: str):
indices_to_drop = []
def _filter_nonexistent_keys_values(self, keys: Iterator, values: Iterator, existent_keys: Iterator, check_path: str) -> Tuple[List, List]:
keys = list(keys)
values = list(values)
indices_to_drop = self._get_indices_to_drop(keys, existent_keys, check_path)
keys = [keys[i] for i in range(len(keys)) if i not in indices_to_drop]
values = [values[i] for i in range(len(values)) if i not in indices_to_drop]
return keys, values

def _filter_nonexistent_keys(self, keys: Iterator, existent_keys: Iterator, check_path: str) -> List:
keys = list(keys)
indices_to_drop = self._get_indices_to_drop(keys, existent_keys, check_path)
keys = [keys[i] for i in range(len(keys)) if i not in indices_to_drop]
return keys

def _get_indices_to_drop(self, keys: List, existent_keys: Iterator, check_path: str):
indices_to_drop = []
for key_index, key in enumerate(keys):
if key not in existent_keys:
indices_to_drop.append(key_index)
if indices_to_drop:
self.logger.warning(
f'Key(s) {[keys[i] for i in indices_to_drop]} were not found in {check_path}. Continuing anyway...')
keys = [keys[i] for i in range(len(keys)) if i not in indices_to_drop]
return keys
# TODO
pass
# self.logger.warning(
# f'Key(s) {[keys[i] for i in indices_to_drop]} were not found in {check_path}. Continuing anyway...')
return indices_to_drop


class BaseVectorIndexer(BaseIndexer):
Expand Down
17 changes: 9 additions & 8 deletions jina/executors/indexers/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def __init__(self, index_filename: str = None, *args, **kwargs):
raise ValueError(f"Field '{self.field}' not in supported list of {self.supported_fields}")

def add(self, doc_id: 'UniqueId', *args, **kwargs):
# TODO
self._size += 1
self.query_handler.ids.append(doc_id)

# optimization. don't duplicate ids
Expand All @@ -77,6 +75,7 @@ def add(self, doc_id: 'UniqueId', *args, **kwargs):
if data is None:
raise ValueError(f'Got None from CacheDriver')
self.query_handler.content_hash.append(data)
self._size += 1

def query(self, data, *args, **kwargs) -> Optional[bool]:
"""
Expand All @@ -99,13 +98,15 @@ def update(self, keys: Iterator['UniqueId'], values: Iterator[any], *args, **kwa
"""
:param keys: list of Document.id
:param values: list of either `id` or `content_hash` of :class:`Document"""
keys = self._filter_nonexistent_keys(keys, self.query_handler.ids, self.save_abspath)
# if we don't cache anything else, no need
if self.field != ID_KEY:
keys, values = self._filter_nonexistent_keys_values(keys, values, self.query_handler.ids, self.save_abspath)

for key, cached_field in zip(keys, values):
key_idx = self.query_handler.ids.index(key)
# optimization. don't duplicate ids
if self.field != ID_KEY:
self.query_handler.content_hash[key_idx] = cached_field
for key, cached_field in zip(keys, values):
key_idx = self.query_handler.ids.index(key)
# optimization. don't duplicate ids
if self.field != ID_KEY:
self.query_handler.content_hash[key_idx] = cached_field

def delete(self, keys: Iterator['UniqueId'], *args, **kwargs):
"""
Expand Down
4 changes: 3 additions & 1 deletion jina/executors/indexers/keyvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(self, *args, **kwargs):
self._page_size = mmap.ALLOCATIONGRANULARITY

def add(self, keys: Iterator[int], values: Iterator[bytes], *args, **kwargs):
if len(keys) != len(values):
raise ValueError(f'Len of keys {len(keys)} did not match len of values {len(values)}')
for key, value in zip(keys, values):
l = len(value) #: the length
p = int(self._start / self._page_size) * self._page_size #: offset of the page
Expand All @@ -78,7 +80,7 @@ def query(self, key: int) -> Optional[bytes]:
return m[r:]

def update(self, keys: Iterator[int], values: Iterator[bytes], *args, **kwargs):
keys = self._filter_nonexistent_keys(keys, self.query_handler.header.keys(), self.save_abspath)
keys, values = self._filter_nonexistent_keys_values(keys, values, self.query_handler.header.keys(), self.save_abspath)
self._delete(keys)
self.add(keys, values)
return
Expand Down
3 changes: 2 additions & 1 deletion jina/executors/indexers/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def add(self, keys: 'np.ndarray', vectors: 'np.ndarray', *args, **kwargs) -> Non
self._size += keys.shape[0]

def update(self, keys: Sequence[int], values: Sequence[bytes], *args, **kwargs) -> None:
keys = self._filter_nonexistent_keys(keys, self.ext2int_id.keys(), self.save_abspath)
# noinspection PyTypeChecker
keys, values = self._filter_nonexistent_keys_values(keys, values, self.ext2int_id.keys(), self.save_abspath)
# could be empty
# please do not use "if keys:", it wont work on both sequence and ndarray
if getattr(keys, 'size', len(keys)):
Expand Down
38 changes: 19 additions & 19 deletions tests/integration/docidcache/test_crud_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def config_env(field, tmp_workspace, shards, indexers, polling):


np.random.seed(0)
d_embedding = np.random.random([9])
c_embedding = np.random.random([9])
d_embedding = np.array([1, 1, 1])
c_embedding = np.array([2, 2, 2])


def get_documents(chunks, same_content, nr=10, index_start=0):
Expand All @@ -51,7 +51,7 @@ def get_documents(chunks, same_content, nr=10, index_start=0):
d.embedding = d_embedding
else:
d.text = f'hello world {i}'
d.embedding = np.random.random([9])
d.embedding = np.random.random(d_embedding.shape)
for j in range(chunks):
with Document() as c:
c.id = next_chunk_id
Expand All @@ -60,7 +60,7 @@ def get_documents(chunks, same_content, nr=10, index_start=0):
c.embedding = c_embedding
else:
c.text = f'hello world from chunk {j}'
c.embedding = np.random.random([9])
c.embedding = np.random.random(d_embedding.shape)

next_chunk_id += 1
d.chunks.append(c)
Expand All @@ -74,7 +74,7 @@ def get_documents(chunks, same_content, nr=10, index_start=0):

@pytest.mark.parametrize('chunks', [0, 3, 5])
@pytest.mark.parametrize('same_content', [False, True])
@pytest.mark.parametrize('nr', [0, 10, 100])
@pytest.mark.parametrize('nr', [0, 10, 100, 201])
def test_docs_generator(chunks, same_content, nr):
chunk_content = None
docs = list(get_documents(chunks=chunks, same_content=same_content, nr=nr))
Expand Down Expand Up @@ -126,18 +126,18 @@ def check_docs(chunk_content, chunks, same_content, docs, ids_used, index_start=

@pytest.mark.parametrize('indexers, field, shards, chunks, same_content',
[
# ('sequential', 'id', 1, 5, False),
# ('sequential', 'id', 3, 5, False),
('sequential', 'id', 1, 5, False),
('sequential', 'id', 3, 5, False),
('sequential', 'id', 3, 5, True),
# ('sequential', 'content_hash', 1, 0, False),
# ('sequential', 'content_hash', 1, 0, True),
# ('sequential', 'content_hash', 1, 5, False),
# ('sequential', 'content_hash', 1, 5, True),
# ('sequential', 'content_hash', 3, 5, True),
# ('parallel', 'id', 3, 5, False),
# ('parallel', 'id', 3, 5, True),
# ('parallel', 'content_hash', 3, 5, False),
# ('parallel', 'content_hash', 3, 5, True)
('sequential', 'content_hash', 1, 0, False),
('sequential', 'content_hash', 1, 0, True),
('sequential', 'content_hash', 1, 5, False),
('sequential', 'content_hash', 1, 5, True),
('sequential', 'content_hash', 3, 5, True),
('parallel', 'id', 3, 5, False),
('parallel', 'id', 3, 5, True),
('parallel', 'content_hash', 3, 5, False),
('parallel', 'content_hash', 3, 5, True)
])
def test_cache_crud(
tmp_path,
Expand All @@ -157,7 +157,7 @@ def validate_results(resp):

return validate_results

print(f'{tmp_path=}')
print(f'tmp path = {tmp_path}')

config_env(field, tmp_path, shards, indexers, polling='any')
f = Flow.load_config(os.path.abspath('yml/crud_cache_flow.yml'))
Expand All @@ -166,7 +166,7 @@ def validate_results(resp):

# initial data index
with f:
f.index(docs, batch_size=3)
f.index(docs, batch_size=4)

check_indexers_size(chunks, len(docs), field, tmp_path, same_content, shards, 'index')

Expand All @@ -176,7 +176,7 @@ def validate_results(resp):

new_docs = list(get_documents(chunks=chunks, same_content=same_content, index_start=index_start_new_docs))
with f:
f.index(new_docs, batch_size=3)
f.index(new_docs, batch_size=4)

check_indexers_size(chunks, len(docs), field, tmp_path, same_content, shards, 'index2')

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/docidcache/yml/cp_cache_kv.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ requests:
with:
executor: $JINA_KV_IDX_NAME
top_k: $JINA_TOPK
traversal_paths: [r, m]
traversal_paths: [r]
2 changes: 1 addition & 1 deletion tests/integration/docidcache/yml/cp_cache_vector.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ requests:
executor: $JINA_VEC_IDX_NAME
top_k: $JINA_TOPK
fill_embedding: True
traversal_paths: [r, m]
traversal_paths: [r]

0 comments on commit d00290e

Please sign in to comment.