Skip to content

Commit

Permalink
PC-1062 - Improve error handling in stream endpoint (#443)
Browse files Browse the repository at this point in the history
* WIP

* Add error handling when exception raised in streaming pipeline

* Report status codes correctly

* VB
  • Loading branch information
rossgray committed Mar 20, 2024
1 parent 3a1c358 commit 7a06c4a
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 47 deletions.
45 changes: 45 additions & 0 deletions pipeline/cloud/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import httpx
import requests
from fastapi.responses import StreamingResponse
from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor
from starlette.types import Send
from tqdm import tqdm

from pipeline import current_configuration
Expand Down Expand Up @@ -286,3 +288,46 @@ def stream(
client = _get_client()
with client.stream(method=method, url=endpoint, json=json_data) as response:
yield response


class StreamingResponseWithStatusCode(StreamingResponse):
"""
Variation of StreamingResponse that can dynamically decide the HTTP status
code, based on the returns from the content iterator (parameter 'content').
Expects the content to yield tuples of (content: str, status_code: int),
instead of just content as it was in the original StreamingResponse. The
parameter status_code in the constructor is ignored, but kept for
compatibility with StreamingResponse.
See
https://github.com/tiangolo/fastapi/discussions/10138#discussioncomment-8216436
for inspiration
"""

async def stream_response(self, send: Send) -> None:
first_chunk_content, self.status_code = await anext(self.body_iterator)
if not isinstance(first_chunk_content, bytes):
first_chunk_content = first_chunk_content.encode(self.charset)

await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
await send(
{
"type": "http.response.body",
"body": first_chunk_content,
"more_body": True,
}
)

# ignore status code after response has started
async for chunk_content, _ in self.body_iterator:
if not isinstance(chunk_content, bytes):
chunk = chunk_content.encode(self.charset)
await send({"type": "http.response.body", "body": chunk, "more_body": True})

await send({"type": "http.response.body", "body": b"", "more_body": False})
5 changes: 4 additions & 1 deletion pipeline/cloud/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ def _stream_pipeline(
try:
result = ClusterRunResult.parse_obj(result_json)
except ValidationError:
_print(f"Unexpected result from streaming run:\n{result_json}")
_print(
f"Unexpected result from streaming run:\n"
f"Status code = {response.status_code}\n{result_json}"
)
return
except Exception:
http.raise_if_http_status_error(response)
Expand Down
2 changes: 2 additions & 0 deletions pipeline/container/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ def run(
args = self._parse_inputs(input_data, self.pipeline)
try:
result = self.pipeline.run(*args)
except RunInputException:
raise
except Exception as exc:
raise RunnableError(exception=exc, traceback=traceback.format_exc())
return result
127 changes: 83 additions & 44 deletions pipeline/container/routes/v4/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import io
import os
import shutil
import traceback
import uuid
from pathlib import Path

from fastapi import APIRouter, Request, Response
from fastapi.responses import StreamingResponse
from loguru import logger

from pipeline.cloud.http import StreamingResponseWithStatusCode
from pipeline.cloud.schemas import pipelines as pipeline_schemas
from pipeline.cloud.schemas import runs as run_schemas
from pipeline.container.manager import Manager
Expand Down Expand Up @@ -93,65 +94,103 @@ async def stream_run(
if not any([output.type == run_schemas.RunIOType.stream for output in outputs]):
raise TypeError("No streaming outputs found")

return StreamingResponse(
return StreamingResponseWithStatusCode(
_stream_run_outputs(response_schema, request),
media_type="application/json",
# hint to disable buffering
headers={"X-Accel-Buffering": "no"},
)


def _fetch_next_outputs(outputs: list[run_schemas.RunOutput]):
next_outputs = []
have_new_streamed_outputs = False
for output in outputs:
if output.type == run_schemas.RunIOType.stream:
if output.value is None:
raise Exception("Stream value was None")

try:
next_value = output.value.__next__()
next_outputs.append(
run_schemas.RunOutput(
type=run_schemas.RunIOType.from_object(next_value),
value=next_value,
file=None,
)
)
have_new_streamed_outputs = True
except StopIteration:
# if no data left for this stream, return None value
next_outputs.append(
run_schemas.RunOutput(
type=run_schemas.RunIOType.none,
value=None,
file=None,
)
)
except Exception as exc:
logger.exception("Pipeline error caught during run streaming")
raise RunnableError(exception=exc, traceback=traceback.format_exc())
else:
next_outputs.append(output)

if not have_new_streamed_outputs:
return
return next_outputs


async def _stream_run_outputs(
response_schema: run_schemas.ContainerRunResult, request: Request
):
outputs = response_schema.outputs or []
"""Generator returning output data for list of outputs.
We iterate over all outputs until we no longer have any streamed data to
output
"""
outputs = response_schema.outputs or []
while True:
next_outputs = []
have_new_streamed_outputs = False
# iterate over all outputs, until we no longer have any streamed data to
# output
for output in outputs:
if output.type == run_schemas.RunIOType.stream:
try:
if output.value is None:
raise Exception("Stream value was None")

next_value = output.value.__next__()
next_outputs.append(
run_schemas.RunOutput(
type=run_schemas.RunIOType.from_object(next_value),
value=next_value,
file=None,
)
)
have_new_streamed_outputs = True
except StopIteration:
# if no data left for this stream, return None value
next_outputs.append(
run_schemas.RunOutput(
type=run_schemas.RunIOType.from_object(next_value),
value=None,
file=None,
)
)
else:
next_outputs.append(output)

if not have_new_streamed_outputs:
return
status_code = 200
try:
next_outputs = _fetch_next_outputs(outputs)
if not next_outputs:
return

new_response_schema = run_schemas.ContainerRunResult(
inputs=response_schema.inputs,
outputs=next_outputs,
error=response_schema.error,
)

new_response_schema = run_schemas.ContainerRunResult(
inputs=response_schema.inputs,
outputs=next_outputs,
error=response_schema.error,
)
except RunnableError as e:
# if we get a pipeline error, return a run error then finish
new_response_schema = run_schemas.ContainerRunResult(
outputs=None,
inputs=response_schema.inputs,
error=run_schemas.ContainerRunError(
type=run_schemas.ContainerRunErrorType.pipeline_error,
message=repr(e.exception),
traceback=e.traceback,
),
)
except Exception as e:
logger.exception("Unexpected error during run streaming")
status_code = 500
new_response_schema = run_schemas.ContainerRunResult(
outputs=None,
inputs=response_schema.inputs,
error=run_schemas.ContainerRunError(
type=run_schemas.ContainerRunErrorType.unknown,
message=repr(e),
traceback=None,
),
)

# serialise response to str and add newline separator
yield new_response_schema.json() + "\n"
yield f"{new_response_schema.json()}\n", status_code

# if request is disconnected terminate all iterators
if await request.is_disconnected():
# if there was an error or request is disconnected terminate all iterators
if new_response_schema.error or await request.is_disconnected():
for output in outputs:
if (
output.type == run_schemas.RunIOType.stream
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pipeline-ai"
version = "2.1.0"
version = "2.1.1"
description = "Pipelines for machine learning workloads."
authors = [
"Paul Hetherington <ph@mystic.ai>",
Expand Down
47 changes: 47 additions & 0 deletions tests/cloud/test_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient

from pipeline.cloud.http import StreamingResponseWithStatusCode


@pytest.fixture
def dummy_app():
app = FastAPI()

@app.get("/stream/{status_code}")
def stream(status_code: int):
"""Dummy endpoint to return a streaming response with status code.
This mimics beahviour when proxying a streaming response, where the
upstream status code is unknown ahead of time (i.e. we can't just set
the status code on the response itself as it comes from the content
stream).
"""
content_stream = iter(
[
("Hello ", status_code),
("World", status_code),
]
)
return StreamingResponseWithStatusCode(
content=content_stream,
headers={"X-Accel-Buffering": "no"},
)

return app


@pytest.mark.parametrize("status_code", [200, 404, 500])
def test_streaming_response_with_status_code(status_code, dummy_app):
"""Test custom response class by setting up a dummy API"""
client = TestClient(dummy_app)
with client.stream("GET", f"/stream/{status_code}") as response:
assert response.status_code == status_code
# check headers are sent correctly
assert response.headers["X-Accel-Buffering"] == "no"
# check full response content
response_data = ""
for chunk in response.iter_text():
response_data += chunk
assert response_data == "Hello World"
57 changes: 56 additions & 1 deletion tests/container/routes/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ async def test_stream_run_outputs():
]

output_values = []
for result in results:
for result, status_code in results:
assert status_code == 200
outputs = json.loads(result)["outputs"]
values = [o["value"] for o in outputs]
output_values.append(values)
Expand All @@ -67,3 +68,57 @@ async def test_stream_run_outputs():
None,
],
]


async def test_stream_run_outputs_when_exception_raised():
"""Test streaming outputs when pipeline raises an exception.
Error should be reported back to the user.
"""

def error_stream():
yield 1
raise Exception("dummy error")

stream_output_one = run_schemas.RunOutput(
type=run_schemas.RunIOType.stream, value=Stream(error_stream()), file=None
)
stream_output_two = run_schemas.RunOutput(
type=run_schemas.RunIOType.stream,
value=Stream(iter(["hello", "world"])),
file=None,
)
static_output = run_schemas.RunOutput(
type=run_schemas.RunIOType.string, value="static output", file=None
)
container_run_result = run_schemas.ContainerRunResult(
inputs=None,
outputs=[stream_output_one, static_output, stream_output_two],
error=None,
)

results = [
(result, status_code)
async for result, status_code in _stream_run_outputs(
container_run_result, DummyRequest()
)
]
data = [json.loads(result) for result, _ in results]
status_codes = [status_code for _, status_code in results]
# even if pipeline_error, status code should be 200
assert all(status_code == 200 for status_code in status_codes)

# exception was raised on 2nd iteration, so we expect there to be a valid
# output followed by an error
assert len(results) == 2

assert data[0]["outputs"] == [
{"type": "integer", "value": 1, "file": None},
{"type": "string", "value": "static output", "file": None},
{"type": "string", "value": "hello", "file": None},
]

error = data[1].get("error")
assert error is not None
assert error["message"] == "Exception('dummy error')"
assert error["type"] == "pipeline_error"

0 comments on commit 7a06c4a

Please sign in to comment.