Skip to content

Commit

Permalink
fix: missing keys indices
Browse files Browse the repository at this point in the history
  • Loading branch information
cristianmtr committed Jan 14, 2021
1 parent 5530733 commit 5c8a5e3
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 15 deletions.
22 changes: 18 additions & 4 deletions jina/executors/indexers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,31 @@ def flush(self):
except:
pass

def _filter_nonexistent_keys(self, keys: Iterator, existent_keys: Iterator, check_path: str):
indices_to_drop = []
def _filter_nonexistent_keys_values(self, keys: Iterator, values: Iterator, existent_keys: Iterator, check_path: str) -> Tuple[List, List]:
keys = list(keys)
values = list(values)
if len(keys) != len(values):
raise ValueError(f'Keys of length {len(keys)} did not match values of lenth {len(values)}')
indices_to_drop = self._get_indices_to_drop(keys, existent_keys, check_path)
keys = [keys[i] for i in range(len(keys)) if i not in indices_to_drop]
values = [values[i] for i in range(len(values)) if i not in indices_to_drop]
return keys, values

def _filter_nonexistent_keys(self, keys: Iterator, existent_keys: Iterator, check_path: str) -> List:
keys = list(keys)
indices_to_drop = self._get_indices_to_drop(keys, existent_keys, check_path)
keys = [keys[i] for i in range(len(keys)) if i not in indices_to_drop]
return keys

def _get_indices_to_drop(self, keys: List, existent_keys: Iterator, check_path: str):
indices_to_drop = []
for key_index, key in enumerate(keys):
if key not in existent_keys:
indices_to_drop.append(key_index)
if indices_to_drop:
self.logger.warning(
f'Key(s) {[keys[i] for i in indices_to_drop]} were not found in {check_path}. Continuing anyway...')
keys = [keys[i] for i in range(len(keys)) if i not in indices_to_drop]
return keys
return indices_to_drop


class BaseVectorIndexer(BaseIndexer):
Expand Down
22 changes: 13 additions & 9 deletions jina/executors/indexers/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(self, index_filename: str = None, *args, **kwargs):
raise ValueError(f"Field '{self.field}' not in supported list of {self.supported_fields}")

def add(self, doc_id: 'UniqueId', *args, **kwargs):
self._size += 1
self.query_handler.ids.append(doc_id)

# optimization. don't duplicate ids
Expand All @@ -76,6 +75,7 @@ def add(self, doc_id: 'UniqueId', *args, **kwargs):
if data is None:
raise ValueError(f'Got None from CacheDriver')
self.query_handler.content_hash.append(data)
self._size += 1

def query(self, data, *args, **kwargs) -> Optional[bool]:
"""
Expand All @@ -98,13 +98,15 @@ def update(self, keys: Iterator['UniqueId'], values: Iterator[any], *args, **kwa
"""
:param keys: list of Document.id
:param values: list of either `id` or `content_hash` of :class:`Document"""
keys = self._filter_nonexistent_keys(keys, self.query_handler.ids, self.save_abspath)
# if we don't cache anything else, no need
if self.field != ID_KEY:
keys, values = self._filter_nonexistent_keys_values(keys, values, self.query_handler.ids, self.save_abspath)

for key, cached_field in zip(keys, values):
key_idx = self.query_handler.ids.index(key)
# optimization. don't duplicate ids
if self.field != ID_KEY:
self.query_handler.content_hash[key_idx] = cached_field
for key, cached_field in zip(keys, values):
key_idx = self.query_handler.ids.index(key)
# optimization. don't duplicate ids
if self.field != ID_KEY:
self.query_handler.content_hash[key_idx] = cached_field

def delete(self, keys: Iterator['UniqueId'], *args, **kwargs):
"""
Expand All @@ -122,11 +124,13 @@ def delete(self, keys: Iterator['UniqueId'], *args, **kwargs):

def get_add_handler(self):
# not needed, as we use the queryhandler
pass
# FIXME better way to silence warnings
return 1

def get_query_handler(self) -> CacheHandler:
return self.CacheHandler(self.index_abspath, self.logger)

def get_create_handler(self):
# not needed, as we use the queryhandler
pass
# FIXME better way to silence warnings
return 1
4 changes: 3 additions & 1 deletion jina/executors/indexers/keyvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(self, *args, **kwargs):
self._page_size = mmap.ALLOCATIONGRANULARITY

def add(self, keys: Iterator[int], values: Iterator[bytes], *args, **kwargs):
if len(keys) != len(values):
raise ValueError(f'Len of keys {len(keys)} did not match len of values {len(values)}')
for key, value in zip(keys, values):
l = len(value) #: the length
p = int(self._start / self._page_size) * self._page_size #: offset of the page
Expand All @@ -78,7 +80,7 @@ def query(self, key: int) -> Optional[bytes]:
return m[r:]

def update(self, keys: Iterator[int], values: Iterator[bytes], *args, **kwargs):
keys = self._filter_nonexistent_keys(keys, self.query_handler.header.keys(), self.save_abspath)
keys, values = self._filter_nonexistent_keys_values(keys, values, self.query_handler.header.keys(), self.save_abspath)
self._delete(keys)
self.add(keys, values)
return
Expand Down
3 changes: 2 additions & 1 deletion jina/executors/indexers/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def add(self, keys: 'np.ndarray', vectors: 'np.ndarray', *args, **kwargs) -> Non
self._size += keys.shape[0]

def update(self, keys: Sequence[int], values: Sequence[bytes], *args, **kwargs) -> None:
keys = self._filter_nonexistent_keys(keys, self.ext2int_id.keys(), self.save_abspath)
# noinspection PyTypeChecker
keys, values = self._filter_nonexistent_keys_values(keys, values, self.ext2int_id.keys(), self.save_abspath)
# could be empty
# please do not use "if keys:", it wont work on both sequence and ndarray
if getattr(keys, 'size', len(keys)):
Expand Down

0 comments on commit 5c8a5e3

Please sign in to comment.