Skip to content
Permalink
Browse files

refactor(grpc-client): add response handler

  • Loading branch information...
numb3r3 committed Sep 12, 2019
1 parent 9c1a606 commit a7a0ec8b410a081c9de7744d0a108c39a9a21f46
Showing with 109 additions and 67 deletions.
  1. +77 −7 gnes/client/base.py
  2. +32 −60 gnes/client/stream.py
@@ -17,13 +17,51 @@
import grpc
import zmq
from termcolor import colored
from typing import Tuple, List, Union, Type

from ..helper import set_logger
from ..proto import gnes_pb2_grpc
from ..proto import send_message, gnes_pb2, recv_message
from ..service.base import build_socket


class ResponseHandler:
def __init__(self, h: 'ResponseHandler' = None):
self.routes = {k: v for k, v in h.routes.items()} if h else {}
self.logger = set_logger(self.__class__.__name__)
self._context = None

def register(self, resp_type: Union[List, Tuple, type]):
def decorator(f):
if isinstance(resp_type, list) or isinstance(resp_type, tuple):
for t in resp_type:
self.routes[t] = f
else:
self.routes[resp_type] = f
return f

return decorator

def call_routes(self, resp: 'gnes_pb2.Response'):
def get_default_fn(r_type):
self.logger.warning('cant find handler for response type: %s, fall back to the default handler' % r_type)
f = self.routes.get(r_type, self.routes[NotImplementedError])
return f

self.logger.info('received a response for request %d' % resp.request_id)
if resp.WhichOneof('body'):
body = getattr(resp, resp.WhichOneof('body'))
resp_type = type(body)

if resp_type in self.routes:
fn = self.routes.get(resp_type)
else:
fn = get_default_fn(type(resp))

self.logger.info('handling response with %s' % fn.__name__)
return fn(self._context, resp)


class ZmqClient:

def __init__(self, args):
@@ -63,11 +101,18 @@ def recv_message(self, timeout: int = -1) -> gnes_pb2.Message:


class GrpcClient:
"""
A Base Unary gRPC client which the other client application can build from.
"""

handler = ResponseHandler()

def __init__(self, args):
self.args = args
self.logger = set_logger(self.__class__.__name__, self.args.verbose)
self.logger.info('setting up channel...')
self.logger.info('setting up grpc insecure channel...')
# A gRPC channel provides a connection to a remote gRPC server.
self._channel = grpc.insecure_channel(
'%s:%d' % (self.args.grpc_host, self.args.grpc_port),
options={
@@ -77,19 +122,44 @@ def __init__(self, args):
)
self.logger.info('waiting channel to be ready...')
grpc.channel_ready_future(self._channel).result()
self.logger.info('making stub...')
self.logger.critical('gnes client ready!')

# create new stub
self.logger.info('create new stub...')
self._stub = gnes_pb2_grpc.GnesRPCStub(self._channel)
self.logger.critical('ready!')

def send_request(self, request):
# attache response handler
self.handler._context = self

def call(self, request):
resp = self._stub.call(request)
self.handler.call_routes(resp)
return resp

def stream_call(self, request_iterator):
response_stream = self._stub.StreamCall(request_iterator)
for resp in response_stream:
self.handler.call_routes(resp)

@handler.register(NotImplementedError)
def _handler_default(self, msg: 'gnes_pb2.Response'):
raise NotImplementedError

def close(self):
self._channel.close()
self._stub = None
@handler.register(gnes_pb2.Response)
def _handler_response_default(self, msg: 'gnes_pb2.Response'):
pass

def __enter__(self):
self.open()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def open(self):
pass

def close(self):
self._channel.close()
self._stub = None
self.total_response = 0
@@ -13,91 +13,63 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import threading
import queue
from concurrent import futures

from .base import GrpcClient
from .base import GrpcClient, ResponseHandler


class _SyncStream:

def __init__(self, stub, handle_response):
self._stub = stub
self._handle_response = handle_response
self._is_streaming = False
self._request_queue = queue.Queue()

def send_request(self, request):
self._request_queue.put(request)

def start(self):
self._is_streaming = True
response_stream = self._stub.StreamCall(self._request_generator())
for resp in response_stream:
self._handle_response(self, resp)

def stop(self):
self._is_streaming = False

def _request_generator(self):
while self._is_streaming:
try:
request = self._request_queue.get(block=True, timeout=1.0)
yield request
except queue.Empty:
pass


class UnarySyncClient(GrpcClient):
class SyncClient(GrpcClient):
handler = ResponseHandler(GrpcClient.handler)

def __init__(self, args):
super().__init__(args)
self._pool = futures.ThreadPoolExecutor(
max_workers=self.args.max_concurrency)
self._response_callbacks = []

def send_request(self, request):
# Send requests in seperate threads to support multiple outstanding rpcs
self._pool.submit(self._dispatch_request, request)
self._pool.submit(self.call, request)

def close(self):
self._pool.shutdown(wait=True)
super().close()

def _dispatch_request(self, request):
resp = self._stub.Call(request)
self._handle_response(self, resp)

def _handle_response(self, client, response):
for callback in self._response_callbacks:
callback(client, response)

def add_response_callback(self, callback):
"""callback will be invoked as callback(client, response)"""
self._response_callbacks.append(callback)


class StreamingClient(UnarySyncClient):
class StreamingClient(GrpcClient):
handler = ResponseHandler(GrpcClient.handler)

def __init__(self, args):
super().__init__(args)

self._streams = [
_SyncStream(self._stub, self._handle_response)
for _ in range(self.args.max_concurrency)
]
self._curr_stream = 0
self._request_queue = queue.Queue()
self._is_streaming = threading.Event()

self._dispatch_thread = threading.Thread(target=self._start)
self._dispatch_thread.setDaemon(1)
self._dispatch_thread.start()

def send_request(self, request):
# Use a round_robin scheduler to determine what stream to send on
self._streams[self._curr_stream].send_request(request)
self._curr_stream = (self._curr_stream + 1) % len(self._streams)
self._request_queue.put(request)

def _start(self):
self._is_streaming.set()
response_stream = self.stream_call(self._request_generator())

def _request_generator(self):
while self._is_streaming.is_set():
try:
request = self._request_queue.get(block=True, timeout=1.0)
yield request
except queue.Empty:
pass

def start(self):
for stream in self._streams:
self._pool.submit(stream.start)
@handler.register(NotImplementedError)
def _handler_default(self, resp: 'gnes_pb2.Response'):
raise NotImplementedError

def close(self):
for stream in self._streams:
stream.stop()
self._is_streaming.clear()
self._dispatch_thread.join()
super().close()

0 comments on commit a7a0ec8

Please sign in to comment.
You can’t perform that action at this time.