From fdcaf553f3ccf2031dd8f11f3010704e32d3a5e5 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Fri, 9 Dec 2022 14:29:29 -0800 Subject: [PATCH] [python] Add dynamic batching feature to python engine (#371) * [python] Add dynamic batching feature to python engine * Add unittest for batchPredict --- engines/python/setup/djl_python/inputs.py | 29 ++++++++ engines/python/setup/djl_python/outputs.py | 41 +++++++----- .../ai/djl/python/engine/PyPredictor.java | 66 +++++++++++++++++-- .../ai/djl/python/engine/PyEngineTest.java | 33 +++++++++- .../python/src/test/resources/echo/model.py | 13 +++- .../src/test/resources/resnet18/model.py | 48 ++++++-------- serving/src/main/conf/config.properties | 2 + 7 files changed, 178 insertions(+), 54 deletions(-) diff --git a/engines/python/setup/djl_python/inputs.py b/engines/python/setup/djl_python/inputs.py index b8f6737db..ed0212646 100644 --- a/engines/python/setup/djl_python/inputs.py +++ b/engines/python/setup/djl_python/inputs.py @@ -14,6 +14,7 @@ import io import struct import json +import re from .np_util import from_nd_list from .pair_list import PairList @@ -81,6 +82,34 @@ def __str__(self): cur_str += "\n{}: {}".format(key, self.get_data(key)) return cur_str + def is_batch(self) -> bool: + return self.get_batch_size() > 1 + + def get_batch_size(self) -> int: + return int(self.properties.get("batch_size", "1")) + + def get_batches(self) -> list["Input"]: + batch_size = self.get_batch_size() + if batch_size == 1: + return [self] + + batch = [] + for i in range(batch_size): + item = Input() + item.properties = self.properties + batch.append(item) + + p = re.compile("batch_(\\d+)\\.(.*)") + for i in range(self.content.size()): + key = self.content.key_at(i) + m = p.match(key) + if m is None: + raise ValueError(f"Unexpected key in batch input: key") + index = int(m.group(1)) + batch[index].content.add(m.group(2), self.content.value_at(i)) + + return batch + def get_function_name(self) -> str: return self.function_name diff --git a/engines/python/setup/djl_python/outputs.py b/engines/python/setup/djl_python/outputs.py index 1109150b7..562c32b40 100644 --- a/engines/python/setup/djl_python/outputs.py +++ b/engines/python/setup/djl_python/outputs.py @@ -86,11 +86,15 @@ def add_property(self, key, val): self.properties[key] = val return self - def add(self, value, key=None): + def add(self, value, key=None, batch_index=None): if key is not None and type(key) is not str: logging.warning(f"Output key should be str type, got {type(key)}") key = str(key) + if batch_index is not None: + key = "" if key is None else key + key = f"batch_{batch_index}.{key}" + if type(value) is str: self.content.add(key=key, value=value.encode("utf-8")) elif type(value) is bytearray: @@ -98,31 +102,34 @@ def add(self, value, key=None): elif type(value) is bytes: self.content.add(key=key, value=bytearray(value)) else: - self.add_as_json(value, key=key) + self.content.add(key=key, value=self._encode_json(value)) return self - def add_as_numpy(self, np_list, key=None): - self.content.add(key=key, value=to_nd_list(np_list)) - return self + def add_as_numpy(self, np_list, key=None, batch_index=None): + return self.add(to_nd_list(np_list), key=key, batch_index=batch_index) - def add_as_npz(self, np_list, key=None): + def add_as_npz(self, np_list, key=None, batch_index=None): import numpy as np import io memory_file = io.BytesIO() np.savez(memory_file, *np_list) memory_file.seek(0) - self.content.add(key=key, value=memory_file.read(-1)) - return self + return self.add(memory_file.read(-1), key=key, batch_index=batch_index) - def add_as_json(self, val, key=None): - json_value = json.dumps(val, - ensure_ascii=False, - allow_nan=False, - indent=2, - cls=_JSONEncoder, - separators=(",", ":")).encode("utf-8") - self.content.add(key=key, value=json_value) - return self + def add_as_json(self, val, key=None, batch_index=None): + return self.add(self._encode_json(val), + key=key, + batch_index=batch_index) + + @staticmethod + def _encode_json(val) -> bytes: + return bytearray( + json.dumps(val, + ensure_ascii=False, + allow_nan=False, + indent=2, + cls=_JSONEncoder, + separators=(",", ":")).encode("utf-8")) @staticmethod def write_utf8(msg, val): diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java b/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java index 0d7a5caf3..d927042ff 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java @@ -17,17 +17,25 @@ import ai.djl.inference.Predictor; import ai.djl.modality.Input; import ai.djl.modality.Output; +import ai.djl.ndarray.BytesSupplier; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; +import ai.djl.util.Pair; +import ai.djl.util.PairList; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; class PyPredictor extends Predictor { + private static final Pattern BATCH_PATTERN = Pattern.compile("batch_(\\d+)\\.(.*)"); + private PyProcess process; private int timeout; @@ -50,10 +58,60 @@ public List batchPredict(List inputs) throws TranslateException { // TODO: wait for restart throw new TranslateException("Backend Python process is stopped."); } - if (inputs.get(0) instanceof Input) { - List ret = new ArrayList<>(inputs.size()); - for (I input : inputs) { - ret.add((O) process.predict((Input) input, timeout, false)); + Object first = inputs.get(0); + if (first instanceof Input) { + int size = inputs.size(); + if (size == 1) { + Output output = process.predict((Input) first, timeout, false); + return Collections.singletonList((O) output); + } + + Input batch = new Input(); + List ret = new ArrayList<>(size); + batch.setProperties(((Input) first).getProperties()); + batch.addProperty("batch_size", String.valueOf(size)); + for (int i = 0; i < size; ++i) { + Input in = (Input) inputs.get(i); + PairList content = in.getContent(); + String prefix = "batch_" + i; + for (Pair pair : content) { + String key = pair.getKey(); + key = key == null ? "data" : key; + batch.add(prefix + '.' + key, pair.getValue()); + } + } + Output output = process.predict(batch, timeout, false); + if (output.getCode() >= 300) { + for (int i = 0; i < size; ++i) { + ret.add((O) output); + } + return ret; + } + if (output.getContent().size() != size) { + throw new TranslateException( + "Batch output size mismatch, expected: " + + size + + ", actual: " + + output.getContent().size()); + } + for (int i = 0; i < size; ++i) { + Output out = new Output(); + out.setCode(output.getCode()); + out.setMessage(output.getMessage()); + out.setProperties(output.getProperties()); + ret.add((O) out); + } + + PairList content = output.getContent(); + for (Pair pair : content) { + String key = pair.getKey(); + Matcher m = BATCH_PATTERN.matcher(key); + if (!m.matches()) { + throw new TranslateException("Unexpected batch output key: " + key); + } + int index = Integer.parseInt(m.group(1)); + Output out = (Output) ret.get(index); + out.add(m.group(2), pair.getValue()); } return ret; } diff --git a/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java b/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java index 8516feffb..cf8553c38 100644 --- a/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java +++ b/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java @@ -43,6 +43,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Arrays; import java.util.List; import java.util.Map; @@ -160,6 +161,27 @@ public void testEchoModel() throws TranslateException, IOException, ModelExcepti } } + @Test + public void testBatchEcho() throws TranslateException, IOException, ModelException { + Criteria criteria = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelPath(Paths.get("src/test/resources/echo")) + .optEngine("Python") + .build(); + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + Input in1 = new Input(); + in1.add("test1"); + Input in2 = new Input(); + in2.add("test2"); + List batch = Arrays.asList(in1, in2); + List out = predictor.batchPredict(batch); + Assert.assertEquals(out.size(), 2); + Assert.assertEquals(out.get(1).getAsString(0), "test2"); + } + } + @Test public void testResnet18() throws TranslateException, IOException, ModelException { if (!Boolean.getBoolean("nightly")) { @@ -181,9 +203,14 @@ public void testResnet18() throws TranslateException, IOException, ModelExceptio input.addProperty("Content-Type", "image/jpeg"); Output output = predictor.predict(input); String classification = output.getData().getAsString(); - Type type = new TypeToken>>() {}.getType(); - List> list = JsonUtils.GSON.fromJson(classification, type); - Assert.assertTrue(list.get(0).containsKey("tabby")); + Type type = new TypeToken>() {}.getType(); + Map map = JsonUtils.GSON.fromJson(classification, type); + Assert.assertTrue(map.containsKey("tabby")); + + // Test batch predict + List batch = Arrays.asList(input, input); + List ret = predictor.batchPredict(batch); + Assert.assertEquals(ret.size(), 2); } } diff --git a/engines/python/src/test/resources/echo/model.py b/engines/python/src/test/resources/echo/model.py index bca83913c..60ad1a773 100644 --- a/engines/python/src/test/resources/echo/model.py +++ b/engines/python/src/test/resources/echo/model.py @@ -14,6 +14,7 @@ Test Python model example. """ +import logging import sys from djl_python import Input from djl_python import Output @@ -30,10 +31,18 @@ def handle(inputs: Input): sys.exit() data = inputs.get_as_bytes() - content_type = inputs.get_property("content-type") + outputs = Output() - outputs.add(data, key="data") + content_type = inputs.get_property("content-type") if content_type: outputs.add_property("content-type", content_type) + if inputs.is_batch(): + logging.info(f"Dynamic batching size: {inputs.get_batch_size()}.") + batch = inputs.get_batches() + for i, item in enumerate(batch): + outputs.add(item.get_as_bytes(), key="data", batch_index=i) + else: + outputs.add(data, key="data") + return outputs diff --git a/engines/python/src/test/resources/resnet18/model.py b/engines/python/src/test/resources/resnet18/model.py index bb02b9bbf..0edeccf5c 100644 --- a/engines/python/src/test/resources/resnet18/model.py +++ b/engines/python/src/test/resources/resnet18/model.py @@ -14,7 +14,6 @@ PyTorch resnet18 model example. """ -import itertools import json import logging import os @@ -78,18 +77,28 @@ def inference(self, inputs): outputs.add_as_numpy([data.detach().numpy()]) return outputs - image = inputs.get_as_image() - image = self.image_processing(image) - images = torch.stack([image]).to(self.device) + batch = inputs.get_batches() + images = [] + for i, item in enumerate(batch): + image = self.image_processing(item.get_as_image()) + images.append(image) + images = torch.stack(images).to(self.device) data = self.model(images) - ps = F.softmax(data, dim=1) - probs, classes = torch.topk(ps, self.topK, dim=1) - probs = probs.tolist() - classes = classes.tolist() - - outputs.add_as_json( - self.map_class_to_label(probs, self.mapping, classes)) + for i in range(inputs.get_batch_size()): + item = data[i] + ps = F.softmax(item, dim=0) + probs, classes = torch.topk(ps, self.topK) + probs = probs.tolist() + classes = classes.tolist() + result = { + self.mapping[str(classes[i])]: probs[i] + for i in range(self.topK) + } + if inputs.is_batch(): + outputs.add_as_json(result, batch_index=i) + else: + outputs.add_as_json(result) except Exception as e: logging.exception("resnet18 failed") # error handling @@ -117,23 +126,6 @@ def load_label_mapping(mapping_file_path): mapping[key] = new_value return mapping - @staticmethod - def map_class_to_label(probs, mapping=None, lbl_classes=None): - if not (isinstance(probs, list) and isinstance(probs, list)): - raise Exception('Convert classes to list before doing mapping') - if mapping is not None and not isinstance(mapping, dict): - raise Exception('Mapping must be a dict') - - if lbl_classes is None: - lbl_classes = itertools.repeat(range(len(probs[0])), len(probs)) - - results = [{(mapping[str(lbl_class)] - if mapping is not None else str(lbl_class)): prob - for lbl_class, prob in zip(*row)} - for row in zip(lbl_classes, probs)] - - return results - _service = Resnet18() diff --git a/serving/src/main/conf/config.properties b/serving/src/main/conf/config.properties index b28ef85fa..ca09a4eb1 100644 --- a/serving/src/main/conf/config.properties +++ b/serving/src/main/conf/config.properties @@ -16,4 +16,6 @@ load_models=ALL # private_key_file=conf/key.pem # certificate_file=conf/certs.pem # max_request_size=2000485760 +# batch_size=1 +# max_batch_delay=300