Skip to content

Commit

Permalink
[RemoteWorkflows] Fail fast when remote runner is failing (#5469)
Browse files Browse the repository at this point in the history
  • Loading branch information
liranbg committed Apr 30, 2024
1 parent 412dc91 commit c634bd8
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 5 deletions.
18 changes: 18 additions & 0 deletions mlrun/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,11 @@ def iteration(self):
def iteration(self, iteration):
self._iteration = iteration

def is_workflow_runner(self):
if not self.labels:
return False
return self.labels.get("job-type", "") == "workflow-runner"


class HyperParamStrategies:
grid = "grid"
Expand Down Expand Up @@ -1068,6 +1073,19 @@ def __init__(
self.reason = reason
self.notifications = notifications or {}

def is_failed(self) -> Optional[bool]:
"""
This method returns whether a run has failed.
Returns none if state has yet to be defined. callee is responsible for handling None.
(e.g wait for state to be defined)
"""
if not self.state:
return None
return self.state.casefold() in [
mlrun.run.RunStatuses.failed.casefold(),
mlrun.run.RunStatuses.error.casefold(),
]


class RunTemplate(ModelObj):
"""Run template"""
Expand Down
27 changes: 22 additions & 5 deletions mlrun/projects/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import abc
import builtins
import http
import importlib.util as imputil
import os
import tempfile
Expand Down Expand Up @@ -877,17 +878,33 @@ def run(
get_workflow_id_timeout=get_workflow_id_timeout,
)

def _get_workflow_id_or_bail():
try:
return run_db.get_workflow_id(
project=project.name,
name=workflow_response.name,
run_id=workflow_response.run_id,
engine=workflow_spec.engine,
)
except mlrun.errors.MLRunHTTPStatusError as get_wf_exc:
# fail fast on specific errors
if get_wf_exc.error_status_code in [
http.HTTPStatus.PRECONDITION_FAILED
]:
raise mlrun.errors.MLRunFatalFailureError(
original_exception=get_wf_exc
)

# raise for a retry (on other errors)
raise

# Getting workflow id from run:
response = retry_until_successful(
1,
get_workflow_id_timeout,
logger,
False,
run_db.get_workflow_id,
project=project.name,
name=workflow_response.name,
run_id=workflow_response.run_id,
engine=workflow_spec.engine,
_get_workflow_id_or_bail,
)
workflow_id = response.workflow_id
# After fetching the workflow_id the workflow executed successfully
Expand Down
13 changes: 13 additions & 0 deletions server/api/crud/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,19 @@ def get_workflow_id(

if workflow_id is None:
if (
run_object.metadata.is_workflow_runner()
and run_object.status.is_failed()
):
state = run_object.status.state
state_text = run_object.status.error
workflow_name = run_object.spec.parameters.get(
"workflow_name", "<unknown>"
)
raise mlrun.errors.MLRunPreconditionFailedError(
f"Failed to run workflow {workflow_name}, state: {state}, state_text: {state_text}"
)

elif (
engine == "local"
and state.casefold() == mlrun.run.RunStatuses.running.casefold()
):
Expand Down
31 changes: 31 additions & 0 deletions tests/api/api/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,37 @@ def test_bad_schedule_format(db: Session, client: TestClient):
assert resp.status_code == HTTPStatus.BAD_REQUEST


def test_get_workflow_fail_fast(db: Session, client: TestClient):
_create_proj_with_workflow(client)

right_id = "".join(random.choices("0123456789abcdef", k=40))
data = {
"metadata": {
"name": "run-name",
"labels": {
"job-type": "workflow-runner",
},
},
"spec": {
"parameters": {"workflow_name": "main"},
},
"status": {
"state": "failed",
"error": "some dummy error",
# workflow id is empty to simulate a failed remote runner
"results": {"workflow_id": None},
},
}
server.api.crud.Runs().store_run(db, data, right_id, project=PROJECT_NAME)
resp = client.get(
f"projects/{PROJECT_NAME}/workflows/{WORKFLOW_NAME}/runs/{right_id}"
)

# remote runner has failed, so the run should be failed
assert resp.status_code == HTTPStatus.PRECONDITION_FAILED
assert "some dummy error" in resp.json()["detail"]


def test_get_workflow_bad_id(db: Session, client: TestClient):
_create_proj_with_workflow(client)

Expand Down

0 comments on commit c634bd8

Please sign in to comment.