Skip to content
Permalink
Browse files

refactor(grpc-client): implement async client via multi-threaded

  • Loading branch information...
numb3r3 committed Sep 5, 2019
1 parent 242dc5c commit 06aab81331286193a0f39a11fb5df01d7b854267
Showing with 175 additions and 22 deletions.
  1. +11 −19 gnes/client/cli.py
  2. +164 −0 gnes/client/grpc.py
  3. +0 −3 gnes/service/frontend.py
@@ -24,40 +24,32 @@
from termcolor import colored

from ..proto import gnes_pb2_grpc, RequestGenerator
from .grpc import StreamingSyncClient


class CLIClient:
class CLIClient(StreamingSyncClient):
def __init__(self, args):
self.args = args
self._use_channel()

def _use_channel(self):
all_bytes = self.read_all()
with grpc.insecure_channel(
'%s:%d' % (self.args.grpc_host, self.args.grpc_port),
options=[('grpc.max_send_message_length', self.args.max_message_size * 1024 * 1024),
('grpc.max_receive_message_length', self.args.max_message_size * 1024 * 1024)]) as channel:
stub = gnes_pb2_grpc.GnesRPCStub(channel)
getattr(self, self.args.mode)(all_bytes, stub)
super().__init__(args)
self.start()
getattr(self, self.args.mode)(all_bytes)
self.stop()

def train(self, all_bytes: List[bytes], stub):
with ProgressBar(all_bytes, self.args.batch_size, task_name=self.args.mode) as p_bar:
for _ in stub.StreamCall(RequestGenerator.train(all_bytes,
doc_id_start=self.args.start_doc_id,
batch_size=self.args.batch_size)):
for req in RequestGenerator.train(all_bytes, doc_id_start=self.args.start_doc_id, batch_size=self.args.batch_size):
self.send_request(req)
p_bar.update()

def index(self, all_bytes: List[bytes], stub):
with ProgressBar(all_bytes, self.args.batch_size, task_name=self.args.mode) as p_bar:
for _ in stub.StreamCall(RequestGenerator.index(all_bytes,
doc_id_start=self.args.start_doc_id,
batch_size=self.args.batch_size)):
for req in RequestGenerator.index(all_bytes, doc_id_start=self.args.start_doc_id, batch_size=self.args.batch_size):
self.send_request(req)
p_bar.update()

def query(self, all_bytes: List[bytes], stub):
for idx, q in enumerate(all_bytes):
for req in RequestGenerator.query(q, request_id_start=idx, top_k=self.args.top_k):
resp = stub.Call(req)
resp = self._stub.Call(req)
print(resp)
print('query %d result: %s' % (idx, resp))
input('press any key to continue...')
@@ -0,0 +1,164 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import grpc
import time

import queue
from concurrent import futures

from gnes.proto import gnes_pb2_grpc

_TIMEOUT = 60 * 60 * 24


class BaseGrpcClient:

def __init__(self, args):
self.args = args

self._address = '%s:%d' % (self.args.grpc_host, self.args.grpc_port)
self._channel = grpc.insecure_channel(
address=self._address,
options={
"grpc.max_send_message_length": -1,
"grpc.max_receive_message_length": -1,
}.items(),
)
# waits for the channel to be ready before we start sending messages
grpc.channel_ready_future(self._channel).result()

self._stub = gnes_pb2_grpc.GnesRPCStub(self._channel)

self._response_callbacks = []

def add_response_callback(self, callback):
"""callback will be invoked as callback(client, query_time)"""
self._response_callbacks.append(callback)

def send_request(self, request):
"""Non-blocking wrapper for a client's request operation."""
raise NotImplementedError

def start(self):
pass

def stop(self):
pass

def _handle_response(self, client, response, response_time):
for callback in self._response_callbacks:
callback(client, response, response_time)


class UnarySyncClient(BaseGrpcClient):

def __init__(self, args):
super().__init__(args)
self._pool = futures.ThreadPoolExecutor(
max_workers=self.args.max_concurrency)

def send_request(self, request):
# Send requests in seperate threads to support multiple outstanding rpcs
# (See src/proto/grpc/testing/control.proto)
self._pool.submit(self._dispatch_request, request)

def stop(self):
self._pool.shutdown(wait=True)
self._stub = None

def _dispatch_request(self, request):
start_time = time.time()
resp = self._stub.Call(request)
end_time = time.time()
self._handle_response(self, resp, end_time - start_time)


class UnaryAsyncClient(BaseGrpcClient):

def send_request(self, request):
# Use the Future callback api to support multiple outstanding rpcs
start_time = time.time()
response_future = self._stub.Call.future(self._request, _TIMEOUT)
response_future.add_done_callback(
lambda resp: self._response_received(start_time, resp))

def _response_received(self, start_time, resp):
resp = resp.result()
end_time = time.time()
self._handle_response(self, resp, end_time - start_time)

def stop(self):
self._stub = None


class _SyncStream(object):

def __init__(self, stub, handle_response):
self._stub = stub
self._handle_response = handle_response
self._is_streaming = False
self._request_queue = queue.Queue()
self._send_time_queue = queue.Queue()

def send_request(self, request):
self._send_time_queue.put(time.time())
self._request_queue.put(request)

def start(self):
self._is_streaming = True
response_stream = self._stub.StreamCall(self._request_generator())
for resp in response_stream:
self._handle_response(
self, resp,
time.time() - self._send_time_queue.get_nowait())

def stop(self):
self._is_streaming = False

def _request_generator(self):
while self._is_streaming:
try:
request = self._request_queue.get(block=True, timeout=1.0)
yield request
except queue.Empty:
pass


class StreamingSyncClient(BaseGrpcClient):

def __init__(self, args):
super().__init__(args)

self._streams = [
_SyncStream(self._stub, self._request, self._handle_response)
for _ in range(self.args.max_concurrency)
]
self._curr_stream = 0

def send_request(self, request):
# Use a round_robin scheduler to determine what stream to send on
self._streams[self._curr_stream].send_request(request)
self._curr_stream = (self._curr_stream + 1) % len(self._streams)

def start(self):
for stream in self._streams:
self._pool.submit(stream.start)

def stop(self):
for stream in self._streams:
stream.stop()
self._pool.shutdown(wait=True)
self._stub = None
@@ -93,12 +93,9 @@ def Search(self, request, context):
return self.Call(request, context)

def StreamCall(self, request_iterator, context):
num_result = 0
with self.zmq_context as zmq_client:
for request in request_iterator:
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
num_result += 1
for _ in range(num_result):
yield self.remove_envelope(zmq_client.recv_message(self.args.timeout))

class ZmqContext:

0 comments on commit 06aab81

Please sign in to comment.
You can’t perform that action at this time.