Skip to content

Commit

Permalink
Add branch to truss server error handling that passes through fastapi…
Browse files Browse the repository at this point in the history
….HTTPExceptions if raised in model code. (#886)
  • Loading branch information
marius-baseten committed Apr 3, 2024
1 parent d0d719c commit e7c260d
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 12 deletions.
46 changes: 34 additions & 12 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,18 @@
from multiprocessing import Lock
from pathlib import Path
from threading import Thread
from typing import Any, AsyncGenerator, Dict, Optional, Set, Union
from typing import (
Any,
AsyncGenerator,
Callable,
Coroutine,
Dict,
NoReturn,
Optional,
Set,
TypeVar,
Union,
)

import pydantic
from anyio import Semaphore, to_thread
Expand All @@ -21,6 +32,7 @@
from fastapi import HTTPException
from pydantic import BaseModel
from shared.secrets_resolver import SecretsResolver
from typing_extensions import ParamSpec

MODEL_BASENAME = "model"

Expand Down Expand Up @@ -404,28 +416,38 @@ def _elapsed_ms(since_micro_seconds: float) -> int:
return int((time.perf_counter() - since_micro_seconds) * 1000)


def _handle_exception():
def _handle_exception(exception: Exception) -> NoReturn:
# Note that logger.exception logs the stacktrace, such that the user can
# debug this error from the logs.
logging.exception("Internal Server Error")
raise HTTPException(status_code=500, detail="Internal Server Error")
if isinstance(exception, HTTPException):
logging.exception("Model raised HTTPException")
raise exception
else:
logging.exception("Internal Server Error")
raise HTTPException(status_code=500, detail="Internal Server Error")


def _intercept_exceptions_sync(func):
def inner(*args, **kwargs):
_P = ParamSpec("_P")
_R = TypeVar("_R")


def _intercept_exceptions_sync(func: Callable[_P, _R]) -> Callable[_P, _R]:
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
try:
return func(*args, **kwargs)
except Exception:
_handle_exception()
except Exception as e:
_handle_exception(e)

return inner


def _intercept_exceptions_async(func):
async def inner(*args, **kwargs):
def _intercept_exceptions_async(
func: Callable[_P, Coroutine[Any, Any, _R]]
) -> Callable[_P, Coroutine[Any, Any, _R]]:
async def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
try:
return await func(*args, **kwargs)
except Exception:
_handle_exception()
except Exception as e:
_handle_exception(e)

return inner
38 changes: 38 additions & 0 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,44 @@ async def predict(self, request):
assert "Internal Server Error" in response.json()["error"]


@pytest.mark.integration
def test_truss_with_user_errors():
"""Test that user-code raised `fastapi.HTTPExceptions` are passed through as is."""
model = """
import fastapi
class Model:
def predict(self, request):
raise fastapi.HTTPException(status_code=500, detail="My custom message.")
"""

config = "model_name: error-truss"

with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
truss_dir = Path(tmp_work_dir, "truss")

create_truss(truss_dir, config, textwrap.dedent(model))

tr = TrussHandle(truss_dir)
container = tr.docker_run(
local_port=8090, detach=True, wait_for_server_ready=True
)
truss_server_addr = "http://localhost:8090"
full_url = f"{truss_server_addr}/v1/models/model:predict"

response = requests.post(full_url, json={})
assert response.status_code == 500
assert "error" in response.json()

assert_logs_contain_error(
container.logs(),
"HTTPException: 500: My custom message.",
"Model raised HTTPException",
)

assert "My custom message." in response.json()["error"]


@pytest.mark.integration
def test_slow_truss():
with ensure_kill_all():
Expand Down

0 comments on commit e7c260d

Please sign in to comment.