Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Add dynamic batching feature to python engine #371

Merged
merged 2 commits into from
Dec 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions engines/python/setup/djl_python/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io
import struct
import json
import re

from .np_util import from_nd_list
from .pair_list import PairList
Expand Down Expand Up @@ -81,6 +82,34 @@ def __str__(self):
cur_str += "\n{}: {}".format(key, self.get_data(key))
return cur_str

def is_batch(self) -> bool:
return self.get_batch_size() > 1

def get_batch_size(self) -> int:
return int(self.properties.get("batch_size", "1"))

def get_batches(self) -> list["Input"]:
batch_size = self.get_batch_size()
if batch_size == 1:
return [self]

batch = []
for i in range(batch_size):
item = Input()
item.properties = self.properties
batch.append(item)

p = re.compile("batch_(\\d+)\\.(.*)")
for i in range(self.content.size()):
key = self.content.key_at(i)
m = p.match(key)
if m is None:
raise ValueError(f"Unexpected key in batch input: key")
index = int(m.group(1))
batch[index].content.add(m.group(2), self.content.value_at(i))

return batch

def get_function_name(self) -> str:
return self.function_name

Expand Down
41 changes: 24 additions & 17 deletions engines/python/setup/djl_python/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,43 +86,50 @@ def add_property(self, key, val):
self.properties[key] = val
return self

def add(self, value, key=None):
def add(self, value, key=None, batch_index=None):
if key is not None and type(key) is not str:
logging.warning(f"Output key should be str type, got {type(key)}")
key = str(key)

if batch_index is not None:
key = "" if key is None else key
key = f"batch_{batch_index}.{key}"

if type(value) is str:
self.content.add(key=key, value=value.encode("utf-8"))
elif type(value) is bytearray:
self.content.add(key=key, value=value)
elif type(value) is bytes:
self.content.add(key=key, value=bytearray(value))
else:
self.add_as_json(value, key=key)
self.content.add(key=key, value=self._encode_json(value))
return self

def add_as_numpy(self, np_list, key=None):
self.content.add(key=key, value=to_nd_list(np_list))
return self
def add_as_numpy(self, np_list, key=None, batch_index=None):
return self.add(to_nd_list(np_list), key=key, batch_index=batch_index)

def add_as_npz(self, np_list, key=None):
def add_as_npz(self, np_list, key=None, batch_index=None):
import numpy as np
import io
memory_file = io.BytesIO()
np.savez(memory_file, *np_list)
memory_file.seek(0)
self.content.add(key=key, value=memory_file.read(-1))
return self
return self.add(memory_file.read(-1), key=key, batch_index=batch_index)

def add_as_json(self, val, key=None):
json_value = json.dumps(val,
ensure_ascii=False,
allow_nan=False,
indent=2,
cls=_JSONEncoder,
separators=(",", ":")).encode("utf-8")
self.content.add(key=key, value=json_value)
return self
def add_as_json(self, val, key=None, batch_index=None):
return self.add(self._encode_json(val),
key=key,
batch_index=batch_index)

@staticmethod
def _encode_json(val) -> bytes:
return bytearray(
json.dumps(val,
ensure_ascii=False,
allow_nan=False,
indent=2,
cls=_JSONEncoder,
separators=(",", ":")).encode("utf-8"))

@staticmethod
def write_utf8(msg, val):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,25 @@
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.BytesSupplier;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.Pair;
import ai.djl.util.PairList;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

class PyPredictor<I, O> extends Predictor<I, O> {

private static final Pattern BATCH_PATTERN = Pattern.compile("batch_(\\d+)\\.(.*)");

private PyProcess process;
private int timeout;

Expand All @@ -50,10 +58,60 @@ public List<O> batchPredict(List<I> inputs) throws TranslateException {
// TODO: wait for restart
throw new TranslateException("Backend Python process is stopped.");
}
if (inputs.get(0) instanceof Input) {
List<O> ret = new ArrayList<>(inputs.size());
for (I input : inputs) {
ret.add((O) process.predict((Input) input, timeout, false));
Object first = inputs.get(0);
if (first instanceof Input) {
int size = inputs.size();
if (size == 1) {
Output output = process.predict((Input) first, timeout, false);
return Collections.singletonList((O) output);
}

Input batch = new Input();
List<O> ret = new ArrayList<>(size);
batch.setProperties(((Input) first).getProperties());
batch.addProperty("batch_size", String.valueOf(size));
for (int i = 0; i < size; ++i) {
Input in = (Input) inputs.get(i);
PairList<String, BytesSupplier> content = in.getContent();
String prefix = "batch_" + i;
for (Pair<String, BytesSupplier> pair : content) {
String key = pair.getKey();
key = key == null ? "data" : key;
batch.add(prefix + '.' + key, pair.getValue());
}
}
Output output = process.predict(batch, timeout, false);
if (output.getCode() >= 300) {
for (int i = 0; i < size; ++i) {
ret.add((O) output);
}
return ret;
}
if (output.getContent().size() != size) {
throw new TranslateException(
"Batch output size mismatch, expected: "
+ size
+ ", actual: "
+ output.getContent().size());
}
for (int i = 0; i < size; ++i) {
Output out = new Output();
out.setCode(output.getCode());
out.setMessage(output.getMessage());
out.setProperties(output.getProperties());
ret.add((O) out);
}

PairList<String, BytesSupplier> content = output.getContent();
for (Pair<String, BytesSupplier> pair : content) {
String key = pair.getKey();
Matcher m = BATCH_PATTERN.matcher(key);
if (!m.matches()) {
throw new TranslateException("Unexpected batch output key: " + key);
}
int index = Integer.parseInt(m.group(1));
Output out = (Output) ret.get(index);
out.add(m.group(2), pair.getValue());
}
return ret;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -160,6 +161,27 @@ public void testEchoModel() throws TranslateException, IOException, ModelExcepti
}
}

@Test
public void testBatchEcho() throws TranslateException, IOException, ModelException {
Criteria<Input, Output> criteria =
Criteria.builder()
.setTypes(Input.class, Output.class)
.optModelPath(Paths.get("src/test/resources/echo"))
.optEngine("Python")
.build();
try (ZooModel<Input, Output> model = criteria.loadModel();
Predictor<Input, Output> predictor = model.newPredictor()) {
Input in1 = new Input();
in1.add("test1");
Input in2 = new Input();
in2.add("test2");
List<Input> batch = Arrays.asList(in1, in2);
List<Output> out = predictor.batchPredict(batch);
Assert.assertEquals(out.size(), 2);
Assert.assertEquals(out.get(1).getAsString(0), "test2");
}
}

@Test
public void testResnet18() throws TranslateException, IOException, ModelException {
if (!Boolean.getBoolean("nightly")) {
Expand All @@ -181,9 +203,14 @@ public void testResnet18() throws TranslateException, IOException, ModelExceptio
input.addProperty("Content-Type", "image/jpeg");
Output output = predictor.predict(input);
String classification = output.getData().getAsString();
Type type = new TypeToken<List<Map<String, Double>>>() {}.getType();
List<Map<String, Double>> list = JsonUtils.GSON.fromJson(classification, type);
Assert.assertTrue(list.get(0).containsKey("tabby"));
Type type = new TypeToken<Map<String, Double>>() {}.getType();
Map<String, Double> map = JsonUtils.GSON.fromJson(classification, type);
Assert.assertTrue(map.containsKey("tabby"));

// Test batch predict
List<Input> batch = Arrays.asList(input, input);
List<Output> ret = predictor.batchPredict(batch);
Assert.assertEquals(ret.size(), 2);
}
}

Expand Down
13 changes: 11 additions & 2 deletions engines/python/src/test/resources/echo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Test Python model example.
"""

import logging
import sys
from djl_python import Input
from djl_python import Output
Expand All @@ -30,10 +31,18 @@ def handle(inputs: Input):
sys.exit()

data = inputs.get_as_bytes()
content_type = inputs.get_property("content-type")

outputs = Output()
outputs.add(data, key="data")
content_type = inputs.get_property("content-type")
if content_type:
outputs.add_property("content-type", content_type)

if inputs.is_batch():
logging.info(f"Dynamic batching size: {inputs.get_batch_size()}.")
batch = inputs.get_batches()
for i, item in enumerate(batch):
outputs.add(item.get_as_bytes(), key="data", batch_index=i)
else:
outputs.add(data, key="data")

return outputs
48 changes: 20 additions & 28 deletions engines/python/src/test/resources/resnet18/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
PyTorch resnet18 model example.
"""

import itertools
import json
import logging
import os
Expand Down Expand Up @@ -78,18 +77,28 @@ def inference(self, inputs):
outputs.add_as_numpy([data.detach().numpy()])
return outputs

image = inputs.get_as_image()
image = self.image_processing(image)
images = torch.stack([image]).to(self.device)
batch = inputs.get_batches()
images = []
for i, item in enumerate(batch):
image = self.image_processing(item.get_as_image())
images.append(image)
images = torch.stack(images).to(self.device)

data = self.model(images)
ps = F.softmax(data, dim=1)
probs, classes = torch.topk(ps, self.topK, dim=1)
probs = probs.tolist()
classes = classes.tolist()

outputs.add_as_json(
self.map_class_to_label(probs, self.mapping, classes))
for i in range(inputs.get_batch_size()):
item = data[i]
ps = F.softmax(item, dim=0)
probs, classes = torch.topk(ps, self.topK)
probs = probs.tolist()
classes = classes.tolist()
result = {
self.mapping[str(classes[i])]: probs[i]
for i in range(self.topK)
}
if inputs.is_batch():
outputs.add_as_json(result, batch_index=i)
else:
outputs.add_as_json(result)
except Exception as e:
logging.exception("resnet18 failed")
# error handling
Expand Down Expand Up @@ -117,23 +126,6 @@ def load_label_mapping(mapping_file_path):
mapping[key] = new_value
return mapping

@staticmethod
def map_class_to_label(probs, mapping=None, lbl_classes=None):
if not (isinstance(probs, list) and isinstance(probs, list)):
raise Exception('Convert classes to list before doing mapping')
if mapping is not None and not isinstance(mapping, dict):
raise Exception('Mapping must be a dict')

if lbl_classes is None:
lbl_classes = itertools.repeat(range(len(probs[0])), len(probs))

results = [{(mapping[str(lbl_class)]
if mapping is not None else str(lbl_class)): prob
for lbl_class, prob in zip(*row)}
for row in zip(lbl_classes, probs)]

return results


_service = Resnet18()

Expand Down
2 changes: 2 additions & 0 deletions serving/src/main/conf/config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ load_models=ALL
# private_key_file=conf/key.pem
# certificate_file=conf/certs.pem
# max_request_size=2000485760
# batch_size=1
# max_batch_delay=300