This repository has been archived by the owner on Feb 22, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 210
/
indexer.py
119 lines (99 loc) · 5.42 KB
/
indexer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from .base import BaseService as BS, MessageHandler, ServiceError
from ..proto import gnes_pb2, blob2array
class IndexerService(BS):
handler = MessageHandler(BS.handler)
def post_init(self):
from ..indexer.base import BaseIndexer
# print('id: %s, before: %r' % (threading.get_ident(), self._model))
self._model = self.load_model(BaseIndexer)
# self._tmp_a = threading.get_ident()
# print('id: %s, after: %r, self._tmp_a: %r' % (threading.get_ident(), self._model, self._tmp_a))
@handler.register(gnes_pb2.Request.IndexRequest)
def _handler_index(self, msg: 'gnes_pb2.Message'):
# print('tid: %s, model: %r, self._tmp_a: %r' % (threading.get_ident(), self._model, self._tmp_a))
# if self._tmp_a != threading.get_ident():
# print('!!! tid: %s, tmp_a: %r %r' % (threading.get_ident(), self._tmp_a, self._handler_index))
from ..indexer.base import BaseChunkIndexer, BaseDocIndexer
if isinstance(self._model, BaseChunkIndexer):
is_changed = self._handler_chunk_index(msg)
elif isinstance(self._model, BaseDocIndexer):
is_changed = self._handler_doc_index(msg)
else:
raise ServiceError(
'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__)
if self.args.as_response:
msg.response.index.status = gnes_pb2.Response.SUCCESS
if is_changed:
self.is_model_changed.set()
def _handler_chunk_index(self, msg: 'gnes_pb2.Message') -> bool:
embed_info = []
for d in msg.request.index.docs:
if not d.chunks:
self.logger.warning('document (doc_id=%s) contains no chunks!' % d.doc_id)
continue
embed_info += [(blob2array(c.embedding), d.doc_id, c.offset, c.weight) for c in d.chunks if
c.embedding.data]
if embed_info:
vecs, doc_ids, offsets, weights = zip(*embed_info)
self._model.add(list(zip(doc_ids, offsets)), np.stack(vecs), weights)
return True
else:
self.logger.warning('chunks contain no embedded vectors, the indexer will do nothing')
return False
def _handler_doc_index(self, msg: 'gnes_pb2.Message') -> bool:
if msg.request.index.docs:
self._model.add([d.doc_id for d in msg.request.index.docs],
[d for d in msg.request.index.docs],
[d.weight for d in msg.request.index.docs])
return True
else:
return False
def _put_result_into_message(self, results, msg: 'gnes_pb2.Message'):
msg.response.search.ClearField('topk_results')
msg.response.search.topk_results.extend(results)
msg.response.search.top_k = len(results)
msg.response.search.is_big_score_similar = self._model.is_big_score_similar
@handler.register(gnes_pb2.Request.QueryRequest)
def _handler_chunk_search(self, msg: 'gnes_pb2.Message'):
from ..indexer.base import BaseChunkIndexer
if not isinstance(self._model, BaseChunkIndexer):
raise ServiceError(
'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__)
results = []
if not msg.request.search.query.chunks:
self.logger.warning('query contains no chunks!')
else:
results = self._model.query_and_score(msg.request.search.query.chunks, top_k=msg.request.search.top_k)
self._put_result_into_message(results, msg)
@handler.register(gnes_pb2.Response.QueryResponse)
def _handler_doc_search(self, msg: 'gnes_pb2.Message'):
from ..indexer.base import BaseDocIndexer
if not isinstance(self._model, BaseDocIndexer):
raise ServiceError(
'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__)
# check if chunk_indexer and doc_indexer has the same sorting order
if msg.response.search.is_big_score_similar is not None and \
msg.response.search.is_big_score_similar != self._model.is_big_score_similar:
raise ServiceError(
'is_big_score_similar is inconsistent. last topk-list: is_big_score_similar=%s, but '
'this indexer: is_big_score_similar=%s' % (
msg.response.search.is_big_score_similar, self._model.is_big_score_similar))
# assume the doc search will change the whatever sort order the message has
msg.response.search.is_sorted = False
results = self._model.query_and_score(msg.response.search.topk_results)
self._put_result_into_message(results, msg)