Skip to content

Commit

Permalink
feat(client): disallow media type other than json for HTTP clients (#…
Browse files Browse the repository at this point in the history
…4520)

* feat(client): disallow media type other than json for HTTP clients

Signed-off-by: Frost Ming <me@frostming.com>

* fix: check media type on server side

Signed-off-by: Frost Ming <me@frostming.com>

---------

Signed-off-by: Frost Ming <me@frostming.com>
  • Loading branch information
frostming committed Feb 26, 2024
1 parent f5c3322 commit fd70379
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 38 deletions.
57 changes: 55 additions & 2 deletions src/_bentoml_impl/client/__init__.py
@@ -1,7 +1,7 @@
from .base import AbstractClient
from .http import AsyncHTTPClient
from .http import AsyncHTTPClient as _AsyncHTTPClient
from .http import HTTPClient
from .http import SyncHTTPClient
from .http import SyncHTTPClient as _SyncHTTPClient
from .proxy import RemoteProxy

__all__ = [
Expand All @@ -11,3 +11,56 @@
"AbstractClient",
"RemoteProxy",
]


class SyncHTTPClient(_SyncHTTPClient):
"""A synchronous client for BentoML service.
Args:
url (str): URL of the BentoML service.
token (str, optional): Authentication token. Defaults to None.
timeout (float, optional): Timeout for the client. Defaults to 30.
Example::
with SyncHTTPClient("http://localhost:3000") as client:
resp = client.call("classify", input_series=[[1,2,3,4]])
assert resp == [0]
# Or using named method directly
resp = client.classify(input_series=[[1,2,3,4]])
assert resp == [0]
"""

def __init__(
self, url: str, *, token: str | None = None, timeout: float = 30
) -> None:
super().__init__(url, token=token, timeout=timeout)


class AsyncHTTPClient(_AsyncHTTPClient):
"""An asynchronous client for BentoML service.
Args:
url (str): URL of the BentoML service.
token (str, optional): Authentication token. Defaults to None.
timeout (float, optional): Timeout for the client. Defaults to 30.
Example::
async with AsyncHTTPClient("http://localhost:3000") as client:
resp = await client.call("classify", input_series=[[1,2,3,4]])
assert resp == [0]
# Or using named method directly
resp = await client.classify(input_series=[[1,2,3,4]])
assert resp == [0]
# Streaming
resp = client.stream(prompt="hello")
async for data in resp:
print(data)
"""

def __init__(
self, url: str, *, token: str | None = None, timeout: float = 30
) -> None:
super().__init__(url, token=token, timeout=timeout)
28 changes: 2 additions & 26 deletions src/_bentoml_impl/client/http.py
Expand Up @@ -340,14 +340,7 @@ def _get_stream(
class SyncHTTPClient(HTTPClient[httpx.Client]):
"""A synchronous client for BentoML service.
Example:
with SyncHTTPClient("http://localhost:3000") as client:
resp = client.call("classify", input_series=[[1,2,3,4]])
assert resp == [0]
# Or using named method directly
resp = client.classify(input_series=[[1,2,3,4]])
assert resp == [0]
.. note:: Inner usage ONLY
"""

client_cls = httpx.Client
Expand Down Expand Up @@ -449,24 +442,7 @@ def _parse_file_response(
class AsyncHTTPClient(HTTPClient[httpx.AsyncClient]):
"""An asynchronous client for BentoML service.
Example:
async with AsyncHTTPClient("http://localhost:3000") as client:
resp = await client.call("classify", input_series=[[1,2,3,4]])
assert resp == [0]
# Or using named method directly
resp = await client.classify(input_series=[[1,2,3,4]])
assert resp == [0]
.. note::
If the endpoint returns an async generator, it should be awaited before iterating.
Example:
resp = await client.stream(prompt="hello")
async for data in resp:
print(data)
.. note:: Inner usage ONLY
"""

client_cls = httpx.AsyncClient
Expand Down
7 changes: 4 additions & 3 deletions src/_bentoml_impl/loader.py
Expand Up @@ -106,12 +106,13 @@ def import_service(
from _bentoml_sdk import Service

if bento_path is None:
bento_path = pathlib.Path(".").absolute()
bento_path = pathlib.Path(".")
bento_path = bento_path.absolute()

# patch python path if needed
if bento_path != pathlib.Path("."):
if bento_path != pathlib.Path(".").absolute():
# a project
extra_python_path = str(bento_path.absolute())
extra_python_path = str(bento_path)
sys.path.insert(0, extra_python_path)
else:
# a project under current directory
Expand Down
13 changes: 11 additions & 2 deletions src/_bentoml_impl/server/app.py
Expand Up @@ -5,6 +5,7 @@
import inspect
import sys
import typing as t
from http import HTTPStatus
from pathlib import Path

import anyio
Expand Down Expand Up @@ -61,6 +62,7 @@ class ServiceAppFactory(BaseAppFactory):
def __init__(
self,
service: Service[t.Any],
is_main: bool = False,
enable_metrics: bool = Provide[
BentoMLContainer.api_server_config.metrics.enabled
],
Expand All @@ -74,6 +76,7 @@ def __init__(

self.service = service
self.enable_metrics = enable_metrics
self.is_main = is_main
timeout = traffic.get("timeout")
max_concurrency = traffic.get("max_concurrency")
self.enable_access_control = enable_access_control
Expand Down Expand Up @@ -175,7 +178,7 @@ async def handle_bentoml_exception(self, req: Request, exc: Exception) -> Respon
else:
return JSONResponse("", status_code=status)

def __call__(self, is_main: bool = False) -> Starlette:
def __call__(self) -> Starlette:
app = super().__call__()

app.add_exception_handler(
Expand All @@ -184,7 +187,7 @@ def __call__(self, is_main: bool = False) -> Starlette:
app.add_exception_handler(BentoMLException, self.handle_bentoml_exception)
app.add_exception_handler(Exception, self.handle_uncaught_exception)
app.add_route("/schema.json", self.schema_view, name="schema")
if is_main:
if self.is_main:
if BentoMLContainer.new_index:
assets = Path(__file__).parent / "assets"
app.mount("/assets", StaticFiles(directory=assets), name="assets")
Expand Down Expand Up @@ -407,6 +410,12 @@ async def api_endpoint(self, name: str, request: Request) -> Response:

media_type = request.headers.get("Content-Type", "application/json")
media_type = media_type.split(";")[0].strip()
if self.is_main and media_type == "application/vnd.bentoml+pickle":
# Disallow pickle media type for main service for security reasons
raise BentoMLException(
"Pickle media type is not allowed for main service",
error_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
)

method = self.service.apis[name]
func = getattr(self._service_instance, name)
Expand Down
1 change: 0 additions & 1 deletion src/_bentoml_impl/server/serving.py
Expand Up @@ -260,7 +260,6 @@ def serve_http(
"$(CIRCUS.WID)",
"--prometheus-dir",
prometheus_dir,
"--main",
*ssl_args,
*timeout_args,
]
Expand Down
8 changes: 4 additions & 4 deletions src/_bentoml_impl/worker/service.py
Expand Up @@ -95,7 +95,6 @@
type=click.INT,
help="Specify the timeout for API server",
)
@click.option("--main", "is_main", type=click.BOOL, default=False, is_flag=True)
def main(
bento_identifier: str,
service_name: str,
Expand All @@ -114,7 +113,6 @@ def main(
ssl_ciphers: str | None,
development_mode: bool,
timeout: int,
is_main: bool = False,
):
"""
Start a HTTP server worker for given service.
Expand Down Expand Up @@ -165,8 +163,10 @@ def main(
BentoMLContainer.prometheus_multiproc_dir.set(prometheus_dir)
component_context.component_name = service.name

app_factory = ServiceAppFactory(service)
asgi_app = app_factory(is_main=is_main)
asgi_app = ServiceAppFactory(
service, is_main=component_context.component_type == "entry_service"
)()

uvicorn_extra_options: dict[str, t.Any] = {}
if ssl_version is not None:
uvicorn_extra_options["ssl_version"] = ssl_version
Expand Down

0 comments on commit fd70379

Please sign in to comment.