Skip to content
This repository was archived by the owner on Feb 22, 2020. It is now read-only.

Commit f9500c1

Browse files
author
felix
committed
fix(protobuffer): add doc_type as func argument in RequestGenerator
1 parent 45a2495 commit f9500c1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

gnes/proto/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
class RequestGenerator:
3030
@staticmethod
31-
def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: int = 1, *args, **kwargs):
31+
def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: int = gnes_pb2.Document.TEXT, *args, **kwargs):
3232

3333
for pi in batch_iterator(data, batch_size):
3434
req = gnes_pb2.Request()
@@ -42,7 +42,7 @@ def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: i
4242
start_id += 1
4343

4444
@staticmethod
45-
def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: int = 1, *args, **kwargs):
45+
def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: int = gnes_pb2.Document.TEXT, *args, **kwargs):
4646
for pi in batch_iterator(data, batch_size):
4747
req = gnes_pb2.Request()
4848
req.request_id = str(start_id)
@@ -59,7 +59,7 @@ def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: i
5959
start_id += 1
6060

6161
@staticmethod
62-
def query(query: bytes, top_k: int, start_id: int = 0, doc_type: int = 1, *args, **kwargs):
62+
def query(query: bytes, top_k: int, start_id: int = 0, doc_type: int = gnes_pb2.Document.TEXT, *args, **kwargs):
6363
if top_k <= 0:
6464
raise ValueError('"top_k: %d" is not a valid number' % top_k)
6565

0 commit comments

Comments
 (0)