Skip to content
This repository was archived by the owner on Feb 22, 2020. It is now read-only.

Commit 66aec9c

Browse files
author
hanhxiao
committed
feat(grpc): add StreamCall and decouple send and receive
1 parent 9973f60 commit 66aec9c

File tree

7 files changed

+38
-29
lines changed

7 files changed

+38
-29
lines changed

gnes/client/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def __init__(self, args):
4141
stub = gnes_pb2_grpc.GnesRPCStub(channel)
4242

4343
if args.mode == 'train':
44-
resp = stub.RequestStreamCall(RequestGenerator.train(all_bytes, args.batch_size))
44+
resp = list(stub.StreamCall(RequestGenerator.train(all_bytes, args.batch_size)))[-1]
4545
print(resp)
4646
elif args.mode == 'index':
47-
resp = stub.RequestStreamCall(RequestGenerator.train(all_bytes, args.batch_size))
47+
resp = list(stub.StreamCall(RequestGenerator.train(all_bytes, args.batch_size)))[-1]
4848
print(resp)
4949
elif args.mode == 'query':
5050
for idx, q in enumerate(all_bytes):

gnes/client/http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ async def init(loop):
9191
return srv
9292

9393
def stub_call(req):
94-
res_f = stub.RequestStreamCall(req)
94+
res_f = list(stub.StreamCall(req))[-1]
9595
return json.loads(MessageToJson(res_f))
9696

9797
with grpc.insecure_channel(

gnes/proto/gnes.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ service GnesRPC {
207207
}
208208
rpc Call (Request) returns (Response) {
209209
}
210-
rpc RequestStreamCall (stream Request) returns (Response) {
210+
rpc StreamCall (stream Request) returns (stream Response) {
211211
}
212212
}
213213

gnes/proto/gnes_pb2.py

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

gnes/proto/gnes_pb2_grpc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def __init__(self, channel):
3434
request_serializer=gnes__pb2.Request.SerializeToString,
3535
response_deserializer=gnes__pb2.Response.FromString,
3636
)
37-
self.RequestStreamCall = channel.stream_unary(
38-
'/gnes.GnesRPC/RequestStreamCall',
37+
self.StreamCall = channel.stream_stream(
38+
'/gnes.GnesRPC/StreamCall',
3939
request_serializer=gnes__pb2.Request.SerializeToString,
4040
response_deserializer=gnes__pb2.Response.FromString,
4141
)
@@ -74,7 +74,7 @@ def Call(self, request, context):
7474
context.set_details('Method not implemented!')
7575
raise NotImplementedError('Method not implemented!')
7676

77-
def RequestStreamCall(self, request_iterator, context):
77+
def StreamCall(self, request_iterator, context):
7878
# missing associated documentation comment in .proto file
7979
pass
8080
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -104,8 +104,8 @@ def add_GnesRPCServicer_to_server(servicer, server):
104104
request_deserializer=gnes__pb2.Request.FromString,
105105
response_serializer=gnes__pb2.Response.SerializeToString,
106106
),
107-
'RequestStreamCall': grpc.stream_unary_rpc_method_handler(
108-
servicer.RequestStreamCall,
107+
'StreamCall': grpc.stream_stream_rpc_method_handler(
108+
servicer.StreamCall,
109109
request_deserializer=gnes__pb2.Request.FromString,
110110
response_serializer=gnes__pb2.Response.SerializeToString,
111111
),

gnes/service/grpc.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,14 @@ def Index(self, request, context):
131131
def Search(self, request, context):
132132
return self.Call(request, context)
133133

134-
def RequestStreamCall(self, request_iterator, context):
135-
for request in request_iterator:
136-
ret = self.Call(request, context)
137-
return ret
134+
def StreamCall(self, request_iterator, context):
135+
num_result = 0
136+
with self.zmq_context as zmq_client:
137+
for request in request_iterator:
138+
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
139+
num_result += 1
140+
for _ in range(num_result):
141+
yield self.remove_envelope(zmq_client.recv_message(self.args.timeout))
138142

139143

140144
class GRPCFrontend:

tests/test_stream_grpc.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,10 @@ def test_grpc_frontend(self):
5757
('grpc.max_receive_message_length', 70 * 1024 * 1024)]) as channel:
5858
stub = gnes_pb2_grpc.GnesRPCStub(channel)
5959
with TimeContext('sync call'): # about 5s
60-
resp = stub.RequestStreamCall(RequestGenerator.train(self.all_bytes, 1))
60+
resp = list(stub.StreamCall(RequestGenerator.train(self.all_bytes, 1)))[-1]
6161

6262
self.assertEqual(resp.request_id, str(len(self.all_bytes))) # idx start with 0, but +1 for final FLUSH
6363

64-
# test async calls
65-
with TimeContext('async call'): # immeidiately returns 0.001 s
66-
resp = stub.RequestStreamCall.future(RequestGenerator.train(self.all_bytes, 1))
67-
self.assertEqual(resp.result().request_id, str(len(self.all_bytes)))
68-
6964
@unittest.mock.patch.dict(os.environ, {'http_proxy': '', 'https_proxy': ''})
7065
def test_async_block(self):
7166
args = set_grpc_frontend_parser().parse_args([
@@ -91,9 +86,19 @@ def test_async_block(self):
9186
options=[('grpc.max_send_message_length', 70 * 1024 * 1024),
9287
('grpc.max_receive_message_length', 70 * 1024 * 1024)]) as channel:
9388
stub = gnes_pb2_grpc.GnesRPCStub(channel)
94-
with TimeContext('sync call'): # about 5s
95-
resp = stub.RequestStreamCall.future(RequestGenerator.train(self.all_bytes, 1))
96-
97-
self.assertEqual(resp.result().request_id, str(len(self.all_bytes)))
98-
99-
self.assertEqual(resp.request_id, str(len(self.all_bytes2))) # idx start with 0, but +1 for final FLUSH
89+
id = 0
90+
with TimeContext('non-blocking call'): # about 26s = 32s (total) - 3*2s (overlap)
91+
resp = stub.StreamCall(RequestGenerator.train(self.all_bytes2, 1))
92+
for r in resp:
93+
self.assertEqual(r.request_id, str(id))
94+
id += 1
95+
96+
id = 0
97+
with TimeContext('blocking call'): # should be 32 s
98+
for r in RequestGenerator.train(self.all_bytes2, 1):
99+
resp = stub.Call(r)
100+
self.assertEqual(resp.request_id, str(id))
101+
id += 1
102+
# self.assertEqual(resp.result().request_id, str(len(self.all_bytes)))
103+
104+
# self.assertEqual(resp.request_id, str(len(self.all_bytes2))) # idx start with 0, but +1 for final FLUSH

0 commit comments

Comments
 (0)