Skip to content

Commit

Permalink
[Pipelines] Support watch remote:local engine (#5593)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonmr committed May 28, 2024
1 parent bbd0650 commit 1f6067d
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 23 deletions.
22 changes: 22 additions & 0 deletions mlrun/common/runtimes/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import enum
import typing

import mlrun_pipelines.common.models

import mlrun.common.constants as mlrun_constants


Expand Down Expand Up @@ -193,6 +195,26 @@ def not_allowed_for_deletion_states():
# TODO: add aborting state once we have it
]

@staticmethod
def run_state_to_pipeline_run_status(run_state: str):
if not run_state:
return mlrun_pipelines.common.models.RunStatuses.runtime_state_unspecified

if run_state not in RunStates.all():
raise ValueError(f"Invalid run state: {run_state}")

return {
RunStates.completed: mlrun_pipelines.common.models.RunStatuses.succeeded,
RunStates.error: mlrun_pipelines.common.models.RunStatuses.failed,
RunStates.running: mlrun_pipelines.common.models.RunStatuses.running,
RunStates.created: mlrun_pipelines.common.models.RunStatuses.pending,
RunStates.pending: mlrun_pipelines.common.models.RunStatuses.pending,
RunStates.unknown: mlrun_pipelines.common.models.RunStatuses.runtime_state_unspecified,
RunStates.aborted: mlrun_pipelines.common.models.RunStatuses.canceled,
RunStates.aborting: mlrun_pipelines.common.models.RunStatuses.canceling,
RunStates.skipped: mlrun_pipelines.common.models.RunStatuses.skipped,
}[run_state]


# TODO: remove this class in 1.9.0 - use only MlrunInternalLabels
class RunLabels(enum.Enum):
Expand Down
61 changes: 42 additions & 19 deletions mlrun/projects/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from mlrun_pipelines.helpers import new_pipe_metadata

import mlrun
import mlrun.common.runtimes.constants
import mlrun.common.schemas
import mlrun.utils.notifications
from mlrun.errors import err_to_str
Expand Down Expand Up @@ -371,7 +372,7 @@ def __init__(
engine: type["_PipelineRunner"],
project: "mlrun.projects.MlrunProject",
workflow: WorkflowSpec = None,
state: str = "",
state: mlrun_pipelines.common.models.RunStatuses = "",
exc: Exception = None,
):
"""
Expand Down Expand Up @@ -479,6 +480,7 @@ def get_run_status(
timeout=None,
expected_statuses=None,
notifiers: mlrun.utils.notifications.CustomNotificationPusher = None,
**kwargs,
):
pass

Expand Down Expand Up @@ -610,6 +612,7 @@ def get_run_status(
timeout=None,
expected_statuses=None,
notifiers: mlrun.utils.notifications.CustomNotificationPusher = None,
**kwargs,
):
if timeout is None:
timeout = 60 * 60
Expand Down Expand Up @@ -733,6 +736,7 @@ def get_run_status(
timeout=None,
expected_statuses=None,
notifiers: mlrun.utils.notifications.CustomNotificationPusher = None,
**kwargs,
):
pass

Expand Down Expand Up @@ -860,7 +864,7 @@ def _get_workflow_id_or_bail():
)
state = mlrun_pipelines.common.models.RunStatuses.failed
else:
state = mlrun_pipelines.common.models.RunStatuses.succeeded
state = mlrun_pipelines.common.models.RunStatuses.running
project.notifiers.push_pipeline_start_message(
project.metadata.name,
)
Expand All @@ -877,28 +881,47 @@ def _get_workflow_id_or_bail():
@staticmethod
def get_run_status(
project,
run,
run: _PipelineRunStatus,
timeout=None,
expected_statuses=None,
notifiers: mlrun.utils.notifications.CustomNotificationPusher = None,
inner_engine: type[_PipelineRunner] = None,
):
# ignore notifiers for remote notifications, as they are handled by the remote pipeline notifications,
# so overriding with CustomNotificationPusher with empty list of notifiers or only local notifiers
local_project_notifiers = list(
set(mlrun.utils.notifications.NotificationTypes.local()).intersection(
set(project.notifiers.notifications.keys())
inner_engine = inner_engine or _KFPRunner
if inner_engine.engine == _KFPRunner.engine:
# ignore notifiers for remote notifications, as they are handled by the remote pipeline notifications,
# so overriding with CustomNotificationPusher with empty list of notifiers or only local notifiers
local_project_notifiers = list(
set(mlrun.utils.notifications.NotificationTypes.local()).intersection(
set(project.notifiers.notifications.keys())
)
)
notifiers = mlrun.utils.notifications.CustomNotificationPusher(
local_project_notifiers
)
return _KFPRunner.get_run_status(
project,
run,
timeout,
expected_statuses,
notifiers=notifiers,
)

elif inner_engine.engine == _LocalRunner.engine:
mldb = mlrun.db.get_run_db(secrets=project._secrets)
pipeline_runner_run = mldb.read_run(run.run_id, project=project.name)
pipeline_runner_run = mlrun.run.RunObject.from_dict(pipeline_runner_run)
pipeline_runner_run.logs(db=mldb)
pipeline_runner_run.refresh()
run._state = mlrun.common.runtimes.constants.RunStates.run_state_to_pipeline_run_status(
pipeline_runner_run.status.state
)
run._exc = pipeline_runner_run.status.error

else:
raise mlrun.errors.MLRunInvalidArgumentError(
f"Unsupported inner runner engine: {inner_engine.engine}"
)
)
notifiers = mlrun.utils.notifications.CustomNotificationPusher(
local_project_notifiers
)
return _KFPRunner.get_run_status(
project,
run,
timeout,
expected_statuses,
notifiers=notifiers,
)


def create_pipeline(project, pipeline, functions, secrets=None, handler=None):
Expand Down
9 changes: 6 additions & 3 deletions mlrun/projects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2992,14 +2992,17 @@ def run(
)
workflow_spec.clear_tmp()
if (timeout or watch) and not workflow_spec.schedule:
run_status_kwargs = {}
status_engine = run._engine
# run's engine gets replaced with inner engine if engine is remote,
# so in that case we need to get the status from the remote engine manually
# TODO: support watch for remote:local
if workflow_engine.engine == "remote" and status_engine.engine != "local":
if workflow_engine.engine == "remote":
status_engine = _RemoteRunner
run_status_kwargs["inner_engine"] = run._engine

status_engine.get_run_status(project=self, run=run, timeout=timeout)
status_engine.get_run_status(
project=self, run=run, timeout=timeout, **run_status_kwargs
)
return run

def save_workflow(self, name, target, artifact_path=None, ttl=None):
Expand Down
2 changes: 1 addition & 1 deletion server/api/crud/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def get_workflow_id(
and state.casefold()
== mlrun_pipelines.common.models.RunStatuses.running.casefold()
):
workflow_id = ""
workflow_id = run_object.metadata.uid
else:
raise mlrun.errors.MLRunNotFoundError(
f"Workflow id of run {project}:{uid} not found"
Expand Down

0 comments on commit 1f6067d

Please sign in to comment.