16
16
17
17
from typing import List , Union
18
18
19
- from .base import BaseService as BS , MessageHandler , ServiceError
19
+ from .base import BaseService as BS , MessageHandler
20
20
from ..proto import gnes_pb2 , array2blob , blob2array
21
21
22
22
@@ -35,11 +35,13 @@ def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.
35
35
docs = [docs ]
36
36
37
37
contents = []
38
- ids = []
39
- embeds = None
38
+ chunks = []
40
39
41
40
for d in docs :
42
- ids .append (len (d .chunks ))
41
+ if not d .chunks :
42
+ self .logger .warning ('document (doc_id=%s) contains no chunks!' % d .doc_id )
43
+ continue
44
+
43
45
for c in d .chunks :
44
46
if d .doc_type == gnes_pb2 .Document .TEXT :
45
47
contents .append (c .text )
@@ -48,34 +50,32 @@ def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.
48
50
else :
49
51
self .logger .warning (
50
52
'chunk content is in type: %s, dont kow how to handle that, ignored' % c .WhichOneof ('content' ))
53
+ chunks .append (c )
51
54
52
- if do_encoding :
53
- embeds = self ._model .encode (contents )
54
- if sum (ids ) != embeds .shape [0 ]:
55
- raise ServiceError (
56
- 'mismatched %d chunks and a %s shape embedding, '
57
- 'the first dimension must be the same' % (sum (ids ), embeds .shape ))
58
- idx = 0
59
- for d in docs :
60
- for c in d .chunks :
55
+ if do_encoding and contents :
56
+ try :
57
+ embeds = self ._model .encode (contents )
58
+ if len (chunks ) != embeds .shape [0 ]:
59
+ self .logger .error (
60
+ 'mismatched %d chunks and a %s shape embedding, '
61
+ 'the first dimension must be the same' % (len (chunks ), embeds .shape ))
62
+ for idx , c in enumerate (chunks ):
61
63
c .embedding .CopyFrom (array2blob (embeds [idx ]))
62
- idx += 1
64
+ except Exception as ex :
65
+ self .logger .error (ex , exc_info = True )
66
+ self .logger .warning ('encoder service throws an exception, '
67
+ 'the sequel pipeline may not work properly' )
63
68
64
- return contents , embeds
69
+ return contents
65
70
66
71
@handler .register (gnes_pb2 .Request .IndexRequest )
67
72
def _handler_index (self , msg : 'gnes_pb2.Message' ):
68
- _ , embeds = self .embed_chunks_in_docs (msg .request .index .docs )
69
- idx = 0
70
- for d in msg .request .index .docs :
71
- for c in d .chunks :
72
- c .embedding .CopyFrom (array2blob (embeds [idx ]))
73
- idx += 1
73
+ self .embed_chunks_in_docs (msg .request .index .docs )
74
74
75
75
@handler .register (gnes_pb2 .Request .TrainRequest )
76
76
def _handler_train (self , msg : 'gnes_pb2.Message' ):
77
77
if msg .request .train .docs :
78
- contents , _ = self .embed_chunks_in_docs (msg .request .train .docs , do_encoding = False )
78
+ contents = self .embed_chunks_in_docs (msg .request .train .docs , do_encoding = False )
79
79
self .train_data .extend (contents )
80
80
msg .response .train .status = gnes_pb2 .Response .PENDING
81
81
# raise BlockMessage
@@ -88,5 +88,4 @@ def _handler_train(self, msg: 'gnes_pb2.Message'):
88
88
89
89
@handler .register (gnes_pb2 .Request .QueryRequest )
90
90
def _handler_search (self , msg : 'gnes_pb2.Message' ):
91
- _ , embeds = self .embed_chunks_in_docs (msg .request .search .query , is_input_list = False )
92
- msg .request .search .query .chunk_embeddings .CopyFrom (array2blob (embeds ))
91
+ self .embed_chunks_in_docs (msg .request .search .query , is_input_list = False )
0 commit comments