diff --git a/jina/drivers/encode.py b/jina/drivers/encode.py index 1e3600ba8a3bd..ee9167a37633d 100644 --- a/jina/drivers/encode.py +++ b/jina/drivers/encode.py @@ -34,3 +34,22 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None: raise LengthMismatchException(msg) for doc, embedding in zip(docs_pts, embeds): doc.embedding = embedding + + +class ScipySparseEncodeDriver(FlatRecursiveMixin, BaseEncodeDriver): + """Extract the content from documents and call executor and do encoding""" + + def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None: + contents, docs_pts = docs.all_contents + + if docs_pts: + embeds = self.exec_fn(contents) + if len(docs_pts) != embeds.shape[0]: + msg = ( + f'mismatched {len(docs_pts)} docs from level {docs_pts[0].granularity} ' + f'and a {embeds.shape} shape embedding, the first dimension must be the same' + ) + self.logger.error(msg) + raise LengthMismatchException(msg) + for idx, doc in enumerate(docs_pts): + doc.embedding = embeds.getrow(idx) diff --git a/jina/drivers/index.py b/jina/drivers/index.py index 32053e12919c1..885893ac4872d 100644 --- a/jina/drivers/index.py +++ b/jina/drivers/index.py @@ -39,12 +39,24 @@ class VectorIndexDriver(BaseIndexDriver): If `method` is not 'delete', documents without content are filtered out. """ + @staticmethod + def _get_documents_embeddings(docs: 'DocumentSet'): + return docs.all_embeddings + def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None: - embed_vecs, docs_pts = docs.all_embeddings + embed_vecs, docs_pts = self._get_documents_embeddings(docs) if docs_pts: keys = [doc.id for doc in docs_pts] self.check_key_length(keys) - self.exec_fn(keys, np.stack(embed_vecs)) + self.exec_fn(keys, embed_vecs) + + +class SparseVectorIndexDriver(VectorIndexDriver): + """An alias to have coherent naming with the required SparseVectorSearchDriver """ + + @staticmethod + def _get_documents_embeddings(docs: 'DocumentSet'): + return docs.all_sparse_embeddings class KVIndexDriver(BaseIndexDriver): diff --git a/jina/drivers/search.py b/jina/drivers/search.py index 79b1a17b0b3de..3507b978b1698 100644 --- a/jina/drivers/search.py +++ b/jina/drivers/search.py @@ -108,7 +108,7 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None: class VectorSearchDriver(FlatRecursiveMixin, QuerySetReader, BaseSearchDriver): - """Extract embeddings from the request for the executor to query. + """Extract dense embeddings from the request for the executor to query. :param top_k: top-k document ids to retrieve :param fill_embedding: fill in the embedding of the corresponding doc, @@ -121,8 +121,21 @@ def __init__(self, top_k: int = 50, fill_embedding: bool = False, *args, **kwarg self._top_k = top_k self._fill_embedding = fill_embedding + @staticmethod + def _get_documents_embeddings(docs: 'DocumentSet'): + return docs.all_embeddings + + @staticmethod + def _fill_matches(doc, op_name, topks, scores, topk_embed): + for numpy_match_id, score, vec in zip(topks, scores, topk_embed): + m = Document(id=numpy_match_id) + m.score = NamedScore(op_name=op_name, value=score) + r = doc.matches.append(m) + if vec is not None: + r.embedding = vec + def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None: - embed_vecs, doc_pts = docs.all_embeddings + embed_vecs, doc_pts = self._get_documents_embeddings(docs) if not doc_pts: return @@ -134,18 +147,31 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None: ) idx, dist = self.exec_fn(embed_vecs, top_k=int(self.top_k)) - op_name = self.exec.__class__.__name__ for doc, topks, scores in zip(doc_pts, idx, dist): - topk_embed = ( fill_fn(topks) if (self._fill_embedding and fill_fn) else [None] * len(topks) ) - for numpy_match_id, score, vec in zip(topks, scores, topk_embed): - m = Document(id=numpy_match_id) - m.score = NamedScore(op_name=op_name, value=score) - r = doc.matches.append(m) - if vec is not None: - r.embedding = vec + self._fill_matches(doc, op_name, topks, scores, topk_embed) + + +class SparseVectorSearchDriver(VectorSearchDriver): + """ + Extract sparse embeddings from the request for the executor to query. + """ + + @staticmethod + def _get_documents_embeddings(docs: 'DocumentSet'): + return docs.all_sparse_embeddings + + @staticmethod + def _fill_matches(doc, op_name, topks, scores, topk_embed): + for id, (numpy_match_id, score) in enumerate(zip(topks, scores)): + vec = topk_embed.getrow(id) + m = Document(id=numpy_match_id) + m.score = NamedScore(op_name=op_name, value=score) + r = doc.matches.append(m) + if vec is not None: + r.embedding = vec diff --git a/jina/executors/indexers/__init__.py b/jina/executors/indexers/__init__.py index e82b7fc640b3c..b62ad72feef16 100644 --- a/jina/executors/indexers/__init__.py +++ b/jina/executors/indexers/__init__.py @@ -303,6 +303,10 @@ def delete(self, keys: Iterable[str], *args, **kwargs) -> None: raise NotImplementedError +class BaseSparseVectorIndexer(BaseVectorIndexer): + """ Alias to provide proper default drivers in resources""" + + class BaseKVIndexer(BaseIndexer): """An abstract class for key-value indexer. diff --git a/jina/resources/executors.requests.BaseSparseVectorIndexer.yml b/jina/resources/executors.requests.BaseSparseVectorIndexer.yml new file mode 100644 index 0000000000000..83ee0a7d4e95d --- /dev/null +++ b/jina/resources/executors.requests.BaseSparseVectorIndexer.yml @@ -0,0 +1,9 @@ +on: + ControlRequest: + - !ControlReqDriver {} + SearchRequest: + - !SparseVectorSearchDriver {} + [IndexRequest, UpdateRequest]: + - !SparseVectorIndexDriver {} + DeleteRequest: + - !DeleteDriver {} \ No newline at end of file diff --git a/jina/types/document/__init__.py b/jina/types/document/__init__.py index 69a0e73e926e8..b11aeebe19836 100644 --- a/jina/types/document/__init__.py +++ b/jina/types/document/__init__.py @@ -27,6 +27,8 @@ from ...logging import default_logger from ...proto import jina_pb2 +if False: + from scipy.sparse import coo_matrix __all__ = ['Document', 'DocumentContentType', 'DocumentSourceType'] DIGEST_SIZE = 8 @@ -475,6 +477,14 @@ def embedding(self) -> 'np.ndarray': """ return NdArray(self._pb_body.embedding).value + @property + def sparse_embedding(self) -> 'coo_matrix': + """Return ``embedding`` of the content of a Document as an sparse array. + + :return: the embedding from the proto as an sparse array + """ + return NdArray(self._pb_body.embedding, is_sparse=True).value + @embedding.setter def embedding(self, value: Union['np.ndarray', 'jina_pb2.NdArrayProto', 'NdArray']): """Set the ``embedding`` of the content of a Document. diff --git a/jina/types/sets/document.py b/jina/types/sets/document.py index 66ed19ee8f954..17d42d873bea9 100644 --- a/jina/types/sets/document.py +++ b/jina/types/sets/document.py @@ -22,6 +22,7 @@ if False: from ..document import Document + from scipy.sparse import coo_matrix __all__ = ['DocumentSet'] @@ -162,6 +163,34 @@ def all_embeddings(self) -> Tuple['np.ndarray', 'DocumentSet']: """ return self.extract_docs('embedding', stack_contents=True) + @property + def all_sparse_embeddings(self) -> Tuple['coo_matrix', 'DocumentSet']: + """Return all embeddings from every document in this set as a ndarray + + :return: The corresponding documents in a :class:`DocumentSet`, + and the documents have no embedding in a :class:`DocumentSet`. + :rtype: A tuple of embedding in :class:`np.ndarray` + """ + import scipy + + embeddings = [] + docs_pts = [] + bad_docs = [] + for doc in self: + embedding = doc.sparse_embedding + if embedding is None: + bad_docs.append(doc) + continue + embeddings.append(embedding) + docs_pts.append(doc) + + if bad_docs: + default_logger.warning( + f'found {len(bad_docs)} docs at granularity {bad_docs[0].granularity} are missing sparse_embedding' + ) + + return scipy.sparse.vstack(embeddings), docs_pts + @property def all_contents(self) -> Tuple['np.ndarray', 'DocumentSet']: """Return all embeddings from every document in this set as a ndarray diff --git a/tests/integration/sparse_pipeline/__init__.py b/tests/integration/sparse_pipeline/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/integration/sparse_pipeline/indexer.yml b/tests/integration/sparse_pipeline/indexer.yml new file mode 100644 index 0000000000000..1464ec2064db5 --- /dev/null +++ b/tests/integration/sparse_pipeline/indexer.yml @@ -0,0 +1,11 @@ +!DummyCSRSparseIndexer +requests: + on: + ControlRequest: + - !ControlReqDriver {} + SearchRequest: + - !SparseVectorSearchDriver + with: + fill_embedding: True + IndexRequest: + - !SparseVectorIndexDriver {} \ No newline at end of file diff --git a/tests/integration/sparse_pipeline/test_sparse_pipeline.py b/tests/integration/sparse_pipeline/test_sparse_pipeline.py new file mode 100644 index 0000000000000..5db64820a0e9b --- /dev/null +++ b/tests/integration/sparse_pipeline/test_sparse_pipeline.py @@ -0,0 +1,114 @@ +from typing import Any, Iterable +import os + +import pytest +import numpy as np +from scipy import sparse + +from jina import Flow, Document +from jina.types.sets import DocumentSet +from jina.executors.encoders import BaseEncoder +from jina.executors.indexers import BaseSparseVectorIndexer + +from tests import validate_callback + +cur_dir = os.path.dirname(os.path.abspath(__file__)) + + +@pytest.fixture(scope='function') +def num_docs(): + return 10 + + +@pytest.fixture(scope='function') +def docs_to_index(num_docs): + docs = [] + for idx in range(1, num_docs + 1): + doc = Document(id=str(idx), content=np.array([idx * 5])) + docs.append(doc) + return DocumentSet(docs) + + +class DummySparseEncoder(BaseEncoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def encode(self, data: Any, *args, **kwargs) -> Any: + embed = sparse.csr_matrix(data) + return embed + + +class DummyCSRSparseIndexer(BaseSparseVectorIndexer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.keys = [] + self.vectors = {} + + def add( + self, keys: Iterable[str], vectors: 'scipy.sparse.coo_matrix', *args, **kwargs + ) -> None: + assert isinstance(vectors, sparse.coo_matrix) + self.keys.extend(keys) + for i, key in enumerate(keys): + self.vectors[key] = vectors.getrow(i) + + def query(self, vectors: 'scipy.sparse.coo_matrix', top_k: int, *args, **kwargs): + assert isinstance(vectors, sparse.coo_matrix) + distances = [item for item in range(0, min(top_k, len(self.keys)))] + return [self.keys[:top_k]], np.array([distances]) + + def query_by_key(self, keys: Iterable[str], *args, **kwargs): + from scipy.sparse import coo_matrix, vstack + + vectors = [] + for key in keys: + vectors.append(self.vectors[key]) + + return vstack(vectors) + + def save(self): + # avoid creating dump, do not polute workspace + pass + + def close(self): + # avoid creating dump, do not polute workspace + pass + + def get_create_handler(self): + pass + + def get_write_handler(self): + pass + + def get_add_handler(self): + pass + + def get_query_handler(self): + pass + + +def test_sparse_pipeline(mocker, docs_to_index): + def validate(response): + assert len(response.docs) == 1 + assert len(response.docs[0].matches) == 10 + for doc in response.docs: + for i, match in enumerate(doc.matches): + assert match.id == docs_to_index[i].id + assert isinstance(match.embedding, sparse.coo_matrix) + + f = ( + Flow() + .add(uses=DummySparseEncoder) + .add(uses=os.path.join(cur_dir, 'indexer.yml')) + ) + + mock = mocker.Mock() + error_mock = mocker.Mock() + + with f: + f.index(inputs=docs_to_index) + f.search(inputs=docs_to_index[0], on_done=mock, on_error=error_mock) + + mock.assert_called_once() + validate_callback(mock, validate) + error_mock.assert_not_called() diff --git a/tests/unit/drivers/test_encoder_driver.py b/tests/unit/drivers/test_encoder_driver.py index 06b0abf8ce703..7d4c571070b5e 100644 --- a/tests/unit/drivers/test_encoder_driver.py +++ b/tests/unit/drivers/test_encoder_driver.py @@ -1,10 +1,11 @@ from typing import Any -import numpy as np import pytest +import numpy as np +from scipy import sparse from jina import Document, DocumentSet -from jina.drivers.encode import EncodeDriver +from jina.drivers.encode import EncodeDriver, ScipySparseEncodeDriver from jina.executors.encoders import BaseEncoder from jina.executors.decorators import batching @@ -17,7 +18,7 @@ def num_docs(): @pytest.fixture(scope='function') def docs_to_encode(num_docs): docs = [] - for idx in range(num_docs): + for idx in range(1, num_docs + 1): doc = Document(content=np.array([idx])) docs.append(doc) return DocumentSet(docs) @@ -62,3 +63,46 @@ def test_encode_driver(batch_size, docs_to_encode, num_docs): assert len(docs_to_encode) == num_docs for doc in docs_to_encode: assert doc.embedding == doc.blob + + +def get_sparse_encoder(sparse_type): + class MockEncoder(BaseEncoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def encode(self, data: Any, *args, **kwargs) -> Any: + # return a sparse vector of the same number of rows as `data` of different types + embed = sparse_type(data) + return embed + + return MockEncoder() + + +class SimpleScipySparseEncoderDriver(ScipySparseEncodeDriver): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def exec_fn(self): + return self._exec_fn + + +@pytest.fixture( + params=[sparse.csr_matrix, sparse.coo_matrix, sparse.bsr_matrix, sparse.csc_matrix] +) +def sparse_type(request): + return request.param + + +def test_sparse_encode_driver(sparse_type, docs_to_encode, num_docs): + driver = SimpleScipySparseEncoderDriver() + encoder = get_sparse_encoder(sparse_type) + driver.attach(executor=encoder, runtime=None) + assert len(docs_to_encode) == num_docs + for doc in docs_to_encode: + assert doc.embedding is None + driver._apply_all(docs_to_encode) + assert len(docs_to_encode) == num_docs + for doc in docs_to_encode: + assert isinstance(doc.embedding, sparse.coo_matrix) + assert doc.embedding == doc.blob diff --git a/tests/unit/types/document/test_document.py b/tests/unit/types/document/test_document.py index e4b8419bd6bac..80ed24a0ca8d4 100644 --- a/tests/unit/types/document/test_document.py +++ b/tests/unit/types/document/test_document.py @@ -4,6 +4,7 @@ import numpy as np import pytest from google.protobuf.json_format import MessageToDict +from scipy.sparse import coo_matrix, bsr_matrix, csr_matrix, csc_matrix from jina import NdArray, Request from jina.proto.jina_pb2 import DocumentProto @@ -13,8 +14,6 @@ def scipy_sparse_list(): - from scipy.sparse import coo_matrix, bsr_matrix, csr_matrix, csc_matrix - return [coo_matrix, bsr_matrix, csr_matrix, csc_matrix] @@ -742,6 +741,13 @@ def test_document_sparse_attributes_pytorch(torch_sparse_matrix): ) +def test_document_sparse_embedding(scipy_sparse_matrix): + d = Document() + d.embedding = scipy_sparse_matrix + assert d.sparse_embedding is not None + assert isinstance(d.sparse_embedding, coo_matrix) + + def test_siblings_needs_to_be_set_manually(): document = Document() with document: diff --git a/tests/unit/types/sets/test_documentset.py b/tests/unit/types/sets/test_documentset.py index 0b954a26c13b9..0e0bf5f010a9a 100644 --- a/tests/unit/types/sets/test_documentset.py +++ b/tests/unit/types/sets/test_documentset.py @@ -2,6 +2,7 @@ import pytest import numpy as np +from scipy.sparse import coo_matrix from jina import Document from jina.types.sets import DocumentSet @@ -35,6 +36,20 @@ def docset(docs): return DocumentSet(docs) +@pytest.fixture +def docset_with_scipy_sparse_embedding(docs): + embedding = coo_matrix( + ( + np.array([1, 2, 3, 4, 5, 6]), + (np.array([0, 0, 1, 2, 2, 2]), np.array([0, 2, 2, 0, 1, 2])), + ), + shape=(4, 10), + ) + for doc in docs: + doc.embedding = embedding + return DocumentSet(docs) + + def test_length(docset, docs): assert len(docs) == len(docset) == 3 @@ -390,3 +405,11 @@ def test_get_content_multiple_fields_merge(stack, num_rows): assert len(contents[1]) == batch_size for c in contents[0]: assert c.shape == (num_rows, embed_size) + + +def test_all_embeddings(docset_with_scipy_sparse_embedding): + all_embeddings, doc_pts = docset_with_scipy_sparse_embedding.all_sparse_embeddings + assert all_embeddings is not None + assert doc_pts is not None + assert len(doc_pts) == 3 + assert isinstance(all_embeddings, coo_matrix)