Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Service.to_asgi() method #4572

Merged
merged 1 commit into from Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/_bentoml_impl/worker/service.py
Expand Up @@ -137,8 +137,6 @@ def main(
from bentoml._internal.context import server_context
from bentoml._internal.log import configure_server_logging

from ..server.app import ServiceAppFactory

if runner_map:
BentoMLContainer.remote_runner_mapping.set(
t.cast(t.Dict[str, str], json.loads(runner_map))
Expand All @@ -163,9 +161,9 @@ def main(
BentoMLContainer.prometheus_multiproc_dir.set(prometheus_dir)
server_context.service_name = service.name

asgi_app = ServiceAppFactory(
service, is_main=server_context.service_type == "entry_service"
)()
asgi_app = service.to_asgi(
is_main=server_context.service_type == "entry_service", init=False
)

uvicorn_extra_options: dict[str, t.Any] = {}
if ssl_version is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/_bentoml_sdk/decorators.py
Expand Up @@ -107,7 +107,7 @@ def wrapper(func: t.Callable[t.Concatenate[t.Any, P], R]) -> APIMethod[P, R]:

def mount_asgi_app(
app: ASGIApp, *, path: str = "/", name: str | None = None
) -> t.Callable[[T], T]:
) -> t.Callable[[R], R]:
"""Mount an ASGI app to the service.

Args:
Expand All @@ -119,7 +119,7 @@ def mount_asgi_app(
from ._internals import make_fastapi_class_views
from .service import Service

def decorator(obj: T) -> T:
def decorator(obj: R) -> R:
lazy_fastapi = LazyType["FastAPI"]("fastapi.FastAPI")

if isinstance(obj, Service):
Expand Down
9 changes: 9 additions & 0 deletions src/_bentoml_sdk/service/factory.py
Expand Up @@ -198,6 +198,15 @@ def import_string(self) -> str:
)
return self._import_str

def to_asgi(self, is_main: bool = True, init: bool = False) -> ext.ASGIApp:
from _bentoml_impl.server.app import ServiceAppFactory

self.inject_config()
factory = ServiceAppFactory(self, is_main=is_main)
if init:
factory.create_instance()
return factory()

def mount_asgi_app(
self, app: ext.ASGIApp, path: str = "/", name: str | None = None
) -> None:
Expand Down
46 changes: 14 additions & 32 deletions tests/unit/bentoml_io/test_decorators.py
Expand Up @@ -8,58 +8,40 @@ async def test_mount_asgi_app():
import httpx
from fastapi import FastAPI

from _bentoml_impl.server.app import ServiceAppFactory

app = FastAPI()

@bentoml.mount_asgi_app(app, path="/test")
@bentoml.service
@bentoml.service(metrics={"enabled": False})
class TestService:
@app.get("/hello")
def hello(self):
return {"message": "Hello, world!"}

TestService.inject_config()

factory = ServiceAppFactory(TestService, is_main=True, enable_metrics=False)
factory.create_instance()
try:
async with httpx.AsyncClient(
app=factory(), base_url="http://testserver"
) as client:
response = await client.get("/test/hello")
assert response.status_code == 200
assert response.json()["message"] == "Hello, world!"
finally:
await factory.destroy_instance()
async with httpx.AsyncClient(
app=TestService.to_asgi(init=True), base_url="http://testserver"
) as client:
response = await client.get("/test/hello")
assert response.status_code == 200
assert response.json()["message"] == "Hello, world!"


@pytest.mark.asyncio
async def test_mount_asgi_app_later():
import httpx
from fastapi import FastAPI

from _bentoml_impl.server.app import ServiceAppFactory

app = FastAPI()

@bentoml.service
@bentoml.service(metrics={"enabled": False})
@bentoml.mount_asgi_app(app, path="/test")
class TestService:
@app.get("/hello")
def hello(self):
return {"message": "Hello, world!"}

TestService.inject_config()

factory = ServiceAppFactory(TestService, is_main=True, enable_metrics=False)
factory.create_instance()
try:
async with httpx.AsyncClient(
app=factory(), base_url="http://testserver"
) as client:
response = await client.get("/test/hello")
assert response.status_code == 200
assert response.json()["message"] == "Hello, world!"
finally:
await factory.destroy_instance()
async with httpx.AsyncClient(
app=TestService.to_asgi(init=True), base_url="http://testserver"
) as client:
response = await client.get("/test/hello")
assert response.status_code == 200
assert response.json()["message"] == "Hello, world!"