Skip to content
This repository was archived by the owner on Feb 22, 2020. It is now read-only.

Commit a1a2b02

Browse files
author
hanhxiao
committed
refactor(proto): refactor request stream call
1 parent 216cecc commit a1a2b02

File tree

10 files changed

+99
-78
lines changed

10 files changed

+99
-78
lines changed

Diff for: gnes/cli/parser.py

+3
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,13 @@ def _set_grpc_parser(parser=None):
206206

207207

208208
def set_grpc_frontend_parser(parser=None):
209+
from ..service.base import SocketType
209210
if not parser:
210211
parser = set_base_parser()
211212
_set_client_parser(parser)
212213
_set_grpc_parser(parser)
214+
parser.set_defaults(socket_in=SocketType.PULL_BIND,
215+
socket_out=SocketType.PUSH_BIND)
213216
parser.add_argument('--max_concurrency', type=int, default=10,
214217
help='maximum concurrent client allowed')
215218
parser.add_argument('--max_send_size', type=int, default=100,

Diff for: gnes/client/cli.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,15 @@ def __init__(self, args):
4141
stub = gnes_pb2_grpc.GnesRPCStub(channel)
4242

4343
if args.mode == 'train':
44-
for req in RequestGenerator.train(all_bytes, args.batch_size):
45-
resp = stub._Call(req)
46-
print(resp)
44+
resp = stub.RequestStreamCall(RequestGenerator.train(all_bytes, args.batch_size))
45+
print(resp)
4746
elif args.mode == 'index':
48-
for req in RequestGenerator.index(all_bytes, args.batch_size):
49-
resp = stub._Call(req)
50-
print(resp)
47+
resp = stub.RequestStreamCall(RequestGenerator.train(all_bytes, args.batch_size))
48+
print(resp)
5149
elif args.mode == 'query':
5250
for idx, q in enumerate(all_bytes):
5351
for req in RequestGenerator.query(q, args.top_k):
54-
resp = stub._Call(req)
52+
resp = stub.Call(req)
5553
print(resp)
5654
print('query %d result: %s' % (idx, resp))
5755
input('press any key to continue...')

Diff for: gnes/client/http.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,7 @@ async def init(loop):
9191
return srv
9292

9393
def stub_call(req):
94-
res_f = None
95-
for r in req:
96-
res_f = stub._Call(r)
94+
res_f = stub.RequestStreamCall(req)
9795
return json.loads(MessageToJson(res_f))
9896

9997
with grpc.insecure_channel(

Diff for: gnes/proto/__init__.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -28,33 +28,41 @@
2828

2929
class RequestGenerator:
3030
@staticmethod
31-
def index(data: List[bytes], batch_size: int = 0, *args, **kwargs):
31+
def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kwargs):
32+
3233
for pi in batch_iterator(data, batch_size):
3334
req = gnes_pb2.Request()
35+
req.request_id = start_id
3436
for raw_bytes in pi:
3537
d = req.index.docs.add()
3638
d.raw_bytes = raw_bytes
3739
d.weight = 1.0
3840
yield req
41+
start_id += 1
3942

4043
@staticmethod
41-
def train(data: List[bytes], batch_size: int = 0, *args, **kwargs):
44+
def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kwargs):
4245
for pi in batch_iterator(data, batch_size):
4346
req = gnes_pb2.Request()
47+
req.request_id = str(start_id)
4448
for raw_bytes in pi:
4549
d = req.train.docs.add()
4650
d.raw_bytes = raw_bytes
4751
yield req
52+
start_id += 1
4853
req = gnes_pb2.Request()
54+
req.request_id = str(start_id)
4955
req.train.flush = True
5056
yield req
57+
start_id += 1
5158

5259
@staticmethod
53-
def query(query: bytes, top_k: int, *args, **kwargs):
60+
def query(query: bytes, top_k: int, start_id: int = 0, *args, **kwargs):
5461
if top_k <= 0:
5562
raise ValueError('"top_k: %d" is not a valid number' % top_k)
5663

5764
req = gnes_pb2.Request()
65+
req.request_id = start_id
5866
req.search.query.raw_bytes = query
5967
req.search.top_k = top_k
6068
yield req

Diff for: gnes/proto/gnes.proto

+2-4
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,9 @@ service GnesRPC {
205205
}
206206
rpc Query (Request) returns (Response) {
207207
}
208-
rpc _Call (Request) returns (Response) {
208+
rpc Call (Request) returns (Response) {
209209
}
210-
rpc TrainStream (stream Request) returns (Response) {
211-
}
212-
rpc IndexStream (stream Request) returns (Response) {
210+
rpc RequestStreamCall (stream Request) returns (Response) {
213211
}
214212
}
215213

Diff for: gnes/proto/gnes_pb2.py

+6-15
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: gnes/proto/gnes_pb2_grpc.py

+10-27
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,13 @@ def __init__(self, channel):
2929
request_serializer=gnes__pb2.Request.SerializeToString,
3030
response_deserializer=gnes__pb2.Response.FromString,
3131
)
32-
self._Call = channel.unary_unary(
33-
'/gnes.GnesRPC/_Call',
32+
self.Call = channel.unary_unary(
33+
'/gnes.GnesRPC/Call',
3434
request_serializer=gnes__pb2.Request.SerializeToString,
3535
response_deserializer=gnes__pb2.Response.FromString,
3636
)
37-
self.TrainStream = channel.stream_unary(
38-
'/gnes.GnesRPC/TrainStream',
39-
request_serializer=gnes__pb2.Request.SerializeToString,
40-
response_deserializer=gnes__pb2.Response.FromString,
41-
)
42-
self.IndexStream = channel.stream_unary(
43-
'/gnes.GnesRPC/IndexStream',
37+
self.RequestStreamCall = channel.stream_unary(
38+
'/gnes.GnesRPC/RequestStreamCall',
4439
request_serializer=gnes__pb2.Request.SerializeToString,
4540
response_deserializer=gnes__pb2.Response.FromString,
4641
)
@@ -72,21 +67,14 @@ def Query(self, request, context):
7267
context.set_details('Method not implemented!')
7368
raise NotImplementedError('Method not implemented!')
7469

75-
def _Call(self, request, context):
76-
# missing associated documentation comment in .proto file
77-
pass
78-
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
79-
context.set_details('Method not implemented!')
80-
raise NotImplementedError('Method not implemented!')
81-
82-
def TrainStream(self, request_iterator, context):
70+
def Call(self, request, context):
8371
# missing associated documentation comment in .proto file
8472
pass
8573
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
8674
context.set_details('Method not implemented!')
8775
raise NotImplementedError('Method not implemented!')
8876

89-
def IndexStream(self, request_iterator, context):
77+
def RequestStreamCall(self, request_iterator, context):
9078
# missing associated documentation comment in .proto file
9179
pass
9280
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -111,18 +99,13 @@ def add_GnesRPCServicer_to_server(servicer, server):
11199
request_deserializer=gnes__pb2.Request.FromString,
112100
response_serializer=gnes__pb2.Response.SerializeToString,
113101
),
114-
'_Call': grpc.unary_unary_rpc_method_handler(
115-
servicer._Call,
116-
request_deserializer=gnes__pb2.Request.FromString,
117-
response_serializer=gnes__pb2.Response.SerializeToString,
118-
),
119-
'TrainStream': grpc.stream_unary_rpc_method_handler(
120-
servicer.TrainStream,
102+
'Call': grpc.unary_unary_rpc_method_handler(
103+
servicer.Call,
121104
request_deserializer=gnes__pb2.Request.FromString,
122105
response_serializer=gnes__pb2.Response.SerializeToString,
123106
),
124-
'IndexStream': grpc.stream_unary_rpc_method_handler(
125-
servicer.IndexStream,
107+
'RequestStreamCall': grpc.stream_unary_rpc_method_handler(
108+
servicer.RequestStreamCall,
126109
request_deserializer=gnes__pb2.Request.FromString,
127110
response_serializer=gnes__pb2.Response.SerializeToString,
128111
),

Diff for: gnes/service/grpc.py

+14-17
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
__all__ = ['GRPCFrontend']
3030

3131

32-
class ZmqContext(object):
32+
class ZmqContext:
3333
"""The zmq context class."""
3434

3535
def __init__(self, args):
@@ -111,32 +111,29 @@ def add_envelope(self, body: 'gnes_pb2.Request', zmq_client: 'ZmqClient'):
111111
msg.request.CopyFrom(body)
112112
return msg
113113

114-
def _Call(self, request, context):
114+
def remove_envelope(self, m: 'gnes_pb2.Message'):
115+
resp = m.response
116+
resp.request_id = m.envelope.request_id
117+
return resp
118+
119+
def Call(self, request, context):
115120
self.logger.info('received a new request: %s' % request.request_id or 'EMPTY_REQUEST_ID')
116121
with self.zmq_context as zmq_client:
117-
msg = self.add_envelope(request, zmq_client)
118-
zmq_client.send_message(msg, self.args.timeout)
119-
resp = zmq_client.recv_message(self.args.timeout)
120-
self.logger.info("received message done!")
121-
return resp.response
122+
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
123+
return self.remove_envelope(zmq_client.recv_message(self.args.timeout))
122124

123125
def Train(self, request, context):
124-
return self._Call(request, context)
126+
return self.Call(request, context)
125127

126128
def Index(self, request, context):
127-
return self._Call(request, context)
129+
return self.Call(request, context)
128130

129131
def Search(self, request, context):
130-
return self._Call(request, context)
131-
132-
def TrainStream(self, request_iterator, context):
133-
for request in request_iterator:
134-
ret = self._Call(request, context)
135-
return ret
132+
return self.Call(request, context)
136133

137-
def IndexStream(self, request_iterator, context):
134+
def RequestStreamCall(self, request_iterator, context):
138135
for request in request_iterator:
139-
ret = self._Call(request, context)
136+
ret = self.Call(request, context)
140137
return ret
141138

142139

Diff for: shell/make-proto.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#!/usr/bin/env bash
22

33
SRC_DIR=../gnes/proto/
4-
#PLUGIN_PATH=/Volumes/TOSHIBA-4T/Documents/grpc/bins/opt/grpc_python_plugin
5-
PLUGIN_PATH=/user/local/grpc/bins/opt/grpc_python_plugin
4+
PLUGIN_PATH=/Volumes/TOSHIBA-4T/Documents/grpc/bins/opt/grpc_python_plugin
5+
#PLUGIN_PATH=/user/local/grpc/bins/opt/grpc_python_plugin
66

77
protoc -I ${SRC_DIR} --python_out=${SRC_DIR} --grpc_python_out=${SRC_DIR} --plugin=protoc-gen-grpc_python=${PLUGIN_PATH} ${SRC_DIR}gnes.proto
88

0 commit comments

Comments
 (0)