Skip to content

Commit

Permalink
[python] Add dynamic batching feature to python engine (#371)
Browse files Browse the repository at this point in the history
* [python] Add dynamic batching feature to python engine

* Add unittest for batchPredict
  • Loading branch information
frankfliu committed Dec 9, 2022
1 parent 7d60a09 commit fdcaf55
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 54 deletions.
29 changes: 29 additions & 0 deletions engines/python/setup/djl_python/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io
import struct
import json
import re

from .np_util import from_nd_list
from .pair_list import PairList
Expand Down Expand Up @@ -81,6 +82,34 @@ def __str__(self):
cur_str += "\n{}: {}".format(key, self.get_data(key))
return cur_str

def is_batch(self) -> bool:
return self.get_batch_size() > 1

def get_batch_size(self) -> int:
return int(self.properties.get("batch_size", "1"))

def get_batches(self) -> list["Input"]:
batch_size = self.get_batch_size()
if batch_size == 1:
return [self]

batch = []
for i in range(batch_size):
item = Input()
item.properties = self.properties
batch.append(item)

p = re.compile("batch_(\\d+)\\.(.*)")
for i in range(self.content.size()):
key = self.content.key_at(i)
m = p.match(key)
if m is None:
raise ValueError(f"Unexpected key in batch input: key")
index = int(m.group(1))
batch[index].content.add(m.group(2), self.content.value_at(i))

return batch

def get_function_name(self) -> str:
return self.function_name

Expand Down
41 changes: 24 additions & 17 deletions engines/python/setup/djl_python/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,43 +86,50 @@ def add_property(self, key, val):
self.properties[key] = val
return self

def add(self, value, key=None):
def add(self, value, key=None, batch_index=None):
if key is not None and type(key) is not str:
logging.warning(f"Output key should be str type, got {type(key)}")
key = str(key)

if batch_index is not None:
key = "" if key is None else key
key = f"batch_{batch_index}.{key}"

if type(value) is str:
self.content.add(key=key, value=value.encode("utf-8"))
elif type(value) is bytearray:
self.content.add(key=key, value=value)
elif type(value) is bytes:
self.content.add(key=key, value=bytearray(value))
else:
self.add_as_json(value, key=key)
self.content.add(key=key, value=self._encode_json(value))
return self

def add_as_numpy(self, np_list, key=None):
self.content.add(key=key, value=to_nd_list(np_list))
return self
def add_as_numpy(self, np_list, key=None, batch_index=None):
return self.add(to_nd_list(np_list), key=key, batch_index=batch_index)

def add_as_npz(self, np_list, key=None):
def add_as_npz(self, np_list, key=None, batch_index=None):
import numpy as np
import io
memory_file = io.BytesIO()
np.savez(memory_file, *np_list)
memory_file.seek(0)
self.content.add(key=key, value=memory_file.read(-1))
return self
return self.add(memory_file.read(-1), key=key, batch_index=batch_index)

def add_as_json(self, val, key=None):
json_value = json.dumps(val,
ensure_ascii=False,
allow_nan=False,
indent=2,
cls=_JSONEncoder,
separators=(",", ":")).encode("utf-8")
self.content.add(key=key, value=json_value)
return self
def add_as_json(self, val, key=None, batch_index=None):
return self.add(self._encode_json(val),
key=key,
batch_index=batch_index)

@staticmethod
def _encode_json(val) -> bytes:
return bytearray(
json.dumps(val,
ensure_ascii=False,
allow_nan=False,
indent=2,
cls=_JSONEncoder,
separators=(",", ":")).encode("utf-8"))

@staticmethod
def write_utf8(msg, val):
Expand Down
66 changes: 62 additions & 4 deletions engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,25 @@
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.BytesSupplier;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.Pair;
import ai.djl.util.PairList;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

class PyPredictor<I, O> extends Predictor<I, O> {

private static final Pattern BATCH_PATTERN = Pattern.compile("batch_(\\d+)\\.(.*)");

private PyProcess process;
private int timeout;

Expand All @@ -50,10 +58,60 @@ public List<O> batchPredict(List<I> inputs) throws TranslateException {
// TODO: wait for restart
throw new TranslateException("Backend Python process is stopped.");
}
if (inputs.get(0) instanceof Input) {
List<O> ret = new ArrayList<>(inputs.size());
for (I input : inputs) {
ret.add((O) process.predict((Input) input, timeout, false));
Object first = inputs.get(0);
if (first instanceof Input) {
int size = inputs.size();
if (size == 1) {
Output output = process.predict((Input) first, timeout, false);
return Collections.singletonList((O) output);
}

Input batch = new Input();
List<O> ret = new ArrayList<>(size);
batch.setProperties(((Input) first).getProperties());
batch.addProperty("batch_size", String.valueOf(size));
for (int i = 0; i < size; ++i) {
Input in = (Input) inputs.get(i);
PairList<String, BytesSupplier> content = in.getContent();
String prefix = "batch_" + i;
for (Pair<String, BytesSupplier> pair : content) {
String key = pair.getKey();
key = key == null ? "data" : key;
batch.add(prefix + '.' + key, pair.getValue());
}
}
Output output = process.predict(batch, timeout, false);
if (output.getCode() >= 300) {
for (int i = 0; i < size; ++i) {
ret.add((O) output);
}
return ret;
}
if (output.getContent().size() != size) {
throw new TranslateException(
"Batch output size mismatch, expected: "
+ size
+ ", actual: "
+ output.getContent().size());
}
for (int i = 0; i < size; ++i) {
Output out = new Output();
out.setCode(output.getCode());
out.setMessage(output.getMessage());
out.setProperties(output.getProperties());
ret.add((O) out);
}

PairList<String, BytesSupplier> content = output.getContent();
for (Pair<String, BytesSupplier> pair : content) {
String key = pair.getKey();
Matcher m = BATCH_PATTERN.matcher(key);
if (!m.matches()) {
throw new TranslateException("Unexpected batch output key: " + key);
}
int index = Integer.parseInt(m.group(1));
Output out = (Output) ret.get(index);
out.add(m.group(2), pair.getValue());
}
return ret;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -160,6 +161,27 @@ public void testEchoModel() throws TranslateException, IOException, ModelExcepti
}
}

@Test
public void testBatchEcho() throws TranslateException, IOException, ModelException {
Criteria<Input, Output> criteria =
Criteria.builder()
.setTypes(Input.class, Output.class)
.optModelPath(Paths.get("src/test/resources/echo"))
.optEngine("Python")
.build();
try (ZooModel<Input, Output> model = criteria.loadModel();
Predictor<Input, Output> predictor = model.newPredictor()) {
Input in1 = new Input();
in1.add("test1");
Input in2 = new Input();
in2.add("test2");
List<Input> batch = Arrays.asList(in1, in2);
List<Output> out = predictor.batchPredict(batch);
Assert.assertEquals(out.size(), 2);
Assert.assertEquals(out.get(1).getAsString(0), "test2");
}
}

@Test
public void testResnet18() throws TranslateException, IOException, ModelException {
if (!Boolean.getBoolean("nightly")) {
Expand All @@ -181,9 +203,14 @@ public void testResnet18() throws TranslateException, IOException, ModelExceptio
input.addProperty("Content-Type", "image/jpeg");
Output output = predictor.predict(input);
String classification = output.getData().getAsString();
Type type = new TypeToken<List<Map<String, Double>>>() {}.getType();
List<Map<String, Double>> list = JsonUtils.GSON.fromJson(classification, type);
Assert.assertTrue(list.get(0).containsKey("tabby"));
Type type = new TypeToken<Map<String, Double>>() {}.getType();
Map<String, Double> map = JsonUtils.GSON.fromJson(classification, type);
Assert.assertTrue(map.containsKey("tabby"));

// Test batch predict
List<Input> batch = Arrays.asList(input, input);
List<Output> ret = predictor.batchPredict(batch);
Assert.assertEquals(ret.size(), 2);
}
}

Expand Down
13 changes: 11 additions & 2 deletions engines/python/src/test/resources/echo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Test Python model example.
"""

import logging
import sys
from djl_python import Input
from djl_python import Output
Expand All @@ -30,10 +31,18 @@ def handle(inputs: Input):
sys.exit()

data = inputs.get_as_bytes()
content_type = inputs.get_property("content-type")

outputs = Output()
outputs.add(data, key="data")
content_type = inputs.get_property("content-type")
if content_type:
outputs.add_property("content-type", content_type)

if inputs.is_batch():
logging.info(f"Dynamic batching size: {inputs.get_batch_size()}.")
batch = inputs.get_batches()
for i, item in enumerate(batch):
outputs.add(item.get_as_bytes(), key="data", batch_index=i)
else:
outputs.add(data, key="data")

return outputs
48 changes: 20 additions & 28 deletions engines/python/src/test/resources/resnet18/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
PyTorch resnet18 model example.
"""

import itertools
import json
import logging
import os
Expand Down Expand Up @@ -78,18 +77,28 @@ def inference(self, inputs):
outputs.add_as_numpy([data.detach().numpy()])
return outputs

image = inputs.get_as_image()
image = self.image_processing(image)
images = torch.stack([image]).to(self.device)
batch = inputs.get_batches()
images = []
for i, item in enumerate(batch):
image = self.image_processing(item.get_as_image())
images.append(image)
images = torch.stack(images).to(self.device)

data = self.model(images)
ps = F.softmax(data, dim=1)
probs, classes = torch.topk(ps, self.topK, dim=1)
probs = probs.tolist()
classes = classes.tolist()

outputs.add_as_json(
self.map_class_to_label(probs, self.mapping, classes))
for i in range(inputs.get_batch_size()):
item = data[i]
ps = F.softmax(item, dim=0)
probs, classes = torch.topk(ps, self.topK)
probs = probs.tolist()
classes = classes.tolist()
result = {
self.mapping[str(classes[i])]: probs[i]
for i in range(self.topK)
}
if inputs.is_batch():
outputs.add_as_json(result, batch_index=i)
else:
outputs.add_as_json(result)
except Exception as e:
logging.exception("resnet18 failed")
# error handling
Expand Down Expand Up @@ -117,23 +126,6 @@ def load_label_mapping(mapping_file_path):
mapping[key] = new_value
return mapping

@staticmethod
def map_class_to_label(probs, mapping=None, lbl_classes=None):
if not (isinstance(probs, list) and isinstance(probs, list)):
raise Exception('Convert classes to list before doing mapping')
if mapping is not None and not isinstance(mapping, dict):
raise Exception('Mapping must be a dict')

if lbl_classes is None:
lbl_classes = itertools.repeat(range(len(probs[0])), len(probs))

results = [{(mapping[str(lbl_class)]
if mapping is not None else str(lbl_class)): prob
for lbl_class, prob in zip(*row)}
for row in zip(lbl_classes, probs)]

return results


_service = Resnet18()

Expand Down
2 changes: 2 additions & 0 deletions serving/src/main/conf/config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ load_models=ALL
# private_key_file=conf/key.pem
# certificate_file=conf/certs.pem
# max_request_size=2000485760
# batch_size=1
# max_batch_delay=300

0 comments on commit fdcaf55

Please sign in to comment.