-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
test_sparse_pipeline.py
114 lines (84 loc) · 3.01 KB
/
test_sparse_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import Any, Iterable
import os
import pytest
import numpy as np
from scipy import sparse
from jina import Flow, Document
from jina.types.sets import DocumentSet
from jina.executors.encoders import BaseEncoder
from jina.executors.indexers import BaseSparseVectorIndexer
from tests import validate_callback
cur_dir = os.path.dirname(os.path.abspath(__file__))
@pytest.fixture(scope='function')
def num_docs():
return 10
@pytest.fixture(scope='function')
def docs_to_index(num_docs):
docs = []
for idx in range(1, num_docs + 1):
doc = Document(id=str(idx), content=np.array([idx * 5]))
docs.append(doc)
return DocumentSet(docs)
class DummySparseEncoder(BaseEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def encode(self, data: Any, *args, **kwargs) -> Any:
embed = sparse.csr_matrix(data)
return embed
class DummyCSRSparseIndexer(BaseSparseVectorIndexer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.keys = []
self.vectors = {}
def add(
self, keys: Iterable[str], vectors: 'scipy.sparse.coo_matrix', *args, **kwargs
) -> None:
assert isinstance(vectors, sparse.coo_matrix)
self.keys.extend(keys)
for i, key in enumerate(keys):
self.vectors[key] = vectors.getrow(i)
def query(self, vectors: 'scipy.sparse.coo_matrix', top_k: int, *args, **kwargs):
assert isinstance(vectors, sparse.coo_matrix)
distances = [item for item in range(0, min(top_k, len(self.keys)))]
return [self.keys[:top_k]], np.array([distances])
def query_by_key(self, keys: Iterable[str], *args, **kwargs):
from scipy.sparse import coo_matrix, vstack
vectors = []
for key in keys:
vectors.append(self.vectors[key])
return vstack(vectors)
def save(self):
# avoid creating dump, do not polute workspace
pass
def close(self):
# avoid creating dump, do not polute workspace
pass
def get_create_handler(self):
pass
def get_write_handler(self):
pass
def get_add_handler(self):
pass
def get_query_handler(self):
pass
def test_sparse_pipeline(mocker, docs_to_index):
def validate(response):
assert len(response.docs) == 1
assert len(response.docs[0].matches) == 10
for doc in response.docs:
for i, match in enumerate(doc.matches):
assert match.id == docs_to_index[i].id
assert isinstance(match.embedding, sparse.coo_matrix)
f = (
Flow()
.add(uses=DummySparseEncoder)
.add(uses=os.path.join(cur_dir, 'indexer.yml'))
)
mock = mocker.Mock()
error_mock = mocker.Mock()
with f:
f.index(inputs=docs_to_index)
f.search(inputs=docs_to_index[0], on_done=mock, on_error=error_mock)
mock.assert_called_once()
validate_callback(mock, validate)
error_mock.assert_not_called()