Skip to content

Commit

Permalink
Publish current run from v4/container/state endpoint (#476)
Browse files Browse the repository at this point in the history
* Update dependencies

* Use async test client

* Fix logging typo

* Refactor exeuction handler

* Add run id to container state

* Slight rename

* Bump version
  • Loading branch information
rossgray authored Jul 30, 2024
1 parent fbf7f33 commit 32382ba
Show file tree
Hide file tree
Showing 13 changed files with 310 additions and 248 deletions.
3 changes: 2 additions & 1 deletion pipeline/cloud/schemas/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ class PipelineState(str, Enum):

class PipelineContainerState(BaseModel):
state: PipelineState
message: t.Optional[str]
message: str | None
current_run: str | None


class PipelineScalingInfo(BaseModel):
Expand Down
36 changes: 20 additions & 16 deletions pipeline/container/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ def _get_url_or_path(input_schema: run_schemas.RunInput) -> str | None:


class Manager:
def __init__(self, pipeline_path: str):
self.pipeline_state: pipeline_schemas.PipelineState = (
pipeline_schemas.PipelineState.not_loaded
)
self.pipeline_state_message: str | None = None
self.current_run: str | None = None
try:
self._load(pipeline_path)
except Exception:
tb = traceback.format_exc()
logger.exception("Exception raised when loading pipeline")
self.pipeline_state = pipeline_schemas.PipelineState.load_failed
self.pipeline_state_message = tb
return

def _load(self, pipeline_path: str):
with logger.contextualize(pipeline_stage="loading"):
logger.info("Loading pipeline")
Expand Down Expand Up @@ -74,20 +89,6 @@ def _load(self, pipeline_path: str):

logger.info(f"Pipeline set to {self.pipeline_path}")

def __init__(self, pipeline_path: str):
self.pipeline_state: pipeline_schemas.PipelineState = (
pipeline_schemas.PipelineState.not_loaded
)
self.pipeline_state_message: str | None = None
try:
self._load(pipeline_path)
except Exception:
tb = traceback.format_exc()
logger.exception("Exception raised when loading pipeline")
self.pipeline_state = pipeline_schemas.PipelineState.load_failed
self.pipeline_state_message = tb
return

def startup(self):
if self.pipeline_state == pipeline_schemas.PipelineState.load_failed:
return
Expand All @@ -98,7 +99,7 @@ def startup(self):
self.pipeline._startup()
except Exception:
tb = traceback.format_exc()
logger.exception("Exception raised during pipeline execution")
logger.exception("Exception raised during pipeline startup")
self.pipeline_state = pipeline_schemas.PipelineState.startup_failed
self.pipeline_state_message = tb
else:
Expand Down Expand Up @@ -242,12 +243,15 @@ def run(
) -> t.Any:
with logger.contextualize(run_id=run_id):
logger.info("Running pipeline")
args = self._parse_inputs(input_data, self.pipeline)
self.current_run = run_id
try:
args = self._parse_inputs(input_data, self.pipeline)
result = self.pipeline.run(*args)
except RunInputException:
raise
except Exception as exc:
raise RunnableError(exception=exc, traceback=traceback.format_exc())
finally:
self.current_run = None
logger.info("Run successful")
return result
3 changes: 2 additions & 1 deletion pipeline/container/routes/v4/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
},
)
async def is_ready(request: Request, response: Response):
run_manager = request.app.state.manager
run_manager: Manager = request.app.state.manager
if run_manager.pipeline_state in [
pipeline_schemas.PipelineState.loading,
pipeline_schemas.PipelineState.not_loaded,
Expand All @@ -42,6 +42,7 @@ async def is_ready(request: Request, response: Response):
return pipeline_schemas.PipelineContainerState(
state=run_manager.pipeline_state,
message=run_manager.pipeline_state_message,
current_run=run_manager.current_run,
)


Expand Down
146 changes: 5 additions & 141 deletions pipeline/container/routes/v4/runs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import asyncio
import io
import os
import shutil
import traceback
import uuid
from pathlib import Path

from fastapi import APIRouter, Request, Response
from loguru import logger
Expand All @@ -13,8 +8,7 @@
from pipeline.cloud.schemas import pipelines as pipeline_schemas
from pipeline.cloud.schemas import runs as run_schemas
from pipeline.container.manager import Manager
from pipeline.exceptions import RunInputException, RunnableError
from pipeline.objects.graph import File
from pipeline.exceptions import RunnableError

router = APIRouter(prefix="/runs", tags=["runs"])

Expand Down Expand Up @@ -45,12 +39,11 @@ async def run(
return result

execution_queue: asyncio.Queue = request.app.state.execution_queue

response_queue: asyncio.Queue = asyncio.Queue()
execution_queue.put_nowait((run_create, response_queue))
run_output = await response_queue.get()

response_schema, response.status_code = await response_queue.get()
logger.info("Returning run result")
response_schema, response.status_code = _generate_run_result(run_output)
return response_schema


Expand Down Expand Up @@ -83,9 +76,8 @@ async def stream_run(

response_queue: asyncio.Queue = asyncio.Queue()
execution_queue.put_nowait((run_create, response_queue))
run_output = await response_queue.get()

response_schema, response.status_code = _generate_run_result(run_output)
# wait for result
response_schema, response.status_code = await response_queue.get()

outputs = response_schema.outputs or []
if not outputs:
Expand Down Expand Up @@ -239,131 +231,3 @@ def _handle_pipeline_state_not_ready(
traceback=manager.pipeline_state_message,
),
)


def _generate_run_result(run_output) -> tuple[run_schemas.ContainerRunResult, int]:
if isinstance(run_output, RunInputException):
return (
run_schemas.ContainerRunResult(
outputs=None,
inputs=None,
error=run_schemas.ContainerRunError(
type=run_schemas.ContainerRunErrorType.input_error,
message=run_output.message,
traceback=None,
),
),
400,
)
elif isinstance(run_output, RunnableError):
return (
run_schemas.ContainerRunResult(
outputs=None,
inputs=None,
error=run_schemas.ContainerRunError(
type=run_schemas.ContainerRunErrorType.pipeline_error,
message=repr(run_output.exception),
traceback=run_output.traceback,
),
),
200,
)
elif isinstance(run_output, Exception):
return (
run_schemas.ContainerRunResult(
outputs=None,
inputs=None,
error=run_schemas.ContainerRunError(
type=run_schemas.ContainerRunErrorType.unknown,
message=str(run_output),
traceback=None,
),
),
500,
)
else:
outputs = _parse_run_outputs(run_output)
return (
run_schemas.ContainerRunResult(
outputs=outputs,
error=None,
inputs=None,
),
200,
)


def _parse_run_outputs(run_outputs):
outputs = []
for output in run_outputs:
output_type = run_schemas.RunIOType.from_object(output)
# if single file
if output_type == run_schemas.RunIOType.file:
file_schema = _save_run_file(output)
outputs.append(
run_schemas.RunOutput(type=output_type, value=None, file=file_schema)
)
# else if list of files
elif (
output_type == run_schemas.RunIOType.pkl
and isinstance(output, list)
and all([isinstance(item, (File, io.BufferedIOBase)) for item in output])
):
file_list = []
for file in output:
file_schema = _save_run_file(file)
file_list.append(
run_schemas.RunOutput(
type=run_schemas.RunIOType.file,
value=None,
file=file_schema,
)
)
outputs.append(
run_schemas.RunOutput(
type=run_schemas.RunIOType.array,
value=file_list,
file=None,
)
)
else:
outputs.append(
run_schemas.RunOutput(type=output_type, value=output, file=None)
)
return outputs


def _save_run_file(file: File | io.BufferedIOBase) -> run_schemas.RunOutputFile:
# ensure we save the file somewhere unique on the file system so it doesn't
# get overwritten by another run
uuid_path = str(uuid.uuid4())
output_path = Path(f"/tmp/run_files/{uuid_path}")
output_path.mkdir(parents=True, exist_ok=True)

if isinstance(file, File):
# should always exist
assert file.path
file_name = file.path.name
output_file = output_path / file_name
file_size = os.stat(file.path).st_size
# copy file to new output location
shutil.copyfile(file.path, output_file)
else:
file_name = getattr(
file,
"name",
str(uuid.uuid4()),
)
try:
file_size = file.seek(0, os.SEEK_END)
except Exception:
file_size = -1
logger.warning(f"Could not get size of type {type(file)}")
# write file to new output location
output_file = output_path / file_name
output_file.write_bytes(file.read())

file_schema = run_schemas.RunOutputFile(
name=file_name, path=str(output_file), size=file_size, url=None
)
return file_schema
Empty file.
Loading

0 comments on commit 32382ba

Please sign in to comment.