Skip to content
Permalink
Browse files

fix(protobuffer): add doc_type as func argument in RequestGenerator

  • Loading branch information...
numb3r3 committed Jul 26, 2019
1 parent 45a2495 commit f9500c1fe09dcbe27bec8a7a690e9fb243d51cc4
Showing with 3 additions and 3 deletions.
  1. +3 −3 gnes/proto/__init__.py
@@ -28,7 +28,7 @@

class RequestGenerator:
@staticmethod
def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: int = 1, *args, **kwargs):
def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: int = gnes_pb2.Document.TEXT, *args, **kwargs):

for pi in batch_iterator(data, batch_size):
req = gnes_pb2.Request()
@@ -42,7 +42,7 @@ def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: i
start_id += 1

@staticmethod
def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: int = 1, *args, **kwargs):
def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: int = gnes_pb2.Document.TEXT, *args, **kwargs):
for pi in batch_iterator(data, batch_size):
req = gnes_pb2.Request()
req.request_id = str(start_id)
@@ -59,7 +59,7 @@ def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: i
start_id += 1

@staticmethod
def query(query: bytes, top_k: int, start_id: int = 0, doc_type: int = 1, *args, **kwargs):
def query(query: bytes, top_k: int, start_id: int = 0, doc_type: int = gnes_pb2.Document.TEXT, *args, **kwargs):
if top_k <= 0:
raise ValueError('"top_k: %d" is not a valid number' % top_k)

0 comments on commit f9500c1

Please sign in to comment.
You can’t perform that action at this time.