Skip to content

Commit

Permalink
LangServe: Raise invalid request exceptions (#581)
Browse files Browse the repository at this point in the history
* This PR removes dependency on httpx sse
* After this PR invalid requests will be returned as having a status
code of 422 using a non streaming request
  • Loading branch information
eyurtsev committed Apr 1, 2024
1 parent 2c7fcef commit 014acfc
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 222 deletions.
138 changes: 20 additions & 118 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,8 +1085,6 @@ async def stream(
Originates from the client side. This config must be validated.
server_config: optional server configuration that will be merged
"""
err_event = {}
validation_exception: Optional[BaseException] = None
run_id = None
try:
config, input_ = await self._get_config_and_input(
Expand All @@ -1096,26 +1094,12 @@ async def stream(
server_config=server_config,
)
run_id = config["run_id"]
except BaseException as e:
validation_exception = e
if isinstance(e, RequestValidationError):
err_event = {
"event": "error",
"data": json.dumps(
{"status_code": 422, "message": repr(e.errors())}
),
}
else:
err_event = {
"event": "error",
# Do not expose the error message to the client since
# the message may contain sensitive information.
"data": json.dumps(
{"status_code": 500, "message": "Internal Server Error"}
),
}
except BaseException:
# Exceptions will be properly translated by default FastAPI middleware
# to either 422 (on input validation) or 500 internal server errors.
raise

if self._token_feedback_enabled and not validation_exception:
if self._token_feedback_enabled:
# Create task to create a presigned feedback token
feedback_key: Optional[str] = self._token_feedback_config["key_configs"][0][
"key"
Expand All @@ -1133,15 +1117,6 @@ async def stream(

async def _stream() -> AsyncIterator[dict]:
"""Stream the output of the runnable."""
if validation_exception:
yield err_event
if isinstance(validation_exception, RequestValidationError):
return
else:
raise AssertionError(
"Internal server error"
) from validation_exception

try:
config_w_callbacks = config.copy()
event_aggregator = AsyncEventAggregatorCallback()
Expand Down Expand Up @@ -1199,65 +1174,28 @@ async def stream_log(
View documentation for endpoint at the end of the file.
It's attached to _stream_log_docs endpoint.
"""
err_event = {}
validation_exception: Optional[BaseException] = None
try:
config, input_ = await self._get_config_and_input(
request,
config_hash,
endpoint="stream_log",
server_config=server_config,
)
except BaseException as e:
validation_exception = e
if isinstance(e, RequestValidationError):
err_event = {
"event": "error",
"data": json.dumps(
{"status_code": 422, "message": repr(e.errors())}
),
}
else:
err_event = {
"event": "error",
# Do not expose the error message to the client since
# the message may contain sensitive information.
"data": json.dumps(
{"status_code": 500, "message": "Internal Server Error"}
),
}

except BaseException:
# Exceptions will be properly translated by default FastAPI middleware
# to either 422 (on input validation) or 500 internal server errors.
raise
try:
body = await request.json()
with _with_validation_error_translation():
stream_log_request = StreamLogParameters(**body)
except json.JSONDecodeError:
# Body as text
validation_exception = RequestValidationError(errors=["Invalid JSON body"])
err_event = {
"event": "error",
"data": json.dumps(
{"status_code": 422, "message": "Invalid JSON body"}
),
}
except RequestValidationError as e:
validation_exception = e
err_event = {
"event": "error",
"data": json.dumps({"status_code": 422, "message": repr(e.errors())}),
}
raise RequestValidationError(errors=["Invalid JSON body"])
except RequestValidationError:
raise

async def _stream_log() -> AsyncIterator[dict]:
"""Stream the output of the runnable."""
if validation_exception:
yield err_event
if isinstance(validation_exception, RequestValidationError):
return
else:
raise AssertionError(
"Internal server error"
) from validation_exception

try:
async for chunk in self._runnable.astream_log(
input_,
Expand Down Expand Up @@ -1314,8 +1252,6 @@ async def astream_events(
server_config: Optional[RunnableConfig] = None,
) -> EventSourceResponse:
"""Stream events from the runnable."""
err_event = {}
validation_exception: Optional[BaseException] = None
run_id = None
try:
config, input_ = await self._get_config_and_input(
Expand All @@ -1325,48 +1261,23 @@ async def astream_events(
server_config=server_config,
)
run_id = config["run_id"]
except BaseException as e:
validation_exception = e
if isinstance(e, RequestValidationError):
err_event = {
"event": "error",
"data": json.dumps(
{"status_code": 422, "message": repr(e.errors())}
),
}
else:
err_event = {
"event": "error",
# Do not expose the error message to the client since
# the message may contain sensitive information.
"data": json.dumps(
{"status_code": 500, "message": "Internal Server Error"}
),
}
except BaseException:
# Exceptions will be properly translated by default FastAPI middleware
# to either 422 (on input validation) or 500 internal server errors.
raise

try:
body = await request.json()
with _with_validation_error_translation():
stream_events_request = StreamEventsParameters(**body)
except json.JSONDecodeError:
# Body as text
validation_exception = RequestValidationError(errors=["Invalid JSON body"])
err_event = {
"event": "error",
"data": json.dumps(
{"status_code": 422, "message": "Invalid JSON body"}
),
}
except RequestValidationError as e:
validation_exception = e
err_event = {
"event": "error",
"data": json.dumps({"status_code": 422, "message": repr(e.errors())}),
}
raise RequestValidationError(errors=["Invalid JSON body"])
except RequestValidationError:
raise

feedback_key: Optional[str]

if self._token_feedback_enabled and not validation_exception:
if self._token_feedback_enabled:
# Create task to create a presigned feedback token
feedback_key: str = self._token_feedback_config["key_configs"][0]["key"]
feedback_coro = run_in_executor(
Expand All @@ -1387,15 +1298,6 @@ async def _stream_events() -> AsyncIterator[dict]:
"Please upgrade langchain-core>=0.1.14 to use astream_events"
)

if validation_exception:
yield err_event
if isinstance(validation_exception, RequestValidationError):
return
else:
raise AssertionError(
"Internal server error"
) from validation_exception

has_sent_metadata = False

try:
Expand Down

0 comments on commit 014acfc

Please sign in to comment.