Skip to content

Commit

Permalink
fix: HTTP status in pydantic response model (#1927)
Browse files Browse the repository at this point in the history
Co-authored-by: Joongi Kim <joongi@lablup.com>
  • Loading branch information
fregataa and achimnol committed Feb 28, 2024
1 parent c95d44a commit 719ff6d
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 18 deletions.
1 change: 1 addition & 0 deletions changes/1927.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow passing HTTP status codes via the pydantic-based API response model objects
11 changes: 6 additions & 5 deletions src/ai/backend/manager/api/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from .session import query_userinfo
from .types import CORSOptions, WebMiddleware
from .utils import (
BaseResponseModel,
get_access_key_scopes,
get_user_uuid_scopes,
pydantic_params_api_handler,
Expand Down Expand Up @@ -112,7 +113,7 @@ class ListServeRequestModel(BaseModel):
name: str | None = Field(default=None)


class SuccessResponseModel(BaseModel):
class SuccessResponseModel(BaseResponseModel):
success: bool = Field(default=True)


Expand Down Expand Up @@ -192,7 +193,7 @@ class RouteInfoModel(BaseModel):
traffic_ratio: NonNegativeFloat


class ServeInfoModel(BaseModel):
class ServeInfoModel(BaseResponseModel):
endpoint_id: uuid.UUID = Field(description="Unique ID referencing the model service.")
name: str = Field(description="Name of the model service.")
desired_session_count: NonNegativeInt = Field(
Expand Down Expand Up @@ -817,7 +818,7 @@ class ScaleRequestModel(BaseModel):
to: int = Field(description="Ideal number of inference sessions")


class ScaleResponseModel(BaseModel):
class ScaleResponseModel(BaseResponseModel):
current_route_count: int
target_count: int

Expand Down Expand Up @@ -963,7 +964,7 @@ class TokenRequestModel(BaseModel):
)


class TokenResponseModel(BaseModel):
class TokenResponseModel(BaseResponseModel):
token: str


Expand Down Expand Up @@ -1044,7 +1045,7 @@ class ErrorInfoModel(BaseModel):
error: dict[str, Any]


class ErrorListResponseModel(BaseModel):
class ErrorListResponseModel(BaseResponseModel):
errors: list[ErrorInfoModel]
retries: int

Expand Down
23 changes: 13 additions & 10 deletions src/ai/backend/manager/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Awaitable,
Callable,
Expand All @@ -31,9 +32,9 @@
import sqlalchemy as sa
import trafaret as t
import yaml
from aiohttp import web, web_response
from aiohttp import web
from aiohttp.typedefs import Handler
from pydantic import BaseModel, TypeAdapter, ValidationError
from pydantic import BaseModel, Field, TypeAdapter, ValidationError

from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import AccessKey
Expand Down Expand Up @@ -212,9 +213,13 @@ async def wrapped(request: web.Request, *args: P.args, **kwargs: P.kwargs) -> TA
return wrap


class BaseResponseModel(BaseModel):
status: Annotated[int, Field(strict=True, ge=100, lt=600)] = 200


TParamModel = TypeVar("TParamModel", bound=BaseModel)
TQueryModel = TypeVar("TQueryModel", bound=BaseModel)
TResponseModel = TypeVar("TResponseModel", bound=BaseModel)
TResponseModel = TypeVar("TResponseModel", bound=BaseResponseModel)

TPydanticResponse: TypeAlias = TResponseModel | list
THandlerFuncWithoutParam: TypeAlias = Callable[
Expand All @@ -226,16 +231,14 @@ async def wrapped(request: web.Request, *args: P.args, **kwargs: P.kwargs) -> TA


def ensure_stream_response_type(
response: TResponseModel | list | TAnyResponse,
response: BaseResponseModel | list[TResponseModel] | web.StreamResponse,
) -> web.StreamResponse:
match response:
case BaseModel():
return web.json_response(response.model_dump(mode="json"))
case BaseResponseModel(status=status):
return web.json_response(response.model_dump(mode="json"), status=status)
case list():
return web.json_response(
TypeAdapter(list[TResponseModel]).dump_python(response, mode="json")
)
case web_response.StreamResponse():
return web.json_response(TypeAdapter(type(response)).dump_python(response, mode="json"))
case web.StreamResponse():
return response
case _:
raise RuntimeError(f"Unsupported response type ({type(response)})")
Expand Down
6 changes: 3 additions & 3 deletions src/ai/backend/manager/api/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
from .manager import ALL_ALLOWED, READ_ALLOWED, server_status_required
from .resource import get_watcher_info
from .utils import (
BaseResponseModel,
check_api_params,
get_user_scopes,
pydantic_params_api_handler,
Expand All @@ -129,9 +130,8 @@
P = ParamSpec("P")


class SuccessResponseModel(BaseModel):
class SuccessResponseModel(BaseResponseModel):
success: bool = Field(default=True)
status: int = Field(default=200)


async def ensure_vfolder_status(
Expand Down Expand Up @@ -2325,7 +2325,7 @@ class IDRequestModel(BaseModel):
)


class CompactVFolderInfoModel(BaseModel):
class CompactVFolderInfoModel(BaseResponseModel):
id: uuid.UUID = Field(description="Unique ID referencing the vfolder.")
name: str = Field(description="Name of the vfolder.")

Expand Down

0 comments on commit 719ff6d

Please sign in to comment.