diff --git a/bentoml/handlers/fastai_image_handler.py b/bentoml/handlers/fastai_image_handler.py index adfec8ace7..381bb5ee84 100644 --- a/bentoml/handlers/fastai_image_handler.py +++ b/bentoml/handlers/fastai_image_handler.py @@ -94,7 +94,9 @@ def __init__( div=True, cls=None, after_open=None, + **base_kwargs, ): + super(FastaiImageHandler, self).__init__(**base_kwargs) self.imread = _import_imageio_imread() self.fastai_vision = _import_fastai_vision() @@ -138,6 +140,9 @@ def request_schema(self): def pip_dependencies(self): return ['imageio', 'fastai'] + def handle_batch_request(self, requests, func): + raise NotImplementedError + def handle_request(self, request, func): input_streams = [] for filename in self.input_names: diff --git a/bentoml/handlers/image_handler.py b/bentoml/handlers/image_handler.py index 4fe834dce3..426a5c2553 100644 --- a/bentoml/handlers/image_handler.py +++ b/bentoml/handlers/image_handler.py @@ -91,8 +91,13 @@ class ImageHandler(BentoHandler): HTTP_METHODS = ["POST"] def __init__( - self, input_names=("image",), accept_image_formats=None, pilmode="RGB" + self, + input_names=("image",), + accept_image_formats=None, + pilmode="RGB", + **base_kwargs, ): + super(ImageHandler, self).__init__(**base_kwargs) self.imread = _import_imageio_imread() self.input_names = tuple(input_names) @@ -129,6 +134,9 @@ def request_schema(self): def pip_dependencies(self): return ['imageio'] + def handle_batch_request(self, requests, func): + raise NotImplementedError + def handle_request(self, request, func): """Handle http request that has image file/s. It will convert image into a ndarray for the function to consume. diff --git a/bentoml/handlers/json_handler.py b/bentoml/handlers/json_handler.py index c08aef81f3..945948b61a 100644 --- a/bentoml/handlers/json_handler.py +++ b/bentoml/handlers/json_handler.py @@ -45,6 +45,9 @@ def handle_request(self, request, func): json_output = api_func_result_to_json(result) return Response(response=json_output, status=200, mimetype="application/json") + def handle_batch_request(self, requests, func): + raise NotImplementedError + def handle_cli(self, args, func): parser = argparse.ArgumentParser() parser.add_argument("--input", required=True) diff --git a/bentoml/handlers/tensorflow_tensor_handler.py b/bentoml/handlers/tensorflow_tensor_handler.py index 9530db35cb..696411deb4 100644 --- a/bentoml/handlers/tensorflow_tensor_handler.py +++ b/bentoml/handlers/tensorflow_tensor_handler.py @@ -23,6 +23,7 @@ NestedConverter, tf_b64_2_bytes, tf_tensor_2_serializable, + concat_list, ) from bentoml.handlers.base_handlers import BentoHandler, api_func_result_to_json from bentoml.exceptions import BentoMLException, BadInput @@ -48,16 +49,17 @@ class TensorflowTensorHandler(BentoHandler): BentoMLException: BentoML currently doesn't support Content-Type """ + BATCH_MODE_SUPPORTED = True METHODS = (PREDICT, CLASSIFY, REGRESS) = ("predict", "classify", "regress") - def __init__(self, method=PREDICT): + def __init__(self, method=PREDICT, **base_kwargs): + super(TensorflowTensorHandler, self).__init__(**base_kwargs) self.method = method @property def config(self): - return { - "method": self.method, - } + base_config = super(self.__class__, self).config + return dict(base_config, method=self.method,) @property def request_schema(self): @@ -102,6 +104,59 @@ def _handle_raw_str(self, raw_str, output_format, func): return result_str + def handle_batch_request(self, requests, func): + """ + TODO(hrmthw): + 1. check content type + 1. specify batch dim + 1. output str fromat + """ + import tensorflow as tf + + bad_resp = Response(response="Bad Input", status=400) + instances_list = [None] * len(requests) + responses = [bad_resp] * len(requests) + + for i, request in enumerate(requests): + try: + raw_str = request.data.decode("utf-8") + parsed_json = json.loads(raw_str) + if parsed_json.get("instances") is not None: + instances = parsed_json.get("instances") + if instances is None: + continue + instances = decode_b64_if_needed(instances) + if not isinstance(instances, (list, tuple)): + instances = [instances] + instances_list[i] = instances + + elif parsed_json.get("inputs"): + responses[i] = Response( + response="Column format 'inputs' is not implemented", status=501 + ) + + except (json.exceptions.JSONDecodeError, UnicodeDecodeError): + import traceback + + traceback.print_exc() + + merged_instances, slices = concat_list(instances_list) + + parsed_tensor = tf.constant(merged_instances) + merged_result = func(parsed_tensor) + merged_result = decode_tf_if_needed(merged_result) + assert isinstance(merged_result, (list, tuple)) + + results = [merged_result[s] for s in slices] + + for i, result in enumerate(results): + result_str = api_func_result_to_json(result) + responses[i] = Response( + response=result_str, status=200, mimetype="application/json" + ) + + return responses + def handle_request(self, request, func): """Handle http request that has jsonlized tensorflow tensor. It will convert it into a tf tensor for the function to consume. diff --git a/bentoml/handlers/utils.py b/bentoml/handlers/utils.py index e1182cdd50..6ee345c865 100644 --- a/bentoml/handlers/utils.py +++ b/bentoml/handlers/utils.py @@ -1,4 +1,4 @@ -TF_B64_KEY = 'b64' +TF_B64_KEY = "b64" def tf_b64_2_bytes(obj): @@ -14,19 +14,19 @@ def bytes_2_tf_b64(obj): import base64 if isinstance(obj, bytes): - return {TF_B64_KEY: base64.b64encode(obj).decode('utf-8')} + return {TF_B64_KEY: base64.b64encode(obj).decode("utf-8")} else: return obj def tf_tensor_2_serializable(obj): - ''' + """ To convert tf.Tensor -> json serializable np.ndarray -> json serializable bytes -> {'b64': } others -> themselves - ''' + """ import tensorflow as tf import numpy as np @@ -63,10 +63,10 @@ def tf_tensor_2_serializable(obj): class NestedConverter: - ''' + """ Generate a nested converter that supports object in list/tuple/dict from a single converter. - ''' + """ def __init__(self, converter): self.converter = converter @@ -82,3 +82,25 @@ def __call__(self, obj): return [self(v) for v in obj] else: return obj + + +def concat_list(lst): + """ + >>> lst = [ + [1], + [1, 2], + [1, 2, 3], + ] + >>> concat_list(lst) + [1, 1, 2, 1, 2, 3], [slice(0, 1), slice(1, 3), slice(3, 6)] + """ + slices = [slice(0)] * len(lst) + datas = [] + row_flag = 0 + for i, r in enumerate(lst): + j = -1 + for j, d in enumerate(r): + datas.append(d) + slices[i] = slice(row_flag, row_flag + j + 1) + row_flag += j + 1 + return datas, slices diff --git a/bentoml/marshal/marshal.py b/bentoml/marshal/marshal.py index 238becba32..a65383f75c 100644 --- a/bentoml/marshal/marshal.py +++ b/bentoml/marshal/marshal.py @@ -114,6 +114,8 @@ async def _func(requests): headers={self._MARSHAL_FLAG: 'true'}, ) as resp: resps = await split_aio_responses(resp) + if resps is None: + return [aiohttp.web.HTTPInternalServerError] * len(requests) return resps self.batch_handlers[api_name] = _func diff --git a/bentoml/marshal/utils.py b/bentoml/marshal/utils.py index 7716710f39..39af4a1269 100644 --- a/bentoml/marshal/utils.py +++ b/bentoml/marshal/utils.py @@ -16,7 +16,7 @@ async def split_aio_responses(ori_response): try: merged_responses = pickle.loads(merged) except pickle.UnpicklingError: - raise + return None if ori_response.status != 200: return [web.Response(status=ori_response.status)] * len(merged_responses) diff --git a/tests/handlers/test_utils.py b/tests/handlers/test_utils.py new file mode 100644 index 0000000000..59c2444344 --- /dev/null +++ b/tests/handlers/test_utils.py @@ -0,0 +1,14 @@ +from bentoml.handlers.utils import concat_list + + +def test_concat(): + lst = [ + [1], + [1, 2], + [], + [1, 2, 3], + ] + datas, slices = concat_list(lst) + + for s, origin_data in zip(slices, lst): + assert origin_data == datas[s]