Skip to content

Commit

Permalink
[feature]Added support for PaddlePaddle model artifact (#1523)
Browse files Browse the repository at this point in the history
  • Loading branch information
cqvu committed Mar 24, 2021
1 parent dfd68c0 commit ce9bdd4
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 1 deletion.
19 changes: 19 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,25 @@ jobs:
- name: Upload test coverage to Codecov
uses: codecov/codecov-action@v1.0.12

paddle_integration_tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup python
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install test dependencies
run: ./ci/install_test_deps.sh
- name: Install paddle
run: python -m pip install paddlepaddle
- name: Run tests
run: ./ci/test_project.sh tests/integration/projects/paddle
- name: Upload test coverage to Codecov
uses: codecov/codecov-action@v1.0.12

api_server_integration_tests:
name: API Server Integration Tests (${{ matrix.os }})
runs-on: ${{ matrix.os }}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Standarlize model serving and deployment workflow for teams:
* Transformers - [Docs](https://docs.bentoml.org/en/latest/frameworks.html#transformers)
* Gluon - [Docs](https://docs.bentoml.org/en/latest/frameworks.html#gluon)
* Detectron - [Docs](https://docs.bentoml.org/en/latest/frameworks.html#detectron)

* Paddle - [Docs](https://docs.bentoml.org/en/latest/frameworks.html#paddle)


### Deployment Options
Expand Down
2 changes: 2 additions & 0 deletions bentoml/artifact/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from bentoml.frameworks.spacy import SpacyModelArtifact
from bentoml.frameworks.tensorflow import TensorflowSavedModelArtifact
from bentoml.frameworks.xgboost import XgboostModelArtifact
from bentoml.frameworks.paddle import PaddlePaddleModelArtifact # noqa: E402

__all__ = [
"ArtifactCollection",
Expand All @@ -64,4 +65,5 @@
"TextFileArtifact",
"XgboostModelArtifact",
"PytorchLightningModelArtifact",
"PaddlePaddleModelArtifact",
]
108 changes: 108 additions & 0 deletions bentoml/frameworks/paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os

from bentoml.exceptions import MissingDependencyException
from bentoml.service.artifacts import BentoServiceArtifact
from bentoml.service.env import BentoServiceEnv
import tempfile


try:
import paddle
import paddle.inference as paddle_infer
except ImportError:
paddle = None


class PaddlePaddleModelArtifact(BentoServiceArtifact):
"""Abstraction for saving/loading PaddlePaddle models
Args:
name (string): name of the artifact
Raises:
MissingDependencyException: paddle package is required for
PaddlePaddleModelArtifact
Example usage:
>>> import pandas as pd
>>>
>>> from bentoml import env, artifacts, api, BentoService
>>> from bentoml.adapters import DataframeInput
>>> from bentoml.frameworks.paddle import PaddlePaddleModelArtifact
>>>
>>> @env(infer_pip_packages=True)
>>> @artifacts([PaddlePaddleModelArtifact('model')])
>>> class PaddleService(BentoService):
>>> @api(input=DataframeInput(), batch=True)
>>> def predict(self, df: pd.DataFrame):
>>> input_data = df.to_numpy().astype(np.float32)
>>> predictor = self.artifacts.model
>>>
>>> input_names = predictor.get_input_names()
>>> input_handle = predictor.get_input_handle(input_names[0])
>>> input_handle.reshape(input_data.shape)
>>> input_handle.copy_from_cpu(input_data)
>>>
>>> predictor.run()
>>>
>>> output_names = predictor.get_output_names()
>>> output_handle = predictor.get_output_handle(output_names[0])
>>> output_data = output_handle.copy_to_cpu()
>>> return output_data
>>>
>>> service = PaddleService()
>>>
>>> service.pack('model', model_to_save)
"""

def __init__(self, name: str):
super(PaddlePaddleModelArtifact, self).__init__(name)
self._model = None
self._predictor = None
self._model_path = None

if paddle is None:
raise MissingDependencyException(
"paddlepaddle package is required to use PaddlePaddleModelArtifact"
)

def pack(self, model): # pylint:disable=arguments-differ
self._model = model
return self

def load(self, path):
model = paddle.jit.load(self._file_path(path))
model = paddle.jit.to_static(model, input_spec=model._input_spec())
return self.pack(model)

def _file_path(self, base_path):
return os.path.join(base_path, self.name)

def save(self, dst):
self._save(dst)

def _save(self, dst):
# Override the model path if temp dir was set
self._model_path = self._file_path(dst)
paddle.jit.save(self._model, self._model_path)

def get(self):
# Create predictor, if one doesn't exist, when inference is run
if not self._predictor:
# If model isn't saved, save model to a temp dir
# because predictor init requires the path to a saved model
if self._model_path is None:
self._model_path = tempfile.TemporaryDirectory().name
self._save(self._model_path)

config = paddle_infer.Config(
self._model_path + ".pdmodel", self._model_path + ".pdiparams"
)
config.enable_memory_optim()
predictor = paddle_infer.create_predictor(config)
self._predictor = predictor
return self._predictor

def set_dependencies(self, env: BentoServiceEnv):
env.add_pip_packages(['paddlepaddle'])
11 changes: 11 additions & 0 deletions docs/source/frameworks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,17 @@ Detectron
.. autoclass:: bentoml.frameworks.detectron.DetectronModelArtifact


=========
Paddle
=========

Example Projects:

* Boston Housing Prediction - `Google Colab <https://colab.research.google.com/github/bentoml/gallery/blob/master/paddlepaddle/LinearRegression/LinearRegression.ipynb>`__ / `Notebook Source <https://github.com/bentoml/gallery/blob/master/paddlepaddle/LinearRegression/LinearRegression.ipynb>`__

.. autoclass:: bentoml.frameworks.paddle.PaddleModelArtifact


.. spelling::

MLModel
Expand Down
Empty file.
58 changes: 58 additions & 0 deletions tests/integration/projects/paddle/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pathlib
import sys

import paddle
from paddle import nn
from paddle.static import InputSpec
from bentoml.frameworks.paddle import PaddlePaddleModelArtifact


BATCH_SIZE = 8
BATCH_NUM = 4
EPOCH_NUM = 5

IN_FEATURES = 13
OUT_FEATURES = 1


class Model(nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.fc = nn.Linear(IN_FEATURES, OUT_FEATURES)

@paddle.jit.to_static(input_spec=[InputSpec(shape=[IN_FEATURES], dtype='float32')])
def forward(self, x):
return self.fc(x)


def train(model, loader, loss_fn, opt):
model.train()
for epoch_id in range(EPOCH_NUM):
for batch_id, (feature, label) in enumerate(loader()):
out = model(feature)
loss = loss_fn(out, label)
loss.backward()
opt.step()
opt.clear_grad()


def pack_models(path):
model = Model()
loss = nn.MSELoss()
adam = paddle.optimizer.Adam(parameters=model.parameters())

train_data = paddle.text.datasets.UCIHousing(mode="train")

loader = paddle.io.DataLoader(
train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2
)

train(model, loader, loss, adam)

PaddlePaddleModelArtifact("model").pack(model).save(path)


if __name__ == "__main__":
artifacts_path = sys.argv[1]
pathlib.Path(artifacts_path).mkdir(parents=True, exist_ok=True)
pack_models(artifacts_path)
43 changes: 43 additions & 0 deletions tests/integration/projects/paddle/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pathlib
import sys

import numpy as np

import bentoml
from bentoml.adapters import DataframeInput
from bentoml.frameworks.paddle import PaddlePaddleModelArtifact


@bentoml.env(infer_pip_packages=True)
@bentoml.artifacts([PaddlePaddleModelArtifact('model')])
class PaddleService(bentoml.BentoService):
@bentoml.api(input=DataframeInput(), batch=True)
def predict(self, df):
input_data = df.to_numpy().astype(np.float32)

predictor = self.artifacts.model
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])

input_handle.reshape(input_data.shape)
input_handle.copy_from_cpu(input_data)

predictor.run()

output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu()
return output_data


if __name__ == "__main__":
artifacts_path = sys.argv[1]
bento_dist_path = sys.argv[2]
service = PaddleService()

from model.model import Model # noqa # pylint: disable=unused-import

service.artifacts.load_all(artifacts_path)

pathlib.Path(bento_dist_path).mkdir(parents=True, exist_ok=True)
service.save_to_dir(bento_dist_path)
28 changes: 28 additions & 0 deletions tests/integration/projects/paddle/tests/test_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pandas as pd
import numpy as np

test_df = pd.DataFrame(
[
[
-0.0405441,
0.06636364,
-0.32356227,
-0.06916996,
-0.03435197,
0.05563625,
-0.03475696,
0.02682186,
-0.37171335,
-0.21419304,
-0.33569506,
0.10143217,
-0.21172912,
]
]
)


def test_paddle_artifact_pack(service):
pred = service.predict(test_df)
assert isinstance(pred, np.ndarray), 'Run inference'
assert pred.shape == (1, 1)

0 comments on commit ce9bdd4

Please sign in to comment.