From 7a06c4ac9d50838d0c068e332a8c4fe97a1bd2a2 Mon Sep 17 00:00:00 2001 From: Ross Date: Wed, 20 Mar 2024 17:18:46 +0000 Subject: [PATCH] PC-1062 - Improve error handling in stream endpoint (#443) * WIP * Add error handling when exception raised in streaming pipeline * Report status codes correctly * VB --- pipeline/cloud/http.py | 45 ++++++++++ pipeline/cloud/pipelines.py | 5 +- pipeline/container/manager.py | 2 + pipeline/container/routes/v4/runs.py | 127 +++++++++++++++++---------- pyproject.toml | 2 +- tests/cloud/test_http.py | 47 ++++++++++ tests/container/routes/test_runs.py | 57 +++++++++++- 7 files changed, 238 insertions(+), 47 deletions(-) create mode 100644 tests/cloud/test_http.py diff --git a/pipeline/cloud/http.py b/pipeline/cloud/http.py index fa87d473..df9ed488 100644 --- a/pipeline/cloud/http.py +++ b/pipeline/cloud/http.py @@ -7,7 +7,9 @@ import httpx import requests +from fastapi.responses import StreamingResponse from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor +from starlette.types import Send from tqdm import tqdm from pipeline import current_configuration @@ -286,3 +288,46 @@ def stream( client = _get_client() with client.stream(method=method, url=endpoint, json=json_data) as response: yield response + + +class StreamingResponseWithStatusCode(StreamingResponse): + """ + Variation of StreamingResponse that can dynamically decide the HTTP status + code, based on the returns from the content iterator (parameter 'content'). + Expects the content to yield tuples of (content: str, status_code: int), + instead of just content as it was in the original StreamingResponse. The + parameter status_code in the constructor is ignored, but kept for + compatibility with StreamingResponse. + + See + https://github.com/tiangolo/fastapi/discussions/10138#discussioncomment-8216436 + for inspiration + """ + + async def stream_response(self, send: Send) -> None: + first_chunk_content, self.status_code = await anext(self.body_iterator) + if not isinstance(first_chunk_content, bytes): + first_chunk_content = first_chunk_content.encode(self.charset) + + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + await send( + { + "type": "http.response.body", + "body": first_chunk_content, + "more_body": True, + } + ) + + # ignore status code after response has started + async for chunk_content, _ in self.body_iterator: + if not isinstance(chunk_content, bytes): + chunk = chunk_content.encode(self.charset) + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + + await send({"type": "http.response.body", "body": b"", "more_body": False}) diff --git a/pipeline/cloud/pipelines.py b/pipeline/cloud/pipelines.py index 3302759b..fb1bbae8 100644 --- a/pipeline/cloud/pipelines.py +++ b/pipeline/cloud/pipelines.py @@ -145,7 +145,10 @@ def _stream_pipeline( try: result = ClusterRunResult.parse_obj(result_json) except ValidationError: - _print(f"Unexpected result from streaming run:\n{result_json}") + _print( + f"Unexpected result from streaming run:\n" + f"Status code = {response.status_code}\n{result_json}" + ) return except Exception: http.raise_if_http_status_error(response) diff --git a/pipeline/container/manager.py b/pipeline/container/manager.py index bb9fb815..ee36bf2c 100644 --- a/pipeline/container/manager.py +++ b/pipeline/container/manager.py @@ -241,6 +241,8 @@ def run( args = self._parse_inputs(input_data, self.pipeline) try: result = self.pipeline.run(*args) + except RunInputException: + raise except Exception as exc: raise RunnableError(exception=exc, traceback=traceback.format_exc()) return result diff --git a/pipeline/container/routes/v4/runs.py b/pipeline/container/routes/v4/runs.py index 59489fc7..7450eddf 100644 --- a/pipeline/container/routes/v4/runs.py +++ b/pipeline/container/routes/v4/runs.py @@ -2,13 +2,14 @@ import io import os import shutil +import traceback import uuid from pathlib import Path from fastapi import APIRouter, Request, Response -from fastapi.responses import StreamingResponse from loguru import logger +from pipeline.cloud.http import StreamingResponseWithStatusCode from pipeline.cloud.schemas import pipelines as pipeline_schemas from pipeline.cloud.schemas import runs as run_schemas from pipeline.container.manager import Manager @@ -93,7 +94,7 @@ async def stream_run( if not any([output.type == run_schemas.RunIOType.stream for output in outputs]): raise TypeError("No streaming outputs found") - return StreamingResponse( + return StreamingResponseWithStatusCode( _stream_run_outputs(response_schema, request), media_type="application/json", # hint to disable buffering @@ -101,57 +102,95 @@ async def stream_run( ) +def _fetch_next_outputs(outputs: list[run_schemas.RunOutput]): + next_outputs = [] + have_new_streamed_outputs = False + for output in outputs: + if output.type == run_schemas.RunIOType.stream: + if output.value is None: + raise Exception("Stream value was None") + + try: + next_value = output.value.__next__() + next_outputs.append( + run_schemas.RunOutput( + type=run_schemas.RunIOType.from_object(next_value), + value=next_value, + file=None, + ) + ) + have_new_streamed_outputs = True + except StopIteration: + # if no data left for this stream, return None value + next_outputs.append( + run_schemas.RunOutput( + type=run_schemas.RunIOType.none, + value=None, + file=None, + ) + ) + except Exception as exc: + logger.exception("Pipeline error caught during run streaming") + raise RunnableError(exception=exc, traceback=traceback.format_exc()) + else: + next_outputs.append(output) + + if not have_new_streamed_outputs: + return + return next_outputs + + async def _stream_run_outputs( response_schema: run_schemas.ContainerRunResult, request: Request ): - outputs = response_schema.outputs or [] + """Generator returning output data for list of outputs. + We iterate over all outputs until we no longer have any streamed data to + output + """ + outputs = response_schema.outputs or [] while True: - next_outputs = [] - have_new_streamed_outputs = False - # iterate over all outputs, until we no longer have any streamed data to - # output - for output in outputs: - if output.type == run_schemas.RunIOType.stream: - try: - if output.value is None: - raise Exception("Stream value was None") - - next_value = output.value.__next__() - next_outputs.append( - run_schemas.RunOutput( - type=run_schemas.RunIOType.from_object(next_value), - value=next_value, - file=None, - ) - ) - have_new_streamed_outputs = True - except StopIteration: - # if no data left for this stream, return None value - next_outputs.append( - run_schemas.RunOutput( - type=run_schemas.RunIOType.from_object(next_value), - value=None, - file=None, - ) - ) - else: - next_outputs.append(output) - - if not have_new_streamed_outputs: - return + status_code = 200 + try: + next_outputs = _fetch_next_outputs(outputs) + if not next_outputs: + return + + new_response_schema = run_schemas.ContainerRunResult( + inputs=response_schema.inputs, + outputs=next_outputs, + error=response_schema.error, + ) - new_response_schema = run_schemas.ContainerRunResult( - inputs=response_schema.inputs, - outputs=next_outputs, - error=response_schema.error, - ) + except RunnableError as e: + # if we get a pipeline error, return a run error then finish + new_response_schema = run_schemas.ContainerRunResult( + outputs=None, + inputs=response_schema.inputs, + error=run_schemas.ContainerRunError( + type=run_schemas.ContainerRunErrorType.pipeline_error, + message=repr(e.exception), + traceback=e.traceback, + ), + ) + except Exception as e: + logger.exception("Unexpected error during run streaming") + status_code = 500 + new_response_schema = run_schemas.ContainerRunResult( + outputs=None, + inputs=response_schema.inputs, + error=run_schemas.ContainerRunError( + type=run_schemas.ContainerRunErrorType.unknown, + message=repr(e), + traceback=None, + ), + ) # serialise response to str and add newline separator - yield new_response_schema.json() + "\n" + yield f"{new_response_schema.json()}\n", status_code - # if request is disconnected terminate all iterators - if await request.is_disconnected(): + # if there was an error or request is disconnected terminate all iterators + if new_response_schema.error or await request.is_disconnected(): for output in outputs: if ( output.type == run_schemas.RunIOType.stream diff --git a/pyproject.toml b/pyproject.toml index 2a7e7a9a..14ca3141 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pipeline-ai" -version = "2.1.0" +version = "2.1.1" description = "Pipelines for machine learning workloads." authors = [ "Paul Hetherington ", diff --git a/tests/cloud/test_http.py b/tests/cloud/test_http.py new file mode 100644 index 00000000..0a297eef --- /dev/null +++ b/tests/cloud/test_http.py @@ -0,0 +1,47 @@ +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from pipeline.cloud.http import StreamingResponseWithStatusCode + + +@pytest.fixture +def dummy_app(): + app = FastAPI() + + @app.get("/stream/{status_code}") + def stream(status_code: int): + """Dummy endpoint to return a streaming response with status code. + + This mimics beahviour when proxying a streaming response, where the + upstream status code is unknown ahead of time (i.e. we can't just set + the status code on the response itself as it comes from the content + stream). + """ + content_stream = iter( + [ + ("Hello ", status_code), + ("World", status_code), + ] + ) + return StreamingResponseWithStatusCode( + content=content_stream, + headers={"X-Accel-Buffering": "no"}, + ) + + return app + + +@pytest.mark.parametrize("status_code", [200, 404, 500]) +def test_streaming_response_with_status_code(status_code, dummy_app): + """Test custom response class by setting up a dummy API""" + client = TestClient(dummy_app) + with client.stream("GET", f"/stream/{status_code}") as response: + assert response.status_code == status_code + # check headers are sent correctly + assert response.headers["X-Accel-Buffering"] == "no" + # check full response content + response_data = "" + for chunk in response.iter_text(): + response_data += chunk + assert response_data == "Hello World" diff --git a/tests/container/routes/test_runs.py b/tests/container/routes/test_runs.py index 90158362..c7af202b 100644 --- a/tests/container/routes/test_runs.py +++ b/tests/container/routes/test_runs.py @@ -40,7 +40,8 @@ async def test_stream_run_outputs(): ] output_values = [] - for result in results: + for result, status_code in results: + assert status_code == 200 outputs = json.loads(result)["outputs"] values = [o["value"] for o in outputs] output_values.append(values) @@ -67,3 +68,57 @@ async def test_stream_run_outputs(): None, ], ] + + +async def test_stream_run_outputs_when_exception_raised(): + """Test streaming outputs when pipeline raises an exception. + + Error should be reported back to the user. + """ + + def error_stream(): + yield 1 + raise Exception("dummy error") + + stream_output_one = run_schemas.RunOutput( + type=run_schemas.RunIOType.stream, value=Stream(error_stream()), file=None + ) + stream_output_two = run_schemas.RunOutput( + type=run_schemas.RunIOType.stream, + value=Stream(iter(["hello", "world"])), + file=None, + ) + static_output = run_schemas.RunOutput( + type=run_schemas.RunIOType.string, value="static output", file=None + ) + container_run_result = run_schemas.ContainerRunResult( + inputs=None, + outputs=[stream_output_one, static_output, stream_output_two], + error=None, + ) + + results = [ + (result, status_code) + async for result, status_code in _stream_run_outputs( + container_run_result, DummyRequest() + ) + ] + data = [json.loads(result) for result, _ in results] + status_codes = [status_code for _, status_code in results] + # even if pipeline_error, status code should be 200 + assert all(status_code == 200 for status_code in status_codes) + + # exception was raised on 2nd iteration, so we expect there to be a valid + # output followed by an error + assert len(results) == 2 + + assert data[0]["outputs"] == [ + {"type": "integer", "value": 1, "file": None}, + {"type": "string", "value": "static output", "file": None}, + {"type": "string", "value": "hello", "file": None}, + ] + + error = data[1].get("error") + assert error is not None + assert error["message"] == "Exception('dummy error')" + assert error["type"] == "pipeline_error"