Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add driver and changes in document needed for sparse pipeline (#…
…2297) * feat: add driver and changes in document needed for sparse pipeline * feat: add sparse unit test and sparse encoder driver * feat: add sparse pipeline test * feat: add sparse pipeline drivers * fix: formatting * fix: distances shape not correct * feat: complete sparse pipeline test * feat: complete validation test * feat: add test for document and document set * test: fix sparse testing (#2309) Co-authored-by: bwanglzu <bo.wang@jina.ai> Co-authored-by: David Buchaca Prats <davidbuchaca@gmail.com>
- Loading branch information
1 parent
a87e72c
commit 8403440
Showing
13 changed files
with
324 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
9 changes: 9 additions & 0 deletions
9
jina/resources/executors.requests.BaseSparseVectorIndexer.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
on: | ||
ControlRequest: | ||
- !ControlReqDriver {} | ||
SearchRequest: | ||
- !SparseVectorSearchDriver {} | ||
[IndexRequest, UpdateRequest]: | ||
- !SparseVectorIndexDriver {} | ||
DeleteRequest: | ||
- !DeleteDriver {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
!DummyCSRSparseIndexer | ||
requests: | ||
on: | ||
ControlRequest: | ||
- !ControlReqDriver {} | ||
SearchRequest: | ||
- !SparseVectorSearchDriver | ||
with: | ||
fill_embedding: True | ||
IndexRequest: | ||
- !SparseVectorIndexDriver {} |
114 changes: 114 additions & 0 deletions
114
tests/integration/sparse_pipeline/test_sparse_pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from typing import Any, Iterable | ||
import os | ||
|
||
import pytest | ||
import numpy as np | ||
from scipy import sparse | ||
|
||
from jina import Flow, Document | ||
from jina.types.sets import DocumentSet | ||
from jina.executors.encoders import BaseEncoder | ||
from jina.executors.indexers import BaseSparseVectorIndexer | ||
|
||
from tests import validate_callback | ||
|
||
cur_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
|
||
@pytest.fixture(scope='function') | ||
def num_docs(): | ||
return 10 | ||
|
||
|
||
@pytest.fixture(scope='function') | ||
def docs_to_index(num_docs): | ||
docs = [] | ||
for idx in range(1, num_docs + 1): | ||
doc = Document(id=str(idx), content=np.array([idx * 5])) | ||
docs.append(doc) | ||
return DocumentSet(docs) | ||
|
||
|
||
class DummySparseEncoder(BaseEncoder): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
def encode(self, data: Any, *args, **kwargs) -> Any: | ||
embed = sparse.csr_matrix(data) | ||
return embed | ||
|
||
|
||
class DummyCSRSparseIndexer(BaseSparseVectorIndexer): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.keys = [] | ||
self.vectors = {} | ||
|
||
def add( | ||
self, keys: Iterable[str], vectors: 'scipy.sparse.coo_matrix', *args, **kwargs | ||
) -> None: | ||
assert isinstance(vectors, sparse.coo_matrix) | ||
self.keys.extend(keys) | ||
for i, key in enumerate(keys): | ||
self.vectors[key] = vectors.getrow(i) | ||
|
||
def query(self, vectors: 'scipy.sparse.coo_matrix', top_k: int, *args, **kwargs): | ||
assert isinstance(vectors, sparse.coo_matrix) | ||
distances = [item for item in range(0, min(top_k, len(self.keys)))] | ||
return [self.keys[:top_k]], np.array([distances]) | ||
|
||
def query_by_key(self, keys: Iterable[str], *args, **kwargs): | ||
from scipy.sparse import coo_matrix, vstack | ||
|
||
vectors = [] | ||
for key in keys: | ||
vectors.append(self.vectors[key]) | ||
|
||
return vstack(vectors) | ||
|
||
def save(self): | ||
# avoid creating dump, do not polute workspace | ||
pass | ||
|
||
def close(self): | ||
# avoid creating dump, do not polute workspace | ||
pass | ||
|
||
def get_create_handler(self): | ||
pass | ||
|
||
def get_write_handler(self): | ||
pass | ||
|
||
def get_add_handler(self): | ||
pass | ||
|
||
def get_query_handler(self): | ||
pass | ||
|
||
|
||
def test_sparse_pipeline(mocker, docs_to_index): | ||
def validate(response): | ||
assert len(response.docs) == 1 | ||
assert len(response.docs[0].matches) == 10 | ||
for doc in response.docs: | ||
for i, match in enumerate(doc.matches): | ||
assert match.id == docs_to_index[i].id | ||
assert isinstance(match.embedding, sparse.coo_matrix) | ||
|
||
f = ( | ||
Flow() | ||
.add(uses=DummySparseEncoder) | ||
.add(uses=os.path.join(cur_dir, 'indexer.yml')) | ||
) | ||
|
||
mock = mocker.Mock() | ||
error_mock = mocker.Mock() | ||
|
||
with f: | ||
f.index(inputs=docs_to_index) | ||
f.search(inputs=docs_to_index[0], on_done=mock, on_error=error_mock) | ||
|
||
mock.assert_called_once() | ||
validate_callback(mock, validate) | ||
error_mock.assert_not_called() |
Oops, something went wrong.