Skip to content

Commit

Permalink
[Runtimes] Make abort-able list of runtime kinds and expose in fronte…
Browse files Browse the repository at this point in the history
…nd spec (#909)
  • Loading branch information
Hedingber committed May 2, 2021
1 parent 4ee8b27 commit cd26eea
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 13 deletions.
6 changes: 5 additions & 1 deletion mlrun/api/api/endpoints/frontend_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import mlrun.api.schemas
import mlrun.api.utils.clients.iguazio
import mlrun.runtimes

router = fastapi.APIRouter()

Expand All @@ -15,7 +16,10 @@ def get_frontend_spec(session: typing.Optional[str] = fastapi.Cookie(None)):
jobs_dashboard_url = None
if session:
jobs_dashboard_url = _resolve_jobs_dashboard_url(session)
return mlrun.api.schemas.FrontendSpec(jobs_dashboard_url=jobs_dashboard_url)
return mlrun.api.schemas.FrontendSpec(
jobs_dashboard_url=jobs_dashboard_url,
abortable_function_kinds=mlrun.runtimes.RuntimeKinds.abortable_runtimes(),
)


def _resolve_jobs_dashboard_url(session: str) -> typing.Optional[str]:
Expand Down
8 changes: 3 additions & 5 deletions mlrun/api/crud/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,10 @@ def update_run(
raise mlrun.errors.MLRunConflictError(
"Run is already in terminal state, can not be aborted"
)
if (
current_run.get("metadata", {}).get("labels", {}).get("kind")
== mlrun.runtimes.RuntimeKinds.dask
):
runtime_kind = current_run.get("metadata", {}).get("labels", {}).get("kind")
if runtime_kind not in mlrun.runtimes.RuntimeKinds.abortable_runtimes():
raise mlrun.errors.MLRunBadRequestError(
"Run of a dask function can not be aborted"
f"Run of kind {runtime_kind} can not be aborted"
)
# aborting the run meaning deleting its runtime resources
# TODO: runtimes crud interface should ideally expose some better API that will hold inside itself the
Expand Down
1 change: 1 addition & 0 deletions mlrun/api/schemas/frontend_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

class FrontendSpec(pydantic.BaseModel):
jobs_dashboard_url: typing.Optional[str]
abortable_function_kinds: typing.List[str] = []
4 changes: 1 addition & 3 deletions mlrun/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def new_function(
elif kind in RuntimeKinds.all():
runner = get_runtime_class(kind).from_dict(runtime)
else:
supported_runtimes = ",".join(RuntimeKinds.all() + ["local"])
supported_runtimes = ",".join(RuntimeKinds.all())
raise Exception(
f"unsupported runtime ({kind}) or missing command, supported runtimes: {supported_runtimes}"
)
Expand Down Expand Up @@ -712,8 +712,6 @@ def resolve_nuclio_subkind(kind):

if kind is None or kind in ["", "Function"]:
raise ValueError("please specify the function kind")
elif kind in ["local"]:
r = LocalRuntime()
elif kind in RuntimeKinds.all():
r = get_runtime_class(kind)()
else:
Expand Down
12 changes: 12 additions & 0 deletions mlrun/runtimes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class RuntimeKinds(object):
remotespark = "remote-spark"
mpijob = "mpijob"
serving = "serving"
local = "local"

@staticmethod
def all():
Expand All @@ -109,6 +110,7 @@ def all():
RuntimeKinds.spark,
RuntimeKinds.remotespark,
RuntimeKinds.mpijob,
RuntimeKinds.local,
]

@staticmethod
Expand All @@ -121,6 +123,15 @@ def runtime_with_handlers():
RuntimeKinds.mpijob,
]

@staticmethod
def abortable_runtimes():
return [
RuntimeKinds.job,
RuntimeKinds.spark,
RuntimeKinds.remotespark,
RuntimeKinds.mpijob,
]

@staticmethod
def nuclio_runtimes():
return [
Expand Down Expand Up @@ -179,6 +190,7 @@ def get_runtime_class(kind: str):
RuntimeKinds.serving: ServingRuntime,
RuntimeKinds.dask: DaskCluster,
RuntimeKinds.job: KubejobRuntime,
RuntimeKinds.local: LocalRuntime,
RuntimeKinds.spark: SparkRuntime,
RuntimeKinds.remotespark: RemoteSparkRuntime,
}
Expand Down
20 changes: 20 additions & 0 deletions tests/api/api/test_frontend_spec.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
import http
import unittest.mock

import deepdiff
import fastapi.testclient
import sqlalchemy.orm

import mlrun.api.crud
import mlrun.api.schemas
import mlrun.api.utils.clients.iguazio
import mlrun.errors
import mlrun.runtimes


def test_get_frontend_spec(
db: sqlalchemy.orm.Session, client: fastapi.testclient.TestClient
) -> None:
mlrun.api.utils.clients.iguazio.Client().try_get_grafana_service_url = (
unittest.mock.Mock()
)
response = client.get("/api/frontend-spec")
assert response.status_code == http.HTTPStatus.OK.value
frontend_spec = mlrun.api.schemas.FrontendSpec(**response.json())
assert (
deepdiff.DeepDiff(
frontend_spec.abortable_function_kinds,
mlrun.runtimes.RuntimeKinds.abortable_runtimes(),
)
== {}
)


def test_get_frontend_spec_jobs_dashboard_url_resolution(
db: sqlalchemy.orm.Session, client: fastapi.testclient.TestClient
) -> None:
mlrun.api.utils.clients.iguazio.Client().try_get_grafana_service_url = (
unittest.mock.Mock()
Expand Down
17 changes: 13 additions & 4 deletions tests/api/api/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,20 @@ def test_run_with_nan_in_body(db: Session, client: TestClient) -> None:

def test_abort_run(db: Session, client: TestClient) -> None:
project = "some-project"
run_in_progress = {"status": {"state": mlrun.runtimes.constants.RunStates.running}}
run_in_progress = {
"metadata": {"labels": {"kind": mlrun.runtimes.RuntimeKinds.job}},
"status": {"state": mlrun.runtimes.constants.RunStates.running},
}
run_in_progress_uid = "in-progress-uid"
run_completed = {"status": {"state": mlrun.runtimes.constants.RunStates.completed}}
run_completed = {
"metadata": {"labels": {"kind": mlrun.runtimes.RuntimeKinds.job}},
"status": {"state": mlrun.runtimes.constants.RunStates.completed},
}
run_completed_uid = "completed-uid"
run_aborted = {"status": {"state": mlrun.runtimes.constants.RunStates.aborted}}
run_aborted = {
"metadata": {"labels": {"kind": mlrun.runtimes.RuntimeKinds.job}},
"status": {"state": mlrun.runtimes.constants.RunStates.aborted},
}
run_aborted_uid = "aborted-uid"
run_dask = {
"metadata": {"labels": {"kind": mlrun.runtimes.RuntimeKinds.dask}},
Expand All @@ -59,7 +68,7 @@ def test_abort_run(db: Session, client: TestClient) -> None:
# aborted is terminal state - should fail
response = client.patch(f"/api/run/{project}/{run_aborted_uid}", json=abort_body)
assert response.status_code == HTTPStatus.CONFLICT.value
# kind dask - should fail
# dask kind not abortable - should fail
response = client.patch(f"/api/run/{project}/{run_dask_uid}", json=abort_body)
assert response.status_code == HTTPStatus.BAD_REQUEST.value
# running is ok - should succeed
Expand Down

0 comments on commit cd26eea

Please sign in to comment.