From bbe4dde5f23e075a748a490e8afad174b8f5e049 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sat, 10 Dec 2022 18:08:09 -0800 Subject: [PATCH] [python] Adds .npz input support for Python engine --- engines/python/setup/djl_python/np_util.py | 9 +++++ .../ai/djl/python/engine/PyEngineTest.java | 34 +++++++++++++++++++ .../src/test/resources/resnet18/model.py | 9 +++-- 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/engines/python/setup/djl_python/np_util.py b/engines/python/setup/djl_python/np_util.py index ab258c2fb..8394a8c22 100644 --- a/engines/python/setup/djl_python/np_util.py +++ b/engines/python/setup/djl_python/np_util.py @@ -11,6 +11,7 @@ # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. +import io import struct import numpy as np @@ -126,6 +127,14 @@ def from_nd_list(encoded: bytearray) -> list: :param encoded: bytearray :return: list of numpy array """ + if len(encoded) >= 4 and encoded[0] == 80 and encoded[1] == 75: + # Assume the input is npz format (PK) + result = [] + npz = np.load(io.BytesIO(encoded)) + for item in npz.items(): + result.append(item[1]) + return result + idx = 0 num_ele, idx = get_int(encoded, idx) result = [] diff --git a/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java b/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java index cf8553c38..533fafb84 100644 --- a/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java +++ b/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java @@ -211,6 +211,40 @@ public void testResnet18() throws TranslateException, IOException, ModelExceptio List batch = Arrays.asList(input, input); List ret = predictor.batchPredict(batch); Assert.assertEquals(ret.size(), 2); + + // Test npz input + NDArray ones = model.getNDManager().ones(new Shape(1, 3, 224, 224)); + NDList list = new NDList(); + list.add(ones); + byte[] buf = list.encode(true); + + input = new Input(); + input.add("data", buf); + input.addProperty("Content-Type", "tensor/npz"); + output = predictor.predict(input); + String contentType = output.getProperty("Content-Type", ""); + Assert.assertEquals(contentType, "tensor/npz"); + NDList nd = output.getDataAsNDList(model.getNDManager()); + Assert.assertEquals(nd.get(0).toArray().length, 1000); + } + } + + @Test + public void testResnet18BinaryMode() throws TranslateException, IOException, ModelException { + if (!Boolean.getBoolean("nightly")) { + return; + } + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optModelPath(Paths.get("src/test/resources/resnet18")) + .optEngine("Python") + .build(); + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + NDArray x = model.getNDManager().ones(new Shape(1, 3, 224, 224)); + NDList ret = predictor.predict(new NDList(x)); + Assert.assertEquals(ret.get(0).getShape().get(1), 1000); } } diff --git a/engines/python/src/test/resources/resnet18/model.py b/engines/python/src/test/resources/resnet18/model.py index a6ee751b1..53e7abff6 100644 --- a/engines/python/src/test/resources/resnet18/model.py +++ b/engines/python/src/test/resources/resnet18/model.py @@ -77,8 +77,13 @@ def inference(self, inputs): images = torch.from_numpy(inputs.get_as_numpy()[0]).to( self.device) data = self.model(images).to(torch.device('cpu')) - outputs.add_property("Content-Type", "tensor/ndlist") - outputs.add_as_numpy([data.detach().numpy()]) + accept = inputs.get_property("Accept") + if accept == "tensor/npz" or content_type == "tensor/npz": + outputs.add_property("Content-Type", "tensor/npz") + outputs.add_as_numpy([data.detach().numpy()]) + else: + outputs.add_property("Content-Type", "tensor/ndlist") + outputs.add_as_npz([data.detach().numpy()]) return outputs batch = inputs.get_batches()