Skip to content
Permalink
Browse files

feat(grpc): add a general purpose grpc service

  • Loading branch information...
hanxiao committed Aug 9, 2019
1 parent 444580f commit 3a767b7cba092f52747e2ca2d1cc45892a606819
@@ -46,9 +46,9 @@ def route(args):


def frontend(args):
from ..service.grpc import GRPCFrontend
from gnes.service.frontend import FrontendService
import threading
with GRPCFrontend(args):
with FrontendService(args):
forever = threading.Event()
forever.wait()

@@ -227,15 +227,19 @@ def set_grpc_service_parser(parser=None):
parser = set_base_parser()
set_service_parser(parser)
_set_grpc_parser(parser)
parser.add_argument('--pb2_path',
type=str,
required=True,
help='the path of the python file protocol buffer compiler')
parser.add_argument('--pb2_grpc_path',
type=str,
required=True,
help='the path of the python file generated by the gRPC Python protocol compiler plugin')
parser.add_argument('--grpc_stub_name',
parser.add_argument('--stub_name',
type=str,
required=True,
help='the name of the gRPC Stub')
parser.add_argument('--grpc_api_name',
parser.add_argument('--api_name',
type=str,
required=True,
help='the api name for calling the stub')
@@ -0,0 +1,104 @@
import threading
import uuid
from concurrent.futures import ThreadPoolExecutor

import grpc

from ..client.base import ZmqClient
from ..helper import set_logger
from ..proto import gnes_pb2_grpc, gnes_pb2


class FrontendService:

def __init__(self, args):
self.logger = set_logger(self.__class__.__name__, args.verbose)
self.server = grpc.server(
ThreadPoolExecutor(max_workers=args.max_concurrency),
options=[('grpc.max_send_message_length', args.max_message_size * 1024 * 1024),
('grpc.max_receive_message_length', args.max_message_size * 1024 * 1024)])
self.logger.info('start a grpc server with %d workers' % args.max_concurrency)
gnes_pb2_grpc.add_GnesRPCServicer_to_server(self.GNESServicer(args), self.server)

self.bind_address = '{0}:{1}'.format(args.grpc_host, args.grpc_port)
self.server.add_insecure_port(self.bind_address)

def __enter__(self):
self.server.start()
self.logger.info('listening at: %s' % self.bind_address)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.server.stop(None)

class GNESServicer(gnes_pb2_grpc.GnesRPCServicer):

def __init__(self, args):
self.args = args
self.logger = set_logger(FrontendService.__name__, args.verbose)
self.zmq_context = self.ZmqContext(args)

def add_envelope(self, body: 'gnes_pb2.Request', zmq_client: 'ZmqClient'):
msg = gnes_pb2.Message()
msg.envelope.client_id = zmq_client.identity if zmq_client.identity else ''
if body.request_id:
msg.envelope.request_id = body.request_id
else:
msg.envelope.request_id = str(uuid.uuid4())
self.logger.warning('request_id is missing, filled it with a random uuid!')
msg.envelope.part_id = 1
msg.envelope.num_part.append(1)
msg.envelope.timeout = 5000
r = msg.envelope.routes.add()
r.service = FrontendService.__name__
r.timestamp.GetCurrentTime()
msg.request.CopyFrom(body)
return msg

def remove_envelope(self, m: 'gnes_pb2.Message'):
resp = m.response
resp.request_id = m.envelope.request_id
return resp

def Call(self, request, context):
self.logger.info('received a new request: %s' % request.request_id or 'EMPTY_REQUEST_ID')
with self.zmq_context as zmq_client:
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
return self.remove_envelope(zmq_client.recv_message(self.args.timeout))

def Train(self, request, context):
return self.Call(request, context)

def Index(self, request, context):
return self.Call(request, context)

def Search(self, request, context):
return self.Call(request, context)

def StreamCall(self, request_iterator, context):
num_result = 0
with self.zmq_context as zmq_client:
for request in request_iterator:
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
num_result += 1
for _ in range(num_result):
yield self.remove_envelope(zmq_client.recv_message(self.args.timeout))

class ZmqContext:
"""The zmq context class."""

def __init__(self, args):
self.args = args
self.tlocal = threading.local()
self.tlocal.client = None

def __enter__(self):
"""Enter the context."""
client = ZmqClient(self.args)
self.tlocal.client = client
return client

def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit the context."""
self.tlocal.client.close()
self.tlocal.client = None
@@ -14,16 +14,14 @@
# limitations under the License.


import threading
import uuid
from concurrent import futures
import importlib.util
import os
import sys

import grpc

from .base import BaseService as BS, MessageHandler
from ..client.base import ZmqClient
from ..helper import set_logger
from ..proto import gnes_pb2, gnes_pb2_grpc
from ..proto import gnes_pb2


class GRPCService(BS):
@@ -35,111 +33,40 @@ def post_init(self):
options=[('grpc.max_send_message_length', self.args.max_message_size * 1024 * 1024),
('grpc.max_receive_message_length', self.args.max_message_size * 1024 * 1024)])

import importlib.util
spec = importlib.util.spec_from_file_location('gnes.contrib', self.args.pb2_grpc_path)
foo = importlib.util.module_from_spec(spec)
spec.loader.exec_module(foo)
self.stub = getattr(foo, self.args.grpc_stub_name)(self.channel)
foo = self.PathImport().add_modules(self.args.pb2_path, self.args.pb2_grpc_path)

# build stub
self.stub = getattr(foo, self.args.stub_name)(self.channel)

def close(self):
self.channel.close()
super().close()

@handler.register(NotImplementedError)
def _handler_default(self, msg: 'gnes_pb2.Message'):
yield getattr(self.stub, self.args.grpc_api_name)(msg)


class GRPCFrontend:

def __init__(self, args):
self.logger = set_logger(self.__class__.__name__, args.verbose)
self.server = grpc.server(
futures.ThreadPoolExecutor(max_workers=args.max_concurrency),
options=[('grpc.max_send_message_length', args.max_message_size * 1024 * 1024),
('grpc.max_receive_message_length', args.max_message_size * 1024 * 1024)])
self.logger.info('start a grpc server with %d workers' % args.max_concurrency)
gnes_pb2_grpc.add_GnesRPCServicer_to_server(self.GNESServicer(args), self.server)

self.bind_address = '{0}:{1}'.format(args.grpc_host, args.grpc_port)
self.server.add_insecure_port(self.bind_address)

def __enter__(self):
self.server.start()
self.logger.info('grpc service is listening at: %s' % self.bind_address)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.server.stop(None)

class GNESServicer(gnes_pb2_grpc.GnesRPCServicer):

def __init__(self, args):
self.args = args
self.logger = set_logger(self.__class__.__name__, args.verbose)
self.zmq_context = self.ZmqContext(args)

def add_envelope(self, body: 'gnes_pb2.Request', zmq_client: 'ZmqClient'):
msg = gnes_pb2.Message()
msg.envelope.client_id = zmq_client.identity if zmq_client.identity else ''
if body.request_id:
msg.envelope.request_id = body.request_id
else:
msg.envelope.request_id = str(uuid.uuid4())
self.logger.warning('request_id is missing, filled it with a random uuid!')
msg.envelope.part_id = 1
msg.envelope.num_part.append(1)
msg.envelope.timeout = 5000
r = msg.envelope.routes.add()
r.service = GRPCFrontend.__name__
r.timestamp.GetCurrentTime()
msg.request.CopyFrom(body)
return msg

def remove_envelope(self, m: 'gnes_pb2.Message'):
resp = m.response
resp.request_id = m.envelope.request_id
return resp

def Call(self, request, context):
self.logger.info('received a new request: %s' % request.request_id or 'EMPTY_REQUEST_ID')
with self.zmq_context as zmq_client:
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
return self.remove_envelope(zmq_client.recv_message(self.args.timeout))

def Train(self, request, context):
return self.Call(request, context)

def Index(self, request, context):
return self.Call(request, context)

def Search(self, request, context):
return self.Call(request, context)

def StreamCall(self, request_iterator, context):
num_result = 0
with self.zmq_context as zmq_client:
for request in request_iterator:
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
num_result += 1
for _ in range(num_result):
yield self.remove_envelope(zmq_client.recv_message(self.args.timeout))

class ZmqContext:
"""The zmq context class."""

def __init__(self, args):
self.args = args
self.tlocal = threading.local()
self.tlocal.client = None

def __enter__(self):
"""Enter the context."""
client = ZmqClient(self.args)
self.tlocal.client = client
return client

def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit the context."""
self.tlocal.client.close()
self.tlocal.client = None
yield getattr(self.stub, self.args.api_name)(msg)

class PathImport:

@staticmethod
def get_module_name(absolute_path):
module_name = os.path.basename(absolute_path)
module_name = module_name.replace('.py', '')
return module_name

def add_modules(self, pb2_path, pb2_grpc_path):
(module, spec) = self.path_import(pb2_path)
sys.modules[spec.name] = module

(module, spec) = self.path_import(pb2_grpc_path)
sys.modules[spec.name] = module

return module

def path_import(self, absolute_path):
module_name = self.get_module_name(absolute_path)
spec = importlib.util.spec_from_file_location(module_name, absolute_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
sys.modules[spec.name] = module
return module, spec
@@ -1,10 +1,17 @@
#!/usr/bin/env bash

SRC_NAME=gnes.proto
SRC_DIR=../gnes/proto/

# generating test proto

#SRC_NAME=dummy.proto
#SRC_DIR=../tests/proto/

PLUGIN_PATH=/Volumes/TOSHIBA-4T/Documents/grpc/bins/opt/grpc_python_plugin
#PLUGIN_PATH=/user/local/grpc/bins/opt/grpc_python_plugin

protoc -I ${SRC_DIR} --python_out=${SRC_DIR} --grpc_python_out=${SRC_DIR} --plugin=protoc-gen-grpc_python=${PLUGIN_PATH} ${SRC_DIR}gnes.proto
protoc -I ${SRC_DIR} --python_out=${SRC_DIR} --grpc_python_out=${SRC_DIR} --plugin=protoc-gen-grpc_python=${PLUGIN_PATH} ${SRC_DIR}${SRC_NAME}

# fix import bug in google generator
sed -i '' -e '4s/.*/from\ \.\ import\ gnes_pb2\ as\ gnes__pb2/' ${SRC_DIR}gnes_pb2_grpc.py
#sed -i '' -e '4s/.*/from\ \.\ import\ gnes_pb2\ as\ gnes__pb2/' ${SRC_DIR}gnes_pb2_grpc.py
@@ -0,0 +1,36 @@
syntax = "proto3";
import "google/protobuf/timestamp.proto";


package dummy;


message Envelope {
// unique id of the sender of the message
string client_id = 1;

// unique id of the request
string request_id = 2;

// for multi-part message
uint32 part_id = 3;
repeated uint32 num_part = 4;

uint32 timeout = 5;

// list of string represent the route of the message
message route {
string service = 1;
google.protobuf.Timestamp timestamp = 2;
}
repeated route routes = 6;
}

message Message {
Envelope envelope = 1;
}

service DummyGRPCService {
rpc dummyAPI (Message) returns (Message) {}
}

0 comments on commit 3a767b7

Please sign in to comment.
You can’t perform that action at this time.