Skip to content
Permalink
Browse files

fix(service): fix bug in req Generator add doc_type

  • Loading branch information...
Larryjianfeng committed Jul 26, 2019
1 parent 5743e25 commit 80e234e154caea53b9103007ffc1cd669cd00bc7
Showing with 7 additions and 4 deletions.
  1. +6 −3 gnes/proto/__init__.py
  2. +1 −1 gnes/service/indexer.py
@@ -28,7 +28,7 @@

class RequestGenerator:
@staticmethod
def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kwargs):
def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: str = 'TEXT', *args, **kwargs):

for pi in batch_iterator(data, batch_size):
req = gnes_pb2.Request()
@@ -37,17 +37,19 @@ def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kw
d = req.index.docs.add()
d.raw_bytes = raw_bytes
d.weight = 1.0
d.doc_type = doc_type
yield req
start_id += 1

@staticmethod
def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kwargs):
def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: str = 'TEXT', *args, **kwargs):
for pi in batch_iterator(data, batch_size):
req = gnes_pb2.Request()
req.request_id = str(start_id)
for raw_bytes in pi:
d = req.train.docs.add()
d.raw_bytes = raw_bytes
d.doc_type = doc_type
yield req
start_id += 1
req = gnes_pb2.Request()
@@ -57,13 +59,14 @@ def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kw
start_id += 1

@staticmethod
def query(query: bytes, top_k: int, start_id: int = 0, *args, **kwargs):
def query(query: bytes, top_k: int, start_id: int = 0, doc_type: str = 'TEXT', *args, **kwargs):
if top_k <= 0:
raise ValueError('"top_k: %d" is not a valid number' % top_k)

req = gnes_pb2.Request()
req.request_id = str(start_id)
req.search.query.raw_bytes = query
req.search.query.doc_type = doc_type
req.search.top_k = top_k
yield req

@@ -36,7 +36,7 @@ def _handler_index(self, msg: 'gnes_pb2.Message'):
weights = []

for d in msg.request.index.docs:
if d.chunks:
if len(d.chunks):
all_vecs.append(blob2array(d.chunk_embeddings))
doc_ids += [d.doc_id] * len(d.chunks)
if d.doc_type == 'TEXT':

0 comments on commit 80e234e

Please sign in to comment.
You can’t perform that action at this time.