Skip to content

Commit

Permalink
feat(executors): add apis for other method_name
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed May 20, 2020
1 parent 15f9f1f commit 20563b2
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 16 deletions.
49 changes: 39 additions & 10 deletions jina/executors/clients.py
Expand Up @@ -2,6 +2,9 @@
import grpc
from typing import Dict

if False:
from tensorflow_serving.apis import predict_pb2


class BaseClientExecutor(BaseExecutor):
"""
Expand Down Expand Up @@ -47,23 +50,29 @@ def get_output(self, response):
return np.array(response.result().outputs['output_feature'].float_val)
"""
def __init__(self, service_name, signature_name='serving_default', *args, **kwargs):
def __init__(self, model_name, signature_name='serving_default', method_name='Predict', *args, **kwargs):
"""
:param service_name: the name of the tf serving service
:param signature_name: the name of the tf serving signature
:param model_name: the name of the tf serving model. It must match the `MODEL_NAME` parameter when starting the
tf server.
:param signature_name: the name of the tf serving signature. It must match the key in the `signature_def_map`
when exporting the tf serving model.
:param method_name: the name of the tf serving method. This parameter corresponds to the `method_name` parameter
when building the signature map with ``build_signature_def()``. Currently, only ``Predict`` is supported.
The other methods including ``Classify``, ``Regression``, and ``MultiInference`` are under development.
"""
super().__init__(*args, **kwargs)
self.service_name = service_name
self.model_name = model_name
self.signature_name = signature_name
self.method_name = method_name

def post_init(self):
"""
Initialize the channel and stub for the gRPC client
"""
from tensorflow_serving.apis import prediction_service_pb2_grpc
self._channel = grpc.insecure_channel('{}:{}'.format(self.host, self.port))
from tensorflow_serving.apis import prediction_service_pb2_grpc
self._stub = prediction_service_pb2_grpc.PredictionServiceStub(self._channel)

def get_request(self, data):
Expand All @@ -73,7 +82,10 @@ def get_request(self, data):
"""
request = self.get_default_request()
input_dict = self.get_input(data)
return self.fill_request(request, input_dict)
if self.method_name == 'Predict':
return self.fill_request(request, input_dict)
else:
raise NotImplementedError

def get_input(self, data) -> Dict:
"""
Expand All @@ -86,7 +98,7 @@ def get_response(self, request: 'predict_pb2.PredictRequest'):
"""
Get the response from the tf server and postprocess the response
"""
_response = self._stub.Predict.future(request, self.timeout)
_response = getattr(self._stub, self.method_name).future(request, self.timeout)
if _response.exception():
self.logger.error('exception raised in encoding: {}'.format(_response.exception))
raise ValueError
Expand All @@ -102,12 +114,29 @@ def get_default_request(self) -> 'predict_pb2.PredictRequest':
"""
Construct the default gRPC request to the tf server.
"""
from tensorflow_serving.apis import predict_pb2
request = predict_pb2.PredictRequest()
request.model_spec.name = self.service_name
request = self._get_default_request()
request.model_spec.name = self.model_name
request.model_spec.signature_name = self.signature_name
return request

def _get_default_request(self):
if self.method_name == 'Predict':
from tensorflow_serving.apis import predict_pb2
request = predict_pb2.PredictRequest()
elif self.method_name == 'Classify':
from tensorflow_serving.apis import classification_pb2
request = classification_pb2.ClassificationRequest()
elif self.method_name == 'Regression':
from tensorflow_serving.apis import regression_pb2
request = regression_pb2.RegressionRequest()
elif self.method_name == 'MultiInference':
from tensorflow_serving.apis import inference_pb2
request = inference_pb2.MultiInferenceRequest()
else:
self.logger.error('unknonwn method_name: {}'.format(self.method_name))
raise NotImplementedError
return request

@staticmethod
def fill_request(request, data_dict):
import tensorflow as tf
Expand Down
29 changes: 23 additions & 6 deletions tests/executors/encoders/clients.py
Expand Up @@ -5,17 +5,34 @@


class MyTestCase(JinaTestCase):
@unittest.skip('add grpc mocking for this test')
def test_something(self):
encoder = UnaryTFServingClientEncoder(
host='0.0.0.0', port='8500', service_name='mnist',
input_name='images', output_name='scores',
signature_name='predict_images')
# @unittest.skip('add grpc mocking for this test')
def test_mnist_predict(self):
class MnistTFServingClientEncoder(UnaryTFServingClientEncoder):
def __init__(self, *args, **kwargs):
super().__init__(input_name='images', output_name='scores', model_name='mnist', *args, **kwargs)
self.host = '0.0.0.0'
self.port = '8500'
self.method_name = 'Predict'
self.signature_name = 'predict_images'
import numpy as np
encoder = MnistTFServingClientEncoder()
data = np.random.rand(1, 784)
result = encoder.encode(data)
self.assertEqual(result.shape, (10, ))

def test_mnist_classify(self):
class MnistTFServingClientEncoder(UnaryTFServingClientEncoder):
def __init__(self, *args, **kwargs):
super().__init__(input_name='inputs', output_name='scores', model_name='mnist', *args, **kwargs)
self.host = '0.0.0.0'
self.port = '8500'
self.method_name = 'Classify'
self.signature_name = 'classify_images'
import numpy as np
encoder = MnistTFServingClientEncoder()
data = np.random.rand(1, 784)
result = encoder.encode(data)
self.assertEqual(result.shape, (10, ))

if __name__ == '__main__':
unittest.main()

0 comments on commit 20563b2

Please sign in to comment.