Skip to content
Permalink
Browse files

refactor(grpc): hide private class inside gRPCfrontend

  • Loading branch information...
hanxiao committed Aug 9, 2019
1 parent e516646 commit 5e3409e192d76e4d059dc592b6a33a7ab3d12dde
Showing with 74 additions and 83 deletions.
  1. +74 −83 gnes/service/grpc.py
@@ -28,33 +28,6 @@
__all__ = ['GRPCFrontend']


class ZmqContext:
"""The zmq context class."""

def __init__(self, args):
"""Database connection context.
Args:
servers: a list of config dicts for connecting to database
dbapi_name: the name of database engine
"""
self.args = args

self.tlocal = threading.local()
self.tlocal.client = None

def __enter__(self):
"""Enter the context."""
client = ZmqClient(self.args)
self.tlocal.client = client
return client

def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit the context."""
self.tlocal.client.close()
self.tlocal.client = None


class ZmqClient:

def __init__(self, args):
@@ -87,71 +60,17 @@ def recv_message(self, timeout: int = -1) -> gnes_pb2.Message:
return recv_message(self.receiver, timeout=timeout)


class GNESServicer(gnes_pb2_grpc.GnesRPCServicer):

def __init__(self, args):
self.args = args
self.logger = set_logger(self.__class__.__name__, args.verbose)
self.zmq_context = ZmqContext(args)

def add_envelope(self, body: 'gnes_pb2.Request', zmq_client: 'ZmqClient'):
msg = gnes_pb2.Message()
msg.envelope.client_id = zmq_client.identity if zmq_client.identity else ''
if body.request_id:
msg.envelope.request_id = body.request_id
else:
msg.envelope.request_id = str(uuid.uuid4())
self.logger.warning('request_id is missing, filled it with a random uuid!')
msg.envelope.part_id = 1
msg.envelope.num_part.append(1)
msg.envelope.timeout = 5000
r = msg.envelope.routes.add()
r.service = GRPCFrontend.__name__
r.timestamp.GetCurrentTime()
msg.request.CopyFrom(body)
return msg

def remove_envelope(self, m: 'gnes_pb2.Message'):
resp = m.response
resp.request_id = m.envelope.request_id
return resp

def Call(self, request, context):
self.logger.info('received a new request: %s' % request.request_id or 'EMPTY_REQUEST_ID')
with self.zmq_context as zmq_client:
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
return self.remove_envelope(zmq_client.recv_message(self.args.timeout))

def Train(self, request, context):
return self.Call(request, context)

def Index(self, request, context):
return self.Call(request, context)

def Search(self, request, context):
return self.Call(request, context)

def StreamCall(self, request_iterator, context):
num_result = 0
with self.zmq_context as zmq_client:
for request in request_iterator:
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
num_result += 1
for _ in range(num_result):
yield self.remove_envelope(zmq_client.recv_message(self.args.timeout))


class GRPCFrontend:

def __init__(self, args):
self.logger = set_logger(self.__class__.__name__, args.verbose)
self.server = grpc.server(
futures.ThreadPoolExecutor(max_workers=args.max_concurrency),
options=[('grpc.max_send_message_length', args.max_message_size * 1024 * 1024),
('grpc.max_receive_message_length', args.max_message_size * 1024 * 1024)])
self.logger.info('start a grpc server with %d workers' % args.max_concurrency)
gnes_pb2_grpc.add_GnesRPCServicer_to_server(GNESServicer(args), self.server)
gnes_pb2_grpc.add_GnesRPCServicer_to_server(self.GNESServicer(args), self.server)

# Start GRPC Server
self.bind_address = '{0}:{1}'.format(args.grpc_host, args.grpc_port)
self.server.add_insecure_port(self.bind_address)

@@ -162,3 +81,75 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb):
self.server.stop(None)

class GNESServicer(gnes_pb2_grpc.GnesRPCServicer):

def __init__(self, args):
self.args = args
self.logger = set_logger(self.__class__.__name__, args.verbose)
self.zmq_context = self.ZmqContext(args)

def add_envelope(self, body: 'gnes_pb2.Request', zmq_client: 'ZmqClient'):
msg = gnes_pb2.Message()
msg.envelope.client_id = zmq_client.identity if zmq_client.identity else ''
if body.request_id:
msg.envelope.request_id = body.request_id
else:
msg.envelope.request_id = str(uuid.uuid4())
self.logger.warning('request_id is missing, filled it with a random uuid!')
msg.envelope.part_id = 1
msg.envelope.num_part.append(1)
msg.envelope.timeout = 5000
r = msg.envelope.routes.add()
r.service = GRPCFrontend.__name__
r.timestamp.GetCurrentTime()
msg.request.CopyFrom(body)
return msg

def remove_envelope(self, m: 'gnes_pb2.Message'):
resp = m.response
resp.request_id = m.envelope.request_id
return resp

def Call(self, request, context):
self.logger.info('received a new request: %s' % request.request_id or 'EMPTY_REQUEST_ID')
with self.zmq_context as zmq_client:
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
return self.remove_envelope(zmq_client.recv_message(self.args.timeout))

def Train(self, request, context):
return self.Call(request, context)

def Index(self, request, context):
return self.Call(request, context)

def Search(self, request, context):
return self.Call(request, context)

def StreamCall(self, request_iterator, context):
num_result = 0
with self.zmq_context as zmq_client:
for request in request_iterator:
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
num_result += 1
for _ in range(num_result):
yield self.remove_envelope(zmq_client.recv_message(self.args.timeout))

class ZmqContext:
"""The zmq context class."""

def __init__(self, args):
self.args = args
self.tlocal = threading.local()
self.tlocal.client = None

def __enter__(self):
"""Enter the context."""
client = ZmqClient(self.args)
self.tlocal.client = client
return client

def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit the context."""
self.tlocal.client.close()
self.tlocal.client = None

0 comments on commit 5e3409e

Please sign in to comment.
You can’t perform that action at this time.