Skip to content

Commit

Permalink
refactor: cache data field (#1878)
Browse files Browse the repository at this point in the history
* refactor: cache data field

* refactor: indexer interfaces

* fix: import vector indexer

* refactor: query result should be empty not none

* test: fix cache driver
  • Loading branch information
florian-hoenicke committed Feb 8, 2021
1 parent 31991f2 commit 2b1e6e6
Show file tree
Hide file tree
Showing 13 changed files with 96 additions and 93 deletions.
12 changes: 7 additions & 5 deletions jina/drivers/cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict

from .index import BaseIndexDriver
from ..executors.indexers.cache import DATA_FIELD, CONTENT_HASH_KEY, ID_KEY
from ..executors.indexers.cache import CONTENT_HASH_KEY, ID_KEY

if False:
from .. import Document
Expand Down Expand Up @@ -38,16 +38,16 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
else:
self.on_miss(d, data)

def on_miss(self, req_doc: 'Document', data: Any) -> None:
def on_miss(self, req_doc: 'Document', value: str) -> None:
"""Function to call when document is missing, the default behavior is to add to cache when miss.
:param req_doc: the document in the request but missed in the cache
:param data: the data besides the `req_doc.id` to be passed through to the executors
:param value: the data besides the `req_doc.id` to be passed through to the executors
"""
if self.with_serialization:
self.exec_fn(req_doc.id, req_doc.SerializeToString(), **{DATA_FIELD: data})
self.exec_fn([req_doc.id], req_doc.SerializeToString(), [value])
else:
self.exec_fn(req_doc.id, **{DATA_FIELD: data})
self.exec_fn([req_doc.id], [value])

def on_hit(self, req_doc: 'Document', hit_result: Any) -> None:
"""Function to call when document is hit.
Expand All @@ -71,4 +71,6 @@ def __init__(self, tags: Dict, *args, **kwargs):
self._tags = tags

def on_hit(self, req_doc: 'Document', hit_result: Any) -> None:
"""Function to call when document is hit.
"""
req_doc.tags.update(self._tags)
22 changes: 10 additions & 12 deletions jina/drivers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,13 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
idx, dist = self.exec_fn(embed_vecs, top_k=int(self.top_k))

op_name = self.exec.__class__.__name__
# can be None if index is size 0
if idx is not None and dist is not None:
for doc, topks, scores in zip(doc_pts, idx, dist):

topk_embed = fill_fn(topks) if (self._fill_embedding and fill_fn) else [None] * len(topks)
for numpy_match_id, score, vec in zip(topks, scores, topk_embed):
m = Document(id=numpy_match_id)
m.score = NamedScore(op_name=op_name,
value=score)
r = doc.matches.append(m)
if vec is not None:
r.embedding = vec
for doc, topks, scores in zip(doc_pts, idx, dist):

topk_embed = fill_fn(topks) if (self._fill_embedding and fill_fn) else [None] * len(topks)
for numpy_match_id, score, vec in zip(topks, scores, topk_embed):
m = Document(id=numpy_match_id)
m.score = NamedScore(op_name=op_name,
value=score)
r = doc.matches.append(m)
if vec is not None:
r.embedding = vec
25 changes: 12 additions & 13 deletions jina/executors/indexers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,33 +203,32 @@ def query_by_key(self, keys: Iterable[str], *args, **kwargs) -> 'np.ndarray':
"""
raise NotImplementedError

def add(self, keys: Iterable[str], vectors: 'np.ndarray', *args, **kwargs):
def add(self, keys: Iterable[str], vectors: 'np.ndarray', *args, **kwargs) -> None:
"""Add new chunks and their vector representations
:param keys: a list of ``id``, i.e. ``doc.id`` in protobuf
:param vectors: vector representations in B x D
"""
raise NotImplementedError

def query(self, query_vectors: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple['np.ndarray', 'np.ndarray']:
def query(self, vectors: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple['np.ndarray', 'np.ndarray']:
"""Find k-NN using query vectors, return chunk ids and chunk scores
:param query_vectors: query vectors in ndarray, shape B x D
:param vectors: query vectors in ndarray, shape B x D
:param top_k: int, the number of nearest neighbour to return
:return: a tuple of two ndarray.
The first is ids in shape B x K (`dtype=int`), the second is scores in shape B x K (`dtype=float`)
:return: ids as ndarray (`dtype=int`) and scores as ndarray (`dtype=float)
"""
raise NotImplementedError

def update(self, keys: Iterable[str], values: Iterable[bytes], *args, **kwargs):
def update(self, keys: Iterable[str], vectors: 'np.ndarray', *args, **kwargs) -> None:
"""Update vectors on the index.
:param keys: a list of ``id``, i.e. ``doc.id`` in protobuf
:param values: vector representations in B x D
:param vectors: vector representations in B x D
"""
raise NotImplementedError

def delete(self, keys: Iterable[str], *args, **kwargs):
def delete(self, keys: Iterable[str], *args, **kwargs) -> None:
"""Delete vectors from the index.
:param keys: a list of ``id``, i.e. ``doc.id`` in protobuf
Expand All @@ -245,38 +244,38 @@ class BaseKVIndexer(BaseIndexer):
It can be used to tell whether an indexer is key-value indexer, via ``isinstance(a, BaseKVIndexer)``
"""

def add(self, keys: Iterable[str], values: Iterable[bytes], *args, **kwargs):
def add(self, keys: Iterable[str], values: Iterable[bytes], *args, **kwargs) -> None:
"""Add the serialized documents to the index via document ids.
:param keys: a list of ``id``, i.e. ``doc.id`` in protobuf
:param values: serialized documents
"""
raise NotImplementedError

def query(self, key: Any) -> Optional[Any]:
def query(self, key: str) -> Optional[bytes]:
"""Find the serialized document to the index via document id.
:param key: document id
:return: serialized documents
"""
raise NotImplementedError

def update(self, keys: Iterable[str], values: Iterable[bytes], *args, **kwargs):
def update(self, keys: Iterable[str], values: Iterable[bytes], *args, **kwargs) -> None:
"""Update the serialized documents on the index via document ids.
:param keys: a list of ``id``, i.e. ``doc.id`` in protobuf
:param values: serialized documents
"""
raise NotImplementedError

def delete(self, keys: Iterable[str], *args, **kwargs):
def delete(self, keys: Iterable[str], *args, **kwargs) -> None:
"""Delete the serialized documents from the index via document ids.
:param keys: a list of ``id``, i.e. ``doc.id`` in protobuf
"""
raise NotImplementedError

def __getitem__(self, key: Any) -> Optional[Any]:
def __getitem__(self, key: Any) -> Optional[bytes]:
return self.query(key)


Expand Down
35 changes: 17 additions & 18 deletions jina/executors/indexers/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
class BaseCache(BaseKVIndexer):
"""Base class of the cache inherited :class:`BaseKVIndexer`
The difference between a cache and a :class:`BaseKVIndexer` is the ``handler_mutex`` is released in cache, this allows one to query-while-indexing.
The difference between a cache and a :class:`BaseKVIndexer` is the ``handler_mutex`` is released in cache,
this allows one to query-while-indexing.
"""

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -49,7 +50,7 @@ def close(self):
default_field = ID_KEY

def __init__(self, index_filename: Optional[str] = None, field: Optional[str] = None, *args, **kwargs):
""" Create a new DocCache
"""Create a new DocCache
:param index_filename: file name for storing the cache data
:param field: field to cache on (ID_KEY or CONTENT_HASH_KEY)
Expand All @@ -62,28 +63,26 @@ def __init__(self, index_filename: Optional[str] = None, field: Optional[str] =
if self.field not in self.supported_fields:
raise ValueError(f"Field '{self.field}' not in supported list of {self.supported_fields}")

def add(self, doc_id: str, *args, **kwargs):
"""Add a document to the cache depending on `self.field`.
def add(self, keys: Iterable[str], values: Iterable[str], *args, **kwargs) -> None:
"""Add a document to the cache depending.
:param doc_id: document id to be added
:param keys: document ids to be added
:param values: document cache values to be added
"""
if self.field != ID_KEY:
data = kwargs.get(DATA_FIELD, None)
else:
data = doc_id
self.query_handler.id_to_cache_val[doc_id] = data
self.query_handler.cache_val_to_id[data] = doc_id
self._size += 1

def query(self, data: str, *args, **kwargs) -> Optional[bool]:
for key, value in zip(keys, values):
self.query_handler.id_to_cache_val[key] = value
self.query_handler.cache_val_to_id[value] = key
self._size += 1

def query(self, key: str, *args, **kwargs) -> bool:
"""Check whether the data exists in the cache.
:param data: either the id or the content_hash of a Document
:param key: either the id or the content_hash of a Document
:return: status
"""
return data in self.query_handler.cache_val_to_id
return key in self.query_handler.cache_val_to_id

def update(self, keys: Iterable[str], values: Iterable[any], *args, **kwargs):
def update(self, keys: Iterable[str], values: Iterable[str], *args, **kwargs) -> None:
"""Update cached documents.
:param keys: list of Document.id
:param values: list of either `id` or `content_hash` of :class:`Document`"""
Expand All @@ -97,7 +96,7 @@ def update(self, keys: Iterable[str], values: Iterable[any], *args, **kwargs):
del self.query_handler.cache_val_to_id[old_value]
self.query_handler.cache_val_to_id[value] = key

def delete(self, keys: Iterable[str], *args, **kwargs):
def delete(self, keys: Iterable[str], *args, **kwargs) -> None:
"""Delete documents from the cache.
:param keys: list of Document.id
"""
Expand Down
20 changes: 13 additions & 7 deletions jina/executors/indexers/keyvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,21 @@ def __init__(self, path, key_length):
def close(self):
self._body.close()

def get_add_handler(self):
def get_add_handler(self) -> 'WriteHandler':
"""Get write file handler.
"""
# keep _start position as in pickle serialization
return self.WriteHandler(self.index_abspath, 'ab')

def get_create_handler(self):
def get_create_handler(self) -> 'WriteHandler':
"""Get write file handler.
"""
self._start = 0 # override _start position
return self.WriteHandler(self.index_abspath, 'wb')

def get_query_handler(self):
def get_query_handler(self) -> 'ReadHandler':
"""Get read file handler.
"""
return self.ReadHandler(self.index_abspath, self.key_length)

def __init__(self, *args, **kwargs):
Expand All @@ -57,7 +63,7 @@ def __init__(self, *args, **kwargs):
self._start = 0
self._page_size = mmap.ALLOCATIONGRANULARITY

def add(self, keys: Iterable[str], values: Iterable[bytes], *args, **kwargs):
def add(self, keys: Iterable[str], values: Iterable[bytes], *args, **kwargs) -> None:
"""Add the serialized documents to the index via document ids.
:param keys: a list of ``id``, i.e. ``doc.id`` in protobuf
Expand Down Expand Up @@ -93,7 +99,7 @@ def query(self, key: str) -> Optional[bytes]:
with mmap.mmap(self.query_handler.body, offset=p, length=l) as m:
return m[r:]

def update(self, keys: Iterable[str], values: Iterable[bytes], *args, **kwargs):
def update(self, keys: Iterable[str], values: Iterable[bytes], *args, **kwargs) -> None:
"""Update the serialized documents on the index via document ids.
:param keys: a list of ``id``, i.e. ``doc.id`` in protobuf
Expand All @@ -103,7 +109,7 @@ def update(self, keys: Iterable[str], values: Iterable[bytes], *args, **kwargs):
self._delete(keys)
self.add(keys, values)

def _delete(self, keys: Iterable[str]):
def _delete(self, keys: Iterable[str]) -> None:
self.query_handler.close()
self.handler_mutex = False
for key in keys:
Expand All @@ -118,7 +124,7 @@ def _delete(self, keys: Iterable[str]):
del self.query_handler.header[key]
self._size -= 1

def delete(self, keys: Iterable[str], *args, **kwargs):
def delete(self, keys: Iterable[str], *args, **kwargs) -> None:
"""Delete the serialized documents from the index via document ids.
:param keys: a list of ``id``, i.e. ``doc.id`` in protobuf
Expand Down
32 changes: 16 additions & 16 deletions jina/executors/indexers/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
__license__ = "Apache-2.0"

import gzip
import io
import os
from functools import lru_cache
from os import path
from typing import Optional, Iterable, Tuple, Dict, Sequence
from typing import Optional, Iterable, Tuple, Dict

import numpy as np

Expand Down Expand Up @@ -82,8 +83,8 @@ def index_abspath(self) -> str:
"""
return self.get_file_from_workspace(self.index_filename)

def get_add_handler(self):
"""Open a binary gzip file for adding new vectors
def get_add_handler(self) -> 'io.BufferedWriter':
"""Open a binary gzip file for appending new vectors
:return: a gzip file stream
"""
Expand All @@ -92,8 +93,8 @@ def get_add_handler(self):
else:
return open(self.index_abspath, 'ab')

def get_create_handler(self):
"""Create a new gzip file for adding new vectors
def get_create_handler(self) -> 'io.BufferedWriter':
"""Create a new gzip file for adding new vectors. The old vectors are replaced.
:return: a gzip file stream
"""
Expand Down Expand Up @@ -135,14 +136,14 @@ def _add(self, keys: 'np.ndarray', vectors: 'np.ndarray'):
self.key_bytes += keys.tobytes()
self._size += keys.shape[0]

def update(self, keys: Iterable[str], values: Sequence[bytes], *args, **kwargs) -> None:
def update(self, keys: Iterable[str], vectors: 'np.ndarray', *args, **kwargs) -> None:
"""Update the embeddings on the index via document ids.
:param keys: a list of ``id``, i.e. ``doc.id`` in protobuf
:param values: embeddings
:param vectors: embeddings
"""
# noinspection PyTypeChecker
keys, values = self._filter_nonexistent_keys_values(keys, values, self._ext2int_id.keys())
keys, values = self._filter_nonexistent_keys_values(keys, vectors, self._ext2int_id.keys())
np_keys = np.array(keys, (np.str_, self.key_length))

if np_keys.size:
Expand Down Expand Up @@ -209,7 +210,7 @@ def _raw_ndarray(self) -> Optional['np.ndarray']:
return np.memmap(self.index_abspath, dtype=self.dtype, mode='r',
shape=(self.size + deleted_keys, self.num_dim))

def query_by_key(self, keys: Sequence[str], *args, **kwargs) -> Optional['np.ndarray']:
def query_by_key(self, keys: Iterable[str], *args, **kwargs) -> 'np.ndarray':
"""
Search the index by the external key (passed during `.add(`).
Expand Down Expand Up @@ -325,8 +326,7 @@ def _get_sorted_top_k(dist: 'np.array', top_k: int) -> Tuple['np.ndarray', 'np.n

return idx, dist

def query(self, query_vectors: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple[
Optional['np.ndarray'], Optional['np.ndarray']]:
def query(self, vectors: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple['np.ndarray', 'np.ndarray']:
"""Find the top-k vectors with smallest ``metric`` and return their ids in ascending order.
:return: a tuple of two ndarray.
Expand All @@ -339,14 +339,14 @@ def query(self, query_vectors: 'np.ndarray', top_k: int, *args, **kwargs) -> Tup
"""
if self.size == 0:
return None, None
return np.array([]), np.array([])
if self.metric not in {'cosine', 'euclidean'} or self.backend == 'scipy':
dist = self._cdist(query_vectors, self.query_handler)
dist = self._cdist(vectors, self.query_handler)
elif self.metric == 'euclidean':
_query_vectors = _ext_A(query_vectors)
_query_vectors = _ext_A(vectors)
dist = self._euclidean(_query_vectors, self.query_handler)
elif self.metric == 'cosine':
_query_vectors = _ext_A(_norm(query_vectors))
_query_vectors = _ext_A(_norm(vectors))
dist = self._cosine(_query_vectors, self.query_handler)
else:
raise NotImplementedError(f'{self.metric} is not implemented')
Expand All @@ -355,7 +355,7 @@ def query(self, query_vectors: 'np.ndarray', top_k: int, *args, **kwargs) -> Tup
indices = self._int2ext_id[self.valid_indices][idx]
return indices, dist

def build_advanced_index(self, vecs: 'np.ndarray'):
def build_advanced_index(self, vecs: 'np.ndarray') -> 'np.ndarray':
return vecs

@batching(merge_over_axis=1, slice_on=2)
Expand Down

0 comments on commit 2b1e6e6

Please sign in to comment.