Skip to content

Commit

Permalink
[API] Support multiple state filter on runs api (#5573)
Browse files Browse the repository at this point in the history
  • Loading branch information
roei3000b committed May 17, 2024
1 parent b7982e7 commit f87f68f
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 15 deletions.
6 changes: 5 additions & 1 deletion mlrun/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Optional, Union

import mlrun.alerts
import mlrun.common.runtimes.constants
import mlrun.common.schemas
import mlrun.model_monitoring

Expand Down Expand Up @@ -63,7 +64,10 @@ def list_runs(
uid: Optional[Union[str, list[str]]] = None,
project: Optional[str] = None,
labels: Optional[Union[str, list[str]]] = None,
state: Optional[str] = None,
state: Optional[
mlrun.common.runtimes.constants.RunStates
] = None, # Backward compatibility
states: Optional[list[mlrun.common.runtimes.constants.RunStates]] = None,
sort: bool = True,
last: int = 0,
iter: bool = False,
Expand Down
12 changes: 10 additions & 2 deletions mlrun/db/httpdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from mlrun_pipelines.utils import compile_pipeline

import mlrun
import mlrun.common.runtimes
import mlrun.common.schemas
import mlrun.common.types
import mlrun.model_monitoring.model_endpoint
Expand Down Expand Up @@ -752,7 +753,10 @@ def list_runs(
uid: Optional[Union[str, list[str]]] = None,
project: Optional[str] = None,
labels: Optional[Union[str, list[str]]] = None,
state: Optional[str] = None,
state: Optional[
mlrun.common.runtimes.constants.RunStates
] = None, # Backward compatibility
states: typing.Optional[list[mlrun.common.runtimes.constants.RunStates]] = None,
sort: bool = True,
last: int = 0,
iter: bool = False,
Expand Down Expand Up @@ -791,6 +795,7 @@ def list_runs(
of a label (i.e. list("key=value")) or by looking for the existence of a given
key (i.e. "key").
:param state: List only runs whose state is specified.
:param states: List only runs whose state is one of the provided states.
:param sort: Whether to sort the result according to their start time. Otherwise, results will be
returned by their internal order in the DB (order will not be guaranteed).
:param last: Deprecated - currently not used (will be removed in 1.8.0).
Expand Down Expand Up @@ -831,6 +836,7 @@ def list_runs(
and not uid
and not labels
and not state
and not states
and not last
and not start_time_from
and not start_time_to
Expand All @@ -849,7 +855,9 @@ def list_runs(
"name": name,
"uid": uid,
"label": labels or [],
"state": state,
"state": mlrun.utils.helpers.as_list(state)
if state is not None
else states or None,
"sort": bool2str(sort),
"iter": bool2str(iter),
"start_time_from": datetime_to_iso(start_time_from),
Expand Down
6 changes: 5 additions & 1 deletion mlrun/db/nopdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Optional, Union

import mlrun.alerts
import mlrun.common.runtimes.constants
import mlrun.common.schemas
import mlrun.errors

Expand Down Expand Up @@ -80,7 +81,10 @@ def list_runs(
uid: Optional[Union[str, list[str]]] = None,
project: Optional[str] = None,
labels: Optional[Union[str, list[str]]] = None,
state: Optional[str] = None,
state: Optional[
mlrun.common.runtimes.constants.RunStates
] = None, # Backward compatibility
states: Optional[list[mlrun.common.runtimes.constants.RunStates]] = None,
sort: bool = True,
last: int = 0,
iter: bool = False,
Expand Down
11 changes: 9 additions & 2 deletions mlrun/projects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from mlrun_pipelines.models import PipelineNodeWrapper

import mlrun.common.helpers
import mlrun.common.runtimes.constants
import mlrun.common.schemas.artifact
import mlrun.common.schemas.model_monitoring.constants as mm_constants
import mlrun.db
Expand Down Expand Up @@ -3689,7 +3690,10 @@ def list_runs(
name: Optional[str] = None,
uid: Optional[Union[str, list[str]]] = None,
labels: Optional[Union[str, list[str]]] = None,
state: Optional[str] = None,
state: Optional[
mlrun.common.runtimes.constants.RunStates
] = None, # Backward compatibility
states: typing.Optional[list[mlrun.common.runtimes.constants.RunStates]] = None,
sort: bool = True,
last: int = 0,
iter: bool = False,
Expand Down Expand Up @@ -3724,6 +3728,7 @@ def list_runs(
of a label (i.e. list("key=value")) or by looking for the existence of a given
key (i.e. "key").
:param state: List only runs whose state is specified.
:param states: List only runs whose state is one of the provided states.
:param sort: Whether to sort the result according to their start time. Otherwise, results will be
returned by their internal order in the DB (order will not be guaranteed).
:param last: Deprecated - currently not used (will be removed in 1.8.0).
Expand All @@ -3740,7 +3745,9 @@ def list_runs(
uid,
self.metadata.name,
labels=labels,
state=state,
states=mlrun.utils.helpers.as_list(state)
if state is not None
else states or None,
sort=sort,
last=last,
iter=iter,
Expand Down
7 changes: 4 additions & 3 deletions server/api/api/endpoints/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from fastapi.concurrency import run_in_threadpool
from sqlalchemy.orm import Session

import mlrun.common.runtimes.constants
import mlrun.common.schemas
import server.api.crud
import server.api.utils.auth.verifier
Expand Down Expand Up @@ -185,15 +186,15 @@ async def delete_run(
"/runs",
deprecated=True,
description="/runs is deprecated in 1.5.0 and will be removed in 1.8.0, "
"use /projects/{project}/runs/{uid} instead",
"use /projects/{project}/runs/ instead",
)
@router.get("/projects/{project}/runs")
async def list_runs(
project: str = None,
name: str = None,
uid: list[str] = Query([]),
labels: list[str] = Query([], alias="label"),
state: str = None,
states: list[str] = Query([], alias="state"),
last: int = 0,
sort: bool = True,
iter: bool = True,
Expand Down Expand Up @@ -251,7 +252,7 @@ async def _filter_runs(_runs):
uid=uid,
project=project,
labels=labels,
state=state,
states=states,
sort=sort,
last=last,
iter=iter,
Expand Down
10 changes: 7 additions & 3 deletions server/api/crud/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,10 @@ def list_runs(
uid: typing.Optional[typing.Union[str, list[str]]] = None,
project: str = "",
labels: typing.Optional[typing.Union[str, list[str]]] = None,
state: typing.Optional[str] = None,
states: typing.Optional[list[str]] = None, # Backward compatibility
state: typing.Optional[
mlrun.common.runtimes.constants.RunStates
] = None, # Backward compatibility
states: typing.Optional[typing.Union[str, list[str]]] = None,
sort: bool = True,
last: int = 0,
iter: bool = False,
Expand Down Expand Up @@ -231,7 +233,9 @@ def list_runs(
uid=uid,
project=project,
labels=labels,
states=[state] if state is not None else states or None,
states=mlrun.utils.helpers.as_list(state)
if state is not None
else states or None,
sort=sort,
last=last,
iter=iter,
Expand Down
2 changes: 1 addition & 1 deletion server/api/db/sqldb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def list_runs(
uid: typing.Optional[typing.Union[str, list[str]]] = None,
project: str = "",
labels: typing.Optional[typing.Union[str, list[str]]] = None,
states: typing.Optional[list[str]] = None,
states: typing.Optional[list[mlrun.common.runtimes.constants.RunStates]] = None,
sort: bool = True,
last: int = 0,
iter: bool = False,
Expand Down
8 changes: 6 additions & 2 deletions server/api/rundb/sqldb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sqlalchemy.exc import SQLAlchemyError

import mlrun.alerts
import mlrun.common.runtimes.constants
import mlrun.common.schemas
import mlrun.common.schemas.artifact
import mlrun.db.factory
Expand Down Expand Up @@ -113,7 +114,8 @@ def list_runs(
uid: Optional[Union[str, list[str]]] = None,
project: Optional[str] = None,
labels: Optional[Union[str, list[str]]] = None,
state: Optional[str] = None,
state: Optional[mlrun.common.runtimes.constants.RunStates] = None,
states: Optional[list[mlrun.common.runtimes.constants.RunStates]] = None,
sort: bool = True,
last: int = 0,
iter: bool = False,
Expand All @@ -137,7 +139,9 @@ def list_runs(
uid=uid,
project=project,
labels=labels,
states=mlrun.utils.helpers.as_list(state) if state is not None else None,
states=mlrun.utils.helpers.as_list(state)
if state is not None
else states or None,
sort=sort,
last=last,
iter=iter,
Expand Down
16 changes: 16 additions & 0 deletions tests/rundb/test_httpdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ def test_runs(create_server):
for i in range(count):
uid = f"uid_{i}"
run_as_dict["metadata"]["name"] = "run-name"
if i % 2 == 0:
run_as_dict["status"]["state"] = "completed"
else:
run_as_dict["status"]["state"] = "created"
db.store_run(run_as_dict, uid, prj)

# retrieve only the last run as it is partitioned by name
Expand All @@ -287,8 +291,20 @@ def test_runs(create_server):
)
assert len(runs) == 7, "bad number of runs"

# retrieve only created runs
runs = db.list_runs(project=prj, states=["created"])
assert len(runs) == 3, "bad number of runs"

# retrieve created and completed runs
runs = db.list_runs(project=prj, states=["created", "completed"])
assert len(runs) == 7, "bad number of runs"

# delete runs in created state
db.del_runs(project=prj, state="created")

# delete runs in completed state
db.del_runs(project=prj, state="completed")

runs = db.list_runs(project=prj)
assert not runs, "found runs in after delete"

Expand Down

0 comments on commit f87f68f

Please sign in to comment.