Skip to content

Commit

Permalink
feat: add driver and changes in document needed for sparse pipeline (#…
Browse files Browse the repository at this point in the history
…2297)

* feat: add driver and changes in document needed for sparse pipeline

* feat: add sparse unit test and sparse encoder driver

* feat: add sparse pipeline test

* feat: add sparse pipeline drivers

* fix: formatting

* fix: distances shape not correct

* feat: complete sparse pipeline test

* feat: complete validation test

* feat: add test for document and document set

* test: fix sparse testing (#2309)

Co-authored-by: bwanglzu <bo.wang@jina.ai>
Co-authored-by: David Buchaca Prats <davidbuchaca@gmail.com>
  • Loading branch information
3 people committed Apr 15, 2021
1 parent a87e72c commit 8403440
Show file tree
Hide file tree
Showing 13 changed files with 324 additions and 17 deletions.
19 changes: 19 additions & 0 deletions jina/drivers/encode.py
Expand Up @@ -34,3 +34,22 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
raise LengthMismatchException(msg)
for doc, embedding in zip(docs_pts, embeds):
doc.embedding = embedding


class ScipySparseEncodeDriver(FlatRecursiveMixin, BaseEncodeDriver):
"""Extract the content from documents and call executor and do encoding"""

def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
contents, docs_pts = docs.all_contents

if docs_pts:
embeds = self.exec_fn(contents)
if len(docs_pts) != embeds.shape[0]:
msg = (
f'mismatched {len(docs_pts)} docs from level {docs_pts[0].granularity} '
f'and a {embeds.shape} shape embedding, the first dimension must be the same'
)
self.logger.error(msg)
raise LengthMismatchException(msg)
for idx, doc in enumerate(docs_pts):
doc.embedding = embeds.getrow(idx)
16 changes: 14 additions & 2 deletions jina/drivers/index.py
Expand Up @@ -39,12 +39,24 @@ class VectorIndexDriver(BaseIndexDriver):
If `method` is not 'delete', documents without content are filtered out.
"""

@staticmethod
def _get_documents_embeddings(docs: 'DocumentSet'):
return docs.all_embeddings

def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
embed_vecs, docs_pts = docs.all_embeddings
embed_vecs, docs_pts = self._get_documents_embeddings(docs)
if docs_pts:
keys = [doc.id for doc in docs_pts]
self.check_key_length(keys)
self.exec_fn(keys, np.stack(embed_vecs))
self.exec_fn(keys, embed_vecs)


class SparseVectorIndexDriver(VectorIndexDriver):
"""An alias to have coherent naming with the required SparseVectorSearchDriver """

@staticmethod
def _get_documents_embeddings(docs: 'DocumentSet'):
return docs.all_sparse_embeddings


class KVIndexDriver(BaseIndexDriver):
Expand Down
46 changes: 36 additions & 10 deletions jina/drivers/search.py
Expand Up @@ -108,7 +108,7 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:


class VectorSearchDriver(FlatRecursiveMixin, QuerySetReader, BaseSearchDriver):
"""Extract embeddings from the request for the executor to query.
"""Extract dense embeddings from the request for the executor to query.
:param top_k: top-k document ids to retrieve
:param fill_embedding: fill in the embedding of the corresponding doc,
Expand All @@ -121,8 +121,21 @@ def __init__(self, top_k: int = 50, fill_embedding: bool = False, *args, **kwarg
self._top_k = top_k
self._fill_embedding = fill_embedding

@staticmethod
def _get_documents_embeddings(docs: 'DocumentSet'):
return docs.all_embeddings

@staticmethod
def _fill_matches(doc, op_name, topks, scores, topk_embed):
for numpy_match_id, score, vec in zip(topks, scores, topk_embed):
m = Document(id=numpy_match_id)
m.score = NamedScore(op_name=op_name, value=score)
r = doc.matches.append(m)
if vec is not None:
r.embedding = vec

def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
embed_vecs, doc_pts = docs.all_embeddings
embed_vecs, doc_pts = self._get_documents_embeddings(docs)

if not doc_pts:
return
Expand All @@ -134,18 +147,31 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
)

idx, dist = self.exec_fn(embed_vecs, top_k=int(self.top_k))

op_name = self.exec.__class__.__name__
for doc, topks, scores in zip(doc_pts, idx, dist):

topk_embed = (
fill_fn(topks)
if (self._fill_embedding and fill_fn)
else [None] * len(topks)
)
for numpy_match_id, score, vec in zip(topks, scores, topk_embed):
m = Document(id=numpy_match_id)
m.score = NamedScore(op_name=op_name, value=score)
r = doc.matches.append(m)
if vec is not None:
r.embedding = vec
self._fill_matches(doc, op_name, topks, scores, topk_embed)


class SparseVectorSearchDriver(VectorSearchDriver):
"""
Extract sparse embeddings from the request for the executor to query.
"""

@staticmethod
def _get_documents_embeddings(docs: 'DocumentSet'):
return docs.all_sparse_embeddings

@staticmethod
def _fill_matches(doc, op_name, topks, scores, topk_embed):
for id, (numpy_match_id, score) in enumerate(zip(topks, scores)):
vec = topk_embed.getrow(id)
m = Document(id=numpy_match_id)
m.score = NamedScore(op_name=op_name, value=score)
r = doc.matches.append(m)
if vec is not None:
r.embedding = vec
4 changes: 4 additions & 0 deletions jina/executors/indexers/__init__.py
Expand Up @@ -303,6 +303,10 @@ def delete(self, keys: Iterable[str], *args, **kwargs) -> None:
raise NotImplementedError


class BaseSparseVectorIndexer(BaseVectorIndexer):
""" Alias to provide proper default drivers in resources"""


class BaseKVIndexer(BaseIndexer):
"""An abstract class for key-value indexer.
Expand Down
9 changes: 9 additions & 0 deletions jina/resources/executors.requests.BaseSparseVectorIndexer.yml
@@ -0,0 +1,9 @@
on:
ControlRequest:
- !ControlReqDriver {}
SearchRequest:
- !SparseVectorSearchDriver {}
[IndexRequest, UpdateRequest]:
- !SparseVectorIndexDriver {}
DeleteRequest:
- !DeleteDriver {}
10 changes: 10 additions & 0 deletions jina/types/document/__init__.py
Expand Up @@ -27,6 +27,8 @@
from ...logging import default_logger
from ...proto import jina_pb2

if False:
from scipy.sparse import coo_matrix

__all__ = ['Document', 'DocumentContentType', 'DocumentSourceType']
DIGEST_SIZE = 8
Expand Down Expand Up @@ -475,6 +477,14 @@ def embedding(self) -> 'np.ndarray':
"""
return NdArray(self._pb_body.embedding).value

@property
def sparse_embedding(self) -> 'coo_matrix':
"""Return ``embedding`` of the content of a Document as an sparse array.
:return: the embedding from the proto as an sparse array
"""
return NdArray(self._pb_body.embedding, is_sparse=True).value

@embedding.setter
def embedding(self, value: Union['np.ndarray', 'jina_pb2.NdArrayProto', 'NdArray']):
"""Set the ``embedding`` of the content of a Document.
Expand Down
29 changes: 29 additions & 0 deletions jina/types/sets/document.py
Expand Up @@ -22,6 +22,7 @@

if False:
from ..document import Document
from scipy.sparse import coo_matrix

__all__ = ['DocumentSet']

Expand Down Expand Up @@ -162,6 +163,34 @@ def all_embeddings(self) -> Tuple['np.ndarray', 'DocumentSet']:
"""
return self.extract_docs('embedding', stack_contents=True)

@property
def all_sparse_embeddings(self) -> Tuple['coo_matrix', 'DocumentSet']:
"""Return all embeddings from every document in this set as a ndarray
:return: The corresponding documents in a :class:`DocumentSet`,
and the documents have no embedding in a :class:`DocumentSet`.
:rtype: A tuple of embedding in :class:`np.ndarray`
"""
import scipy

embeddings = []
docs_pts = []
bad_docs = []
for doc in self:
embedding = doc.sparse_embedding
if embedding is None:
bad_docs.append(doc)
continue
embeddings.append(embedding)
docs_pts.append(doc)

if bad_docs:
default_logger.warning(
f'found {len(bad_docs)} docs at granularity {bad_docs[0].granularity} are missing sparse_embedding'
)

return scipy.sparse.vstack(embeddings), docs_pts

@property
def all_contents(self) -> Tuple['np.ndarray', 'DocumentSet']:
"""Return all embeddings from every document in this set as a ndarray
Expand Down
Empty file.
11 changes: 11 additions & 0 deletions tests/integration/sparse_pipeline/indexer.yml
@@ -0,0 +1,11 @@
!DummyCSRSparseIndexer
requests:
on:
ControlRequest:
- !ControlReqDriver {}
SearchRequest:
- !SparseVectorSearchDriver
with:
fill_embedding: True
IndexRequest:
- !SparseVectorIndexDriver {}
114 changes: 114 additions & 0 deletions tests/integration/sparse_pipeline/test_sparse_pipeline.py
@@ -0,0 +1,114 @@
from typing import Any, Iterable
import os

import pytest
import numpy as np
from scipy import sparse

from jina import Flow, Document
from jina.types.sets import DocumentSet
from jina.executors.encoders import BaseEncoder
from jina.executors.indexers import BaseSparseVectorIndexer

from tests import validate_callback

cur_dir = os.path.dirname(os.path.abspath(__file__))


@pytest.fixture(scope='function')
def num_docs():
return 10


@pytest.fixture(scope='function')
def docs_to_index(num_docs):
docs = []
for idx in range(1, num_docs + 1):
doc = Document(id=str(idx), content=np.array([idx * 5]))
docs.append(doc)
return DocumentSet(docs)


class DummySparseEncoder(BaseEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def encode(self, data: Any, *args, **kwargs) -> Any:
embed = sparse.csr_matrix(data)
return embed


class DummyCSRSparseIndexer(BaseSparseVectorIndexer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.keys = []
self.vectors = {}

def add(
self, keys: Iterable[str], vectors: 'scipy.sparse.coo_matrix', *args, **kwargs
) -> None:
assert isinstance(vectors, sparse.coo_matrix)
self.keys.extend(keys)
for i, key in enumerate(keys):
self.vectors[key] = vectors.getrow(i)

def query(self, vectors: 'scipy.sparse.coo_matrix', top_k: int, *args, **kwargs):
assert isinstance(vectors, sparse.coo_matrix)
distances = [item for item in range(0, min(top_k, len(self.keys)))]
return [self.keys[:top_k]], np.array([distances])

def query_by_key(self, keys: Iterable[str], *args, **kwargs):
from scipy.sparse import coo_matrix, vstack

vectors = []
for key in keys:
vectors.append(self.vectors[key])

return vstack(vectors)

def save(self):
# avoid creating dump, do not polute workspace
pass

def close(self):
# avoid creating dump, do not polute workspace
pass

def get_create_handler(self):
pass

def get_write_handler(self):
pass

def get_add_handler(self):
pass

def get_query_handler(self):
pass


def test_sparse_pipeline(mocker, docs_to_index):
def validate(response):
assert len(response.docs) == 1
assert len(response.docs[0].matches) == 10
for doc in response.docs:
for i, match in enumerate(doc.matches):
assert match.id == docs_to_index[i].id
assert isinstance(match.embedding, sparse.coo_matrix)

f = (
Flow()
.add(uses=DummySparseEncoder)
.add(uses=os.path.join(cur_dir, 'indexer.yml'))
)

mock = mocker.Mock()
error_mock = mocker.Mock()

with f:
f.index(inputs=docs_to_index)
f.search(inputs=docs_to_index[0], on_done=mock, on_error=error_mock)

mock.assert_called_once()
validate_callback(mock, validate)
error_mock.assert_not_called()

0 comments on commit 8403440

Please sign in to comment.