Skip to content

Commit

Permalink
feat: Flax (#3123)
Browse files Browse the repository at this point in the history
- add a Jax data container.
- a MNIST Flax example.
  • Loading branch information
aarnphm committed Feb 23, 2023
1 parent 685624f commit e5af0c3
Show file tree
Hide file tree
Showing 33 changed files with 1,168 additions and 26 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Expand Up @@ -11,6 +11,7 @@ env:
LINES: 120
COLUMNS: 120
BENTOML_DO_NOT_TRACK: True
PYTEST_PLUGINS: bentoml.testing.pytest.plugin

# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun
defaults:
Expand Down
53 changes: 53 additions & 0 deletions .github/workflows/frameworks.yml
Expand Up @@ -12,6 +12,7 @@ env:
LINES: 120
COLUMNS: 120
BENTOML_DO_NOT_TRACK: True
PYTEST_PLUGINS: bentoml.testing.pytest.plugin

# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun
defaults:
Expand All @@ -33,6 +34,7 @@ jobs:
pytorch: ${{ steps.filter.outputs.pytorch }}
pytorch_lightning: ${{ steps.filter.outputs.pytorch_lightning }}
sklearn: ${{ steps.filter.outputs.sklearn }}
flax: ${{ steps.filter.outputs.flax }}
tensorflow: ${{ steps.filter.outputs.tensorflow }}
torchscript: ${{ steps.filter.outputs.torchscript }}
transformers: ${{ steps.filter.outputs.transformers }}
Expand Down Expand Up @@ -94,6 +96,12 @@ jobs:
- src/bentoml/_internal/frameworks/pytorch.py
- src/bentoml/_internal/frameworks/common/pytorch.py
- tests/integration/frameworks/test_pytorch_unit.py
flax:
- *related
- src/bentoml/flax.py
- src/bentoml/_internal/frameworks/flax.py
- src/bentoml/_internal/frameworks/common/jax.py
- tests/integration/frameworks/models/flax.py
torchscript:
- *related
- src/bentoml/torchscript.py
Expand Down Expand Up @@ -224,6 +232,51 @@ jobs:
files: ./coverage.xml
token: ${{ secrets.CODECOV_TOKEN }}

flax_integration_tests:
needs: diff
if: ${{ (github.event_name == 'pull_request' && needs.diff.outputs.flax == 'true') || github.event_name == 'push' }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0 # fetch all tags and branches
- name: Setup python
uses: actions/setup-python@v4
with:
python-version: 3.8

- name: Get pip cache dir
id: cache-dir
run: |
echo ::set-output name=dir::$(pip cache dir)
- name: Cache pip dependencies
uses: actions/cache@v3
id: cache-pip
with:
path: ${{ steps.cache-dir.outputs.dir }}
key: ${{ runner.os }}-tests-${{ hashFiles('requirements/tests-requirements.txt') }}

- name: Install dependencies
run: |
pip install .
pip install flax jax jaxlib chex tensorflow
pip install -r requirements/tests-requirements.txt
- name: Run tests and generate coverage report
run: |
OPTS=(--cov-config pyproject.toml --cov src/bentoml --cov-append --framework flax)
coverage run -m pytest tests/integration/frameworks/test_frameworks.py "${OPTS[@]}"
- name: Generate coverage
run: coverage xml

- name: Upload test coverage to Codecov
uses: codecov/codecov-action@v3
with:
files: ./coverage.xml
token: ${{ secrets.CODECOV_TOKEN }}

fastai_integration_tests:
needs: diff
if: ${{ (github.event_name == 'pull_request' && needs.diff.outputs.fastai == 'true') || github.event_name == 'push' }}
Expand Down
21 changes: 21 additions & 0 deletions docs/source/reference/frameworks/flax.rst
@@ -0,0 +1,21 @@
====
Flax
====

.. admonition:: About this page

This is an API reference for FLax in BentoML. Please refer to
:doc:`/frameworks/flax` for more information about how to use Flax in BentoML.


.. note::

You can find more examples for **Flax** in our `bentoml/examples https://github.com/bentoml/BentoML/tree/main/examples`_ directory.

.. currentmodule:: bentoml.flax

.. autofunction:: bentoml.flax.save_model

.. autofunction:: bentoml.flax.load_model

.. autofunction:: bentoml.flax.get
1 change: 1 addition & 0 deletions docs/source/reference/frameworks/index.rst
Expand Up @@ -17,6 +17,7 @@ Framework APIs
onnx
sklearn
transformers
flax
tensorflow
xgboost
picklable_model
Expand Down
28 changes: 20 additions & 8 deletions examples/README.md
@@ -1,11 +1,16 @@
# BentoML Examples 🎨 [![Twitter Follow](https://img.shields.io/twitter/follow/bentomlai?style=social)](https://twitter.com/bentomlai) [![Slack](https://img.shields.io/badge/Slack-Join-4A154B?style=social)](https://l.linklyhq.com/l/ktO8)

BentoML is an open platform for machine learning in production. It simplifies model packaging and model management, optimizes model serving workloads to run at production scale, and accelerates the creation, deployment, and monitoring of prediction services.
BentoML is an open platform for machine learning in production. It simplifies
model packaging and model management, optimizes model serving workloads to run
at production scale, and accelerates the creation, deployment, and monitoring of
prediction services.

The repository contains a collection of example projects demonstrating [BentoML](https://github.com/bentoml/BentoML)
usage and best practices.
The repository contains a collection of example projects demonstrating
[BentoML](https://github.com/bentoml/BentoML) usage and best practices.

👉 [Pop into our Slack community!](https://join.slack.bentoml.org) We're happy to help with any issue you face or even just to meet you and hear what you're working on :)
👉 [Pop into our Slack community!](https://join.slack.bentoml.org) We're happy
to help with any issue you face or even just to meet you and hear what you're
working on :)

## Index

Expand Down Expand Up @@ -36,16 +41,23 @@ usage and best practices.
| [tensorflow2_keras](https://github.com/bentoml/BentoML/tree/main/examples/tensorflow2_keras) | TensorFlow, Keras | MNIST | Notebook |
| [tensorflow2_native](https://github.com/bentoml/BentoML/tree/main/examples/tensorflow2_native) | TensforFlow | MNIST | Notebook |
| [xgboost](https://github.com/bentoml/BentoML/tree/main/examples/xgboost) | XGBoost | DMatrix | |
| [flax/MNIST](https://github.com/bentoml/BentoML/tree/main/examples/flax/MNIST) | Flax | MNIST | gRPC, Testing |

## How to contribute

If you have issues running these projects or have suggestions for improvement, use [Github Issues 🐱](https://github.com/bentoml/BentoML/issues/new)
If you have issues running these projects or have suggestions for improvement,
use [Github Issues 🐱](https://github.com/bentoml/BentoML/issues/new)

If you are interested in contributing new projects to this repo, let's talk 🥰 - Join us on [Slack](https://join.slack.com/t/bentoml/shared_invite/enQtNjcyMTY3MjE4NTgzLTU3ZDc1MWM5MzQxMWQxMzJiNTc1MTJmMzYzMTYwMjQ0OGEwNDFmZDkzYWQxNzgxYWNhNjAxZjk4MzI4OGY1Yjg) and share your idea in #bentoml-contributors channel
If you are interested in contributing new projects to this repo, let's talk 🥰 -
Join us on
[Slack](https://join.slack.com/t/bentoml/shared_invite/enQtNjcyMTY3MjE4NTgzLTU3ZDc1MWM5MzQxMWQxMzJiNTc1MTJmMzYzMTYwMjQ0OGEwNDFmZDkzYWQxNzgxYWNhNjAxZjk4MzI4OGY1Yjg)
and share your idea in #bentoml-contributors channel

Before you create a Pull Request, make sure:

- Follow the basic structures and naming conventions of other existing example projects
- Follow the basic structures and naming conventions of other existing example
projects
- Ensure your project runs with the latest version of BentoML

For legacy version prior to v1.0.0, see the [0.13-LTS branch](https://github.com/bentoml/gallery/tree/0.13-LTS).
For legacy version prior to v1.0.0, see the
[0.13-LTS branch](https://github.com/bentoml/gallery/tree/0.13-LTS).
4 changes: 4 additions & 0 deletions examples/flax/MNIST/.bentoignore
@@ -0,0 +1,4 @@
__pycache__/
*.py[cod]
*$py.class
.ipynb_checkpoints
2 changes: 2 additions & 0 deletions examples/flax/MNIST/.gitignore
@@ -0,0 +1,2 @@
events.out*
*.msgpack
18 changes: 18 additions & 0 deletions examples/flax/MNIST/BUILD
@@ -0,0 +1,18 @@
load("@bazel_skylib//rules:write_file.bzl", "write_file")

write_file(
name = "_train_sh",
out = "_train.sh",
content = [
"#!/usr/bin/env bash\n",
"cd $BUILD_WORKING_DIRECTORY\n",
"python -m pip install -r requirements.txt\n",
"python train.py $@",
],
)

sh_binary(
name = "train",
srcs = ["_train.sh"],
data = ["train.py"],
)
34 changes: 34 additions & 0 deletions examples/flax/MNIST/README.md
@@ -0,0 +1,34 @@
# MNIST classifier

This project demonstrates a simple CNN for MNIST classifier served with BentoML.

### Instruction

Run training scripts:

```bash
# run with python3
pip install -r requirements.txt
python3 train.py --num-epochs 2

# run with bazel
bazel run :train -- --num-epochs 2
```

Serve with either gRPC or HTTP:

```bash
bentoml serve-grpc --production --enable-reflection
```

Run the test suite:

```bash
pytest tests
```

To run containerize do:

```bash
bentoml containerize mnist_flax --opt platform=linux/amd64
```
12 changes: 12 additions & 0 deletions examples/flax/MNIST/bentofile.yaml
@@ -0,0 +1,12 @@
service: "service.py:svc"
labels:
owner: bentoml-team
project: mnist-flax
experiemental: true
include:
- "*.py"
python:
lock_packages: false
extra_index_url:
- https://storage.googleapis.com/jax-releases/jax_releases.html
requirements_txt: ./requirements.txt
10 changes: 10 additions & 0 deletions examples/flax/MNIST/requirements-gpu.txt
@@ -0,0 +1,10 @@
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda]==0.4.4
flax>=0.6.1
optax>=0.1.3
bentoml[grpc,grpc-reflection]
tensorflow
tensorflow-datasets
Pillow
pytest
pytest-asyncio
10 changes: 10 additions & 0 deletions examples/flax/MNIST/requirements.txt
@@ -0,0 +1,10 @@
jax[cpu]==0.4.4
flax>=0.6.1
optax>=0.1.3
bentoml[grpc,grpc-reflection]
tensorflow;platform_system!="Darwin"
tensorflow-macos;platform_system=="Darwin"
tensorflow-datasets
Pillow
pytest
pytest-asyncio
Binary file added examples/flax/MNIST/samples/0.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/1.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/2.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/3.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/4.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/5.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/6.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/7.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/8.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/9.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions examples/flax/MNIST/service.py
@@ -0,0 +1,24 @@
from __future__ import annotations

import typing as t
from typing import TYPE_CHECKING

import jax.numpy as jnp
from PIL.Image import Image as PILImage

import bentoml

if TYPE_CHECKING:
from numpy.typing import NDArray

mnist_runner = bentoml.flax.get("mnist_flax").to_runner()

svc = bentoml.Service(name="mnist_flax", runners=[mnist_runner])


@svc.api(input=bentoml.io.Image(), output=bentoml.io.NumpyNdarray())
async def predict(f: PILImage) -> NDArray[t.Any]:
arr = jnp.array(f) / 255.0
arr = jnp.expand_dims(arr, (0, 3))
res = await mnist_runner.async_run(arr)
return res.argmax()
91 changes: 91 additions & 0 deletions examples/flax/MNIST/tests/conftest.py
@@ -0,0 +1,91 @@
from __future__ import annotations

import os
import sys
import typing as t
import contextlib
import subprocess
from typing import TYPE_CHECKING

import psutil
import pytest

import bentoml
from bentoml.testing.server import host_bento
from bentoml._internal.configuration.containers import BentoMLContainer

if TYPE_CHECKING:
from contextlib import ExitStack

from _pytest.main import Session
from _pytest.nodes import Item
from _pytest.config import Config
from _pytest.fixtures import FixtureRequest as _PytestFixtureRequest

class FixtureRequest(_PytestFixtureRequest):
param: str


PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


def pytest_collection_modifyitems(
session: Session, config: Config, items: list[Item]
) -> None:
try:
m = bentoml.models.get("mnist_flax")
print(f"Model exists: {m}")
except bentoml.exceptions.NotFound:
subprocess.check_call(
[
sys.executable,
f"{os.path.join(PROJECT_DIR, 'train.py')}",
"--num-epochs",
"2", # 2 epochs for faster testing
"--lr",
"0.22", # speed up training time
"--enable-tensorboard",
]
)


@pytest.fixture(name="enable_grpc", params=[True, False], scope="session")
def fixture_enable_grpc(request: FixtureRequest) -> str:
return request.param


@pytest.fixture(scope="session", autouse=True)
def clean_context() -> t.Generator[contextlib.ExitStack, None, None]:
stack = contextlib.ExitStack()
yield stack
stack.close()


@pytest.fixture(
name="deployment_mode",
params=["container", "distributed", "standalone"],
scope="session",
)
def fixture_deployment_mode(request: FixtureRequest) -> str:
return request.param


@pytest.mark.usefixtures("change_test_dir")
@pytest.fixture(scope="module")
def host(
deployment_mode: t.Literal["container", "distributed", "standalone"],
clean_context: ExitStack,
enable_grpc: bool,
) -> t.Generator[str, None, None]:
if enable_grpc and psutil.WINDOWS:
pytest.skip("gRPC is not supported on Windows.")

with host_bento(
"service:svc",
deployment_mode=deployment_mode,
project_path=PROJECT_DIR,
bentoml_home=BentoMLContainer.bentoml_home.get(),
clean_context=clean_context,
use_grpc=enable_grpc,
) as _host:
yield _host

0 comments on commit e5af0c3

Please sign in to comment.