-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
search.py
136 lines (106 loc) · 4.9 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"
from typing import Tuple
from . import BaseExecutableDriver, QuerySetReader
from ..types.document import Document
from ..types.score import NamedScore
if False:
from ..types.sets import DocumentSet
class BaseSearchDriver(BaseExecutableDriver):
"""Drivers inherited from this Driver will bind :meth:`craft` by default """
def __init__(
self,
executor: str = None,
method: str = 'query',
traversal_paths: Tuple[str] = ('r', 'c'),
*args,
**kwargs):
super().__init__(
executor,
method,
traversal_paths=traversal_paths,
*args,
**kwargs
)
class KVSearchDriver(BaseSearchDriver):
"""Fill in the doc/chunk-level top-k results using the :class:`jina.executors.indexers.meta.BinaryPbIndexer`
.. warning::
This driver loops over all chunk/chunk's top-K results, each step fires a query.
This may not be very efficient, as the total number of queries depends on ``level``
- ``level=chunk``: D x C x K
- ``level=doc``: D x K
- ``level=all``: D x C x K
where:
- D is the number of queries
- C is the number of chunks per query/doc
- K is the top-k
"""
def __init__(self, is_merge: bool = True, *args, **kwargs):
"""
:param is_merge: when set to true the retrieved docs are merged into current message using :meth:`MergeFrom`,
otherwise, it overrides the current message using :meth:`CopyFrom`
"""
super().__init__(*args, **kwargs)
self._is_merge = is_merge
def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
miss_idx = [] #: missed hit results, some search may not end with results. especially in shards
for idx, retrieved_doc in enumerate(docs):
serialized_doc = self.exec_fn(int(retrieved_doc.id))
if serialized_doc:
r = Document(serialized_doc)
# TODO: this isn't perfect though, merge applies recursively on all children
# it will duplicate embedding.shape if embedding is already there
if self._is_merge:
retrieved_doc.MergeFrom(r)
else:
retrieved_doc.CopyFrom(r)
else:
miss_idx.append(idx)
# delete non-existed matches in reverse
for j in reversed(miss_idx):
del docs[j]
class VectorFillDriver(QuerySetReader, BaseSearchDriver):
""" Fill in the embedding by their doc id
"""
def __init__(self, executor: str = None, method: str = 'query_by_id', *args, **kwargs):
super().__init__(executor, method, *args, **kwargs)
def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
embeds = self.exec_fn([int(d.id) for d in docs])
for doc, embedding in zip(docs, embeds):
doc.embedding = embedding
class VectorSearchDriver(QuerySetReader, BaseSearchDriver):
"""Extract chunk-level embeddings from the request and use the executor to query it
"""
def __init__(self, top_k: int = 50, fill_embedding: bool = False, *args, **kwargs):
"""
:param top_k: top-k doc id to retrieve
:param fill_embedding: fill in the embedding of the corresponding doc,
this requires the executor to implement :meth:`query_by_id`
:param args:
:param kwargs:
"""
super().__init__(*args, **kwargs)
self._top_k = top_k
self._fill_embedding = fill_embedding
def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
embed_vecs, doc_pts, bad_docs = docs.all_embeddings
if not doc_pts:
return
fill_fn = getattr(self.exec, 'query_by_id', None)
if self._fill_embedding and not fill_fn:
self.logger.warning(f'"fill_embedding=True" but {self.exec} does not have "query_by_id" method')
if bad_docs:
self.logger.warning(f'these bad docs can not be added: {bad_docs}')
idx, dist = self.exec_fn(embed_vecs, top_k=int(self.top_k))
op_name = self.exec.__class__.__name__
# can be None if index is size 0
if idx is not None and dist is not None:
for doc, topks, scores in zip(doc_pts, idx, dist):
topk_embed = fill_fn(topks) if (self._fill_embedding and fill_fn) else [None] * len(topks)
for numpy_match_id, score, vec in zip(topks, scores, topk_embed):
m = Document(id=int(numpy_match_id))
m.score = NamedScore(op_name=op_name,
value=score)
r = doc.matches.append(m)
if vec is not None:
r.embedding = vec