Skip to content

Commit

Permalink
implement micro batching support for tf tensor handler (#533)
Browse files Browse the repository at this point in the history
* implement micro batching support for tensorflow_tensor_handler

* fix bare except

* style: format with black

* make linter happy
  • Loading branch information
bojiang committed Feb 13, 2020
1 parent e007568 commit 769ff01
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 12 deletions.
5 changes: 5 additions & 0 deletions bentoml/handlers/fastai_image_handler.py
Expand Up @@ -94,7 +94,9 @@ def __init__(
div=True,
cls=None,
after_open=None,
**base_kwargs,
):
super(FastaiImageHandler, self).__init__(**base_kwargs)
self.imread = _import_imageio_imread()
self.fastai_vision = _import_fastai_vision()

Expand Down Expand Up @@ -138,6 +140,9 @@ def request_schema(self):
def pip_dependencies(self):
return ['imageio', 'fastai']

def handle_batch_request(self, requests, func):
raise NotImplementedError

def handle_request(self, request, func):
input_streams = []
for filename in self.input_names:
Expand Down
10 changes: 9 additions & 1 deletion bentoml/handlers/image_handler.py
Expand Up @@ -91,8 +91,13 @@ class ImageHandler(BentoHandler):
HTTP_METHODS = ["POST"]

def __init__(
self, input_names=("image",), accept_image_formats=None, pilmode="RGB"
self,
input_names=("image",),
accept_image_formats=None,
pilmode="RGB",
**base_kwargs,
):
super(ImageHandler, self).__init__(**base_kwargs)
self.imread = _import_imageio_imread()

self.input_names = tuple(input_names)
Expand Down Expand Up @@ -129,6 +134,9 @@ def request_schema(self):
def pip_dependencies(self):
return ['imageio']

def handle_batch_request(self, requests, func):
raise NotImplementedError

def handle_request(self, request, func):
"""Handle http request that has image file/s. It will convert image into a
ndarray for the function to consume.
Expand Down
3 changes: 3 additions & 0 deletions bentoml/handlers/json_handler.py
Expand Up @@ -45,6 +45,9 @@ def handle_request(self, request, func):
json_output = api_func_result_to_json(result)
return Response(response=json_output, status=200, mimetype="application/json")

def handle_batch_request(self, requests, func):
raise NotImplementedError

def handle_cli(self, args, func):
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True)
Expand Down
63 changes: 59 additions & 4 deletions bentoml/handlers/tensorflow_tensor_handler.py
Expand Up @@ -23,6 +23,7 @@
NestedConverter,
tf_b64_2_bytes,
tf_tensor_2_serializable,
concat_list,
)
from bentoml.handlers.base_handlers import BentoHandler, api_func_result_to_json
from bentoml.exceptions import BentoMLException, BadInput
Expand All @@ -48,16 +49,17 @@ class TensorflowTensorHandler(BentoHandler):
BentoMLException: BentoML currently doesn't support Content-Type
"""

BATCH_MODE_SUPPORTED = True
METHODS = (PREDICT, CLASSIFY, REGRESS) = ("predict", "classify", "regress")

def __init__(self, method=PREDICT):
def __init__(self, method=PREDICT, **base_kwargs):
super(TensorflowTensorHandler, self).__init__(**base_kwargs)
self.method = method

@property
def config(self):
return {
"method": self.method,
}
base_config = super(self.__class__, self).config
return dict(base_config, method=self.method,)

@property
def request_schema(self):
Expand Down Expand Up @@ -102,6 +104,59 @@ def _handle_raw_str(self, raw_str, output_format, func):

return result_str

def handle_batch_request(self, requests, func):
"""
TODO(hrmthw):
1. check content type
1. specify batch dim
1. output str fromat
"""
import tensorflow as tf

bad_resp = Response(response="Bad Input", status=400)
instances_list = [None] * len(requests)
responses = [bad_resp] * len(requests)

for i, request in enumerate(requests):
try:
raw_str = request.data.decode("utf-8")
parsed_json = json.loads(raw_str)
if parsed_json.get("instances") is not None:
instances = parsed_json.get("instances")
if instances is None:
continue
instances = decode_b64_if_needed(instances)
if not isinstance(instances, (list, tuple)):
instances = [instances]
instances_list[i] = instances

elif parsed_json.get("inputs"):
responses[i] = Response(
response="Column format 'inputs' is not implemented", status=501
)

except (json.exceptions.JSONDecodeError, UnicodeDecodeError):
import traceback

traceback.print_exc()

merged_instances, slices = concat_list(instances_list)

parsed_tensor = tf.constant(merged_instances)
merged_result = func(parsed_tensor)
merged_result = decode_tf_if_needed(merged_result)
assert isinstance(merged_result, (list, tuple))

results = [merged_result[s] for s in slices]

for i, result in enumerate(results):
result_str = api_func_result_to_json(result)
responses[i] = Response(
response=result_str, status=200, mimetype="application/json"
)

return responses

def handle_request(self, request, func):
"""Handle http request that has jsonlized tensorflow tensor. It will convert it
into a tf tensor for the function to consume.
Expand Down
34 changes: 28 additions & 6 deletions bentoml/handlers/utils.py
@@ -1,4 +1,4 @@
TF_B64_KEY = 'b64'
TF_B64_KEY = "b64"


def tf_b64_2_bytes(obj):
Expand All @@ -14,19 +14,19 @@ def bytes_2_tf_b64(obj):
import base64

if isinstance(obj, bytes):
return {TF_B64_KEY: base64.b64encode(obj).decode('utf-8')}
return {TF_B64_KEY: base64.b64encode(obj).decode("utf-8")}
else:
return obj


def tf_tensor_2_serializable(obj):
'''
"""
To convert
tf.Tensor -> json serializable
np.ndarray -> json serializable
bytes -> {'b64': <b64_str>}
others -> themselves
'''
"""
import tensorflow as tf
import numpy as np

Expand Down Expand Up @@ -63,10 +63,10 @@ def tf_tensor_2_serializable(obj):


class NestedConverter:
'''
"""
Generate a nested converter that supports object in list/tuple/dict
from a single converter.
'''
"""

def __init__(self, converter):
self.converter = converter
Expand All @@ -82,3 +82,25 @@ def __call__(self, obj):
return [self(v) for v in obj]
else:
return obj


def concat_list(lst):
"""
>>> lst = [
[1],
[1, 2],
[1, 2, 3],
]
>>> concat_list(lst)
[1, 1, 2, 1, 2, 3], [slice(0, 1), slice(1, 3), slice(3, 6)]
"""
slices = [slice(0)] * len(lst)
datas = []
row_flag = 0
for i, r in enumerate(lst):
j = -1
for j, d in enumerate(r):
datas.append(d)
slices[i] = slice(row_flag, row_flag + j + 1)
row_flag += j + 1
return datas, slices
2 changes: 2 additions & 0 deletions bentoml/marshal/marshal.py
Expand Up @@ -114,6 +114,8 @@ async def _func(requests):
headers={self._MARSHAL_FLAG: 'true'},
) as resp:
resps = await split_aio_responses(resp)
if resps is None:
return [aiohttp.web.HTTPInternalServerError] * len(requests)
return resps

self.batch_handlers[api_name] = _func
Expand Down
2 changes: 1 addition & 1 deletion bentoml/marshal/utils.py
Expand Up @@ -16,7 +16,7 @@ async def split_aio_responses(ori_response):
try:
merged_responses = pickle.loads(merged)
except pickle.UnpicklingError:
raise
return None

if ori_response.status != 200:
return [web.Response(status=ori_response.status)] * len(merged_responses)
Expand Down
14 changes: 14 additions & 0 deletions tests/handlers/test_utils.py
@@ -0,0 +1,14 @@
from bentoml.handlers.utils import concat_list


def test_concat():
lst = [
[1],
[1, 2],
[],
[1, 2, 3],
]
datas, slices = concat_list(lst)

for s, origin_data in zip(slices, lst):
assert origin_data == datas[s]

0 comments on commit 769ff01

Please sign in to comment.