Skip to content

Commit

Permalink
Support async runs (#478)
Browse files Browse the repository at this point in the history
* Implement async runs

* Bump version

* Update default logging level

* Update logging

* Fix issue with current run reporting

* upload async result in asyncio task

* Revert
  • Loading branch information
rossgray authored Aug 1, 2024
1 parent f680b7e commit e545efb
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ RUN pip install --no-cache-dir -U pip setuptools wheel

# Install serving packages
RUN pip install --no-cache-dir -U fastapi==0.105.0 uvicorn==0.25.0 \
python-multipart==0.0.6 loguru==0.7.2
python-multipart==0.0.6 loguru==0.7.2 httpx==0.27.0

# Container commands
{container_commands}
Expand Down
11 changes: 3 additions & 8 deletions pipeline/container/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,10 @@ def setup_logging():
logger.remove()

use_json_logging = os.environ.get("USE_JSON_LOGGING", False)
log_level = os.environ.get("LOG_LEVEL", "INFO")
if use_json_logging:
handler = dict(
sink=json_log_handler,
colorize=False,
)
handler = dict(sink=json_log_handler, colorize=False, level=log_level)
else:
handler = dict(
sink=default_log_handler,
colorize=True,
)
handler = dict(sink=default_log_handler, colorize=True, level=log_level)
logger.configure(handlers=[handler])
logging.basicConfig(handlers=[InterceptHandler()], level=0, force=True)
40 changes: 39 additions & 1 deletion pipeline/container/routes/v4/runs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import traceback

from fastapi import APIRouter, Request, Response
from fastapi import APIRouter, Request, Response, status
from loguru import logger

from pipeline.cloud.http import StreamingResponseWithStatusCode
Expand All @@ -25,20 +25,58 @@
"description": "Invalid input data",
"model": run_schemas.ContainerRunResult,
},
202: {
"description": "Async run initiated",
"model": run_schemas.ContainerRunResult,
},
},
)
async def run(
run_create: run_schemas.ContainerRunCreate,
request: Request,
response: Response,
) -> run_schemas.ContainerRunResult:
"""Run this pipeline with the given inputs.
If `async_run=True` then this route will return an empty result immediately,
then make a POST call to the provided `callback_url` once the run is
complete.
"""
run_id = run_create.run_id
with logger.contextualize(run_id=run_id):
logger.info(f"Received run request; async_run={run_create.async_run}")
manager: Manager = request.app.state.manager
if result := _handle_pipeline_state_not_ready(manager):
return result

execution_queue: asyncio.Queue = request.app.state.execution_queue
# If async run, we put run on the queue and return immediately
if run_create.async_run is True:
# check request is valid
if not run_create.callback_url:
response.status_code = status.HTTP_400_BAD_REQUEST
return run_schemas.ContainerRunResult(
outputs=None,
inputs=None,
error=run_schemas.ContainerRunError(
type=run_schemas.ContainerRunErrorType.input_error,
message="callback_url is required for async runs",
traceback=None,
),
)

execution_queue.put_nowait((run_create, None))
# return empty result for now with a status code of 202 to indicate
# we have accepted the request and are processing it in the
# background
response.status_code = status.HTTP_202_ACCEPTED
return run_schemas.ContainerRunResult(
outputs=None,
error=None,
inputs=None,
)
# Otherwise, we put run on the queue then wait for the run to finish and
# return the result
response_queue: asyncio.Queue = asyncio.Queue()
execution_queue.put_nowait((run_create, response_queue))

Expand Down
37 changes: 36 additions & 1 deletion pipeline/container/services/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import uuid
from pathlib import Path

import httpx
from fastapi.concurrency import run_in_threadpool
from loguru import logger

Expand Down Expand Up @@ -34,11 +35,45 @@ async def execution_handler(execution_queue: asyncio.Queue, manager: Manager) ->
output = e

response_schema, status_code = _generate_run_result(output)
response_queue.put_nowait((response_schema, status_code))

if args.async_run is True:
# send response back to callback URL
assert args.callback_url is not None
# send result in an async task so it runs in parallel
# and we are free to process the next run
asyncio.create_task(
_send_async_result(
callback_url=args.callback_url,
response_schema=response_schema,
)
)
else:
response_queue.put_nowait((response_schema, status_code))
except Exception:
logger.exception("Got an error in the execution loop handler")


async def _send_async_result(
callback_url: str, response_schema: run_schemas.ContainerRunResult
):
logger.info("Sending async result...")
async with httpx.AsyncClient() as client:
try:
response = await client.post(
callback_url, json=response_schema.dict(), timeout=10
)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
logger.error(
f"Error sending async result: "
f"{exc.response.status_code} - {exc.response.text}"
)
except httpx.RequestError as exc:
logger.exception(f"Error sending async result: {exc}")
else:
logger.info("Sending async result successful")


def _generate_run_result(run_output) -> tuple[run_schemas.ContainerRunResult, int]:
if isinstance(run_output, RunInputException):
return (
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pipeline-ai"
version = "2.5.1"
version = "2.6.0"
description = "Pipelines for machine learning workloads."
authors = [
"Paul Hetherington <ph@mystic.ai>",
Expand Down
181 changes: 169 additions & 12 deletions tests/container/routes/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ async def test_when_pipeline_failed_to_load(self, client_failed_pipeline):
result = run_schemas.ContainerRunResult.parse_obj(response.json())
assert result.outputs is None
assert result.error is not None
error = run_schemas.ContainerRunError.parse_obj(result.error)
assert error.type == run_schemas.ContainerRunErrorType.startup_error
assert error.message == "Pipeline failed to load"
assert error.traceback is not None
assert result.error.type == run_schemas.ContainerRunErrorType.startup_error
assert result.error.message == "Pipeline failed to load"
assert result.error.traceback is not None

async def test_when_invalid_inputs(self, client):
payload = run_schemas.ContainerRunCreate(
Expand All @@ -72,9 +71,8 @@ async def test_when_invalid_inputs(self, client):
result = run_schemas.ContainerRunResult.parse_obj(response.json())
assert result.outputs is None
assert result.error is not None
error = run_schemas.ContainerRunError.parse_obj(result.error)
assert error.type == run_schemas.ContainerRunErrorType.input_error
assert error.message == "Inputs do not match graph inputs"
assert result.error.type == run_schemas.ContainerRunErrorType.input_error
assert result.error.message == "Inputs do not match graph inputs"

async def test_when_pipeline_raises_an_exception(self, client):
"""We've set up the fixture pipeline to only accept positive integers,
Expand All @@ -96,8 +94,167 @@ async def test_when_pipeline_raises_an_exception(self, client):
result = run_schemas.ContainerRunResult.parse_obj(response.json())
assert result.outputs is None
assert result.error is not None
error = run_schemas.ContainerRunError.parse_obj(result.error)
assert error.type == run_schemas.ContainerRunErrorType.pipeline_error
assert error.message == "ValueError('I can only sum positive integers')"
assert error.traceback is not None
assert error.traceback.startswith("Traceback (most recent call last):")
assert result.error.type == run_schemas.ContainerRunErrorType.pipeline_error
assert result.error.message == "ValueError('I can only sum positive integers')"
assert result.error.traceback is not None
assert result.error.traceback.startswith("Traceback (most recent call last):")


class TestCreateAsyncRun:

async def test_success(self, client):
"""In the case of an async run, the API should respond immediately, and
then make an API call to the callback URL once the run is complete.
"""
with patch(
"pipeline.container.services.run._send_async_result"
) as mock_send_async_result:
callback_url = "https://example.com/callback"
payload = run_schemas.ContainerRunCreate(
run_id="run_123",
inputs=[
run_schemas.RunInput(type="integer", value=5),
run_schemas.RunInput(type="integer", value=4),
],
async_run=True,
callback_url=callback_url,
)
response = await client.post("/v4/runs", json=payload.dict())

assert response.status_code == status.HTTP_202_ACCEPTED
result = run_schemas.ContainerRunResult.parse_obj(response.json())
assert result.error is None
assert result.outputs is None

expected_response_schema = run_schemas.ContainerRunResult(
inputs=None,
error=None,
outputs=[run_schemas.RunOutput(type="integer", value=9)],
)
mock_send_async_result.assert_called_once_with(
callback_url=callback_url, response_schema=expected_response_schema
)

# make a synchronous run afterwards to ensure execution handler is still
# working as expected
payload = run_schemas.ContainerRunCreate(
run_id="run_124",
inputs=[
run_schemas.RunInput(type="integer", value=5),
run_schemas.RunInput(type="integer", value=10),
],
)
response = await client.post("/v4/runs", json=payload.dict())

assert response.status_code == status.HTTP_200_OK
result = run_schemas.ContainerRunResult.parse_obj(response.json())
assert result.error is None
assert result.outputs == [run_schemas.RunOutput(type="integer", value=15)]

async def test_when_pipeline_failed_to_load(self, client_failed_pipeline):
"""Should return an error immediately in this case"""

client = client_failed_pipeline
with patch(
"pipeline.container.services.run._send_async_result"
) as mock_send_async_result:

payload = run_schemas.ContainerRunCreate(
run_id="run_123",
inputs=[
run_schemas.RunInput(type="integer", value=5),
run_schemas.RunInput(type="integer", value=4),
],
async_run=True,
callback_url="https://example.com/callback",
)
response = await client.post("/v4/runs", json=payload.dict())

assert response.status_code == status.HTTP_200_OK
result = run_schemas.ContainerRunResult.parse_obj(response.json())
assert result.outputs is None
assert result.error is not None
assert result.error.type == run_schemas.ContainerRunErrorType.startup_error
assert result.error.message == "Pipeline failed to load"
assert result.error.traceback is not None

mock_send_async_result.assert_not_called()

async def test_when_invalid_inputs(self, client):
"""In the case of invalid inputs, the API will respond immediately, and
then make an API call to the callback URL with an error.
Perhaps it would be better to return the error immediately, but the
parsing of inputs is currently handled at run execution time, which
happens asynchronously.
"""
with patch(
"pipeline.container.services.run._send_async_result"
) as mock_send_async_result:
callback_url = "https://example.com/callback"
payload = run_schemas.ContainerRunCreate(
run_id="run_123",
# one input is missing
inputs=[
run_schemas.RunInput(type="integer", value=5),
],
async_run=True,
callback_url=callback_url,
)
response = await client.post("/v4/runs", json=payload.dict())

assert response.status_code == status.HTTP_202_ACCEPTED
result = run_schemas.ContainerRunResult.parse_obj(response.json())
assert result.error is None
assert result.outputs is None

expected_response_schema = run_schemas.ContainerRunResult(
inputs=None,
error=run_schemas.ContainerRunError(
type=run_schemas.ContainerRunErrorType.input_error,
message="Inputs do not match graph inputs",
),
outputs=None,
)
mock_send_async_result.assert_called_once_with(
callback_url=callback_url, response_schema=expected_response_schema
)

async def test_when_pipeline_raises_an_exception(self, client):
"""We've set up the fixture pipeline to only accept positive integers,
so providing negative ones should result in a RunnableError.
(Note: in reality we could add options to our inputs to handle this and
return an input_error)
"""
with patch(
"pipeline.container.services.run._send_async_result"
) as mock_send_async_result:
callback_url = "https://example.com/callback"
payload = run_schemas.ContainerRunCreate(
run_id="run_123",
inputs=[
run_schemas.RunInput(type="integer", value=-5),
run_schemas.RunInput(type="integer", value=5),
],
async_run=True,
callback_url=callback_url,
)
response = await client.post("/v4/runs", json=payload.dict())

assert response.status_code == status.HTTP_202_ACCEPTED
result = run_schemas.ContainerRunResult.parse_obj(response.json())
assert result.error is None
assert result.outputs is None

mock_calls = mock_send_async_result.call_args_list
assert len(mock_calls) == 1
assert mock_calls[0].kwargs["callback_url"] == callback_url
actual_error = mock_calls[0].kwargs["response_schema"].error
assert actual_error.type == run_schemas.ContainerRunErrorType.pipeline_error
assert (
actual_error.message == "ValueError('I can only sum positive integers')"
)
assert actual_error.traceback.startswith(
"Traceback (most recent call last):"
)

0 comments on commit e545efb

Please sign in to comment.