Skip to content

Commit

Permalink
Fix run name and state filters (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hedingber committed Sep 30, 2020
1 parent e2af81b commit 61c77c6
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 7 deletions.
66 changes: 59 additions & 7 deletions mlrun/api/db/sqldb/db.py
Expand Up @@ -105,6 +105,9 @@ def store_run(self, session, struct, uid, project="", iter=0):
start_time=run_start_time(struct) or datetime.now(timezone.utc),
)
labels = run_labels(struct)
new_state = run_state(struct)
if new_state:
run.state = new_state
update_labels(run, labels)
run.struct = struct
self._upsert(session, run, ignore=True)
Expand Down Expand Up @@ -151,18 +154,19 @@ def list_runs(
last=0,
iter=False,
):
# FIXME: Run has no "name"
project = project or config.default_project
query = self._find_runs(session, uid, project, labels, state)
query = self._find_runs(session, uid, project, labels)
if sort:
query = query.order_by(Run.start_time.desc())
if last:
query = query.limit(last)
if not iter:
query = query.filter(Run.iteration == 0)

filtered_runs = self._post_query_runs_filter(query, name, state)

runs = RunList()
for run in query:
for run in filtered_runs:
runs.append(run.struct)

return runs
Expand All @@ -177,11 +181,12 @@ def del_runs(
):
# FIXME: Run has no `name`
project = project or config.default_project
query = self._find_runs(session, None, project, labels, state)
query = self._find_runs(session, None, project, labels)
if days_ago:
since = datetime.now(timezone.utc) - timedelta(days=days_ago)
query = query.filter(Run.start_time >= since)
for run in query: # Can not use query.delete with join
filtered_runs = self._post_query_runs_filter(query, name, state)
for run in filtered_runs: # Can not use query.delete with join
session.delete(run)
session.commit()

Expand Down Expand Up @@ -737,11 +742,58 @@ def _upsert(self, session, obj, ignore=False):
if not ignore:
raise DBError(f"duplicate {cls} - {err}") from err

def _find_runs(self, session, uid, project, labels, state):
def _find_runs(self, session, uid, project, labels):
labels = label_set(labels)
query = self._query(session, Run, uid=uid, project=project, state=state)
query = self._query(session, Run, uid=uid, project=project)
return self._add_labels_filter(session, query, Run, labels)

def _post_query_runs_filter(self, query, name=None, state=None):
"""
This function is hacky and exists to cover on bugs we had with how we save our data in the DB
We're doing it the hacky way since:
1. SQLDB is about to be replaced
2. Schema + Data migration are complicated and as long we can avoid them, we prefer to (also because of 1)
name - the name is only saved in the json itself, therefore we can't use the SQL query filter and have to filter
it ourselves
state - the state is saved in a column, but, there was a bug in which the state was only getting updated in the
json itself, therefore, in field systems, most runs records will have an empty or not updated data in the state
column
"""
if not name and not state:
return query.all()

filtered_runs = []
for run in query:
run_json = run.struct
if name:
if (
not run_json
or not isinstance(run_json, dict)
or name not in run_json.get("metadata", {}).get("name")
):
continue
if state:
record_state = run.state
json_state = None
if (
run_json
and isinstance(run_json, dict)
and run_json.get("status", {}).get("state")
):
json_state = run_json.get("status", {}).get("state")
if not record_state and not json_state:
continue
# json_state has precedence over record state
if json_state:
if state not in json_state:
continue
else:
if state not in record_state:
continue
filtered_runs.append(run)

return filtered_runs

def _latest_uid_filter(self, session, query):
# Create a sub query of latest uid (by updated) per (project,key)
subq = (
Expand Down
147 changes: 147 additions & 0 deletions tests/api/db/test_runs.py
@@ -0,0 +1,147 @@
import pytest
from mlrun.config import config
from datetime import datetime, timezone
from sqlalchemy.orm import Session

from mlrun.api.db.base import DBInterface
from mlrun.api.db.sqldb.models import Run
from tests.api.db.conftest import dbs


# running only on sqldb cause filedb is not really a thing anymore, will be removed soon
@pytest.mark.parametrize(
"db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"]
)
def test_list_runs_name_filter(db: DBInterface, db_session: Session):
run_name_1 = "run_name_1"
run_name_2 = "run_name_2"
run_1 = {"metadata": {"name": run_name_1}, "status": {"bla": "blabla"}}
run_2 = {"metadata": {"name": run_name_2}, "status": {"bla": "blabla"}}
run_uid_1 = "run_uid_1"
run_uid_2 = "run_uid_2"

db.store_run(
db_session, run_1, run_uid_1,
)
db.store_run(
db_session, run_2, run_uid_2,
)
runs = db.list_runs(db_session)
assert len(runs) == 2

runs = db.list_runs(db_session, name=run_name_1)
assert len(runs) == 1
assert runs[0]["metadata"]["name"] == run_name_1

runs = db.list_runs(db_session, name=run_name_2)
assert len(runs) == 1
assert runs[0]["metadata"]["name"] == run_name_2

runs = db.list_runs(db_session, name="run_name")
assert len(runs) == 2


# running only on sqldb cause filedb is not really a thing anymore, will be removed soon
@pytest.mark.parametrize(
"db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"]
)
def test_list_runs_state_filter(db: DBInterface, db_session: Session):
run_without_state_uid = "run_without_state_uid"
run_without_state = {"metadata": {"uid": run_without_state_uid}, "bla": "blabla"}
db.store_run(
db_session, run_without_state, run_without_state_uid,
)

run_with_json_state_state = "some_json_state"
run_with_json_state_uid = "run_with_json_state_uid"
run_with_json_state = {
"metadata": {"uid": run_with_json_state_uid},
"status": {"state": run_with_json_state_state},
}
run = Run(
uid=run_with_json_state_uid,
project=config.default_project,
iteration=0,
start_time=datetime.now(timezone.utc),
)
run.struct = run_with_json_state
db._upsert(db_session, run, ignore=True)

run_with_record_state_state = "some_record_state"
run_with_record_state_uid = "run_with_record_state_uid"
run_with_record_state = {
"metadata": {"uid": run_with_record_state_uid},
"bla": "blabla",
}
run = Run(
uid=run_with_record_state_uid,
project=config.default_project,
iteration=0,
state=run_with_record_state_state,
start_time=datetime.now(timezone.utc),
)
run.struct = run_with_record_state
db._upsert(db_session, run, ignore=True)

run_with_equal_json_and_record_state_state = "some_equal_json_and_record_state"
run_with_equal_json_and_record_state_uid = (
"run_with_equal_json_and_record_state_uid"
)
run_with_equal_json_and_record_state = {
"metadata": {"uid": run_with_equal_json_and_record_state_uid},
"status": {"state": run_with_equal_json_and_record_state_state},
}
db.store_run(
db_session,
run_with_equal_json_and_record_state,
run_with_equal_json_and_record_state_uid,
)

run_with_unequal_json_and_record_state_json_state = "some_unequal_json_state"
run_with_unequal_json_and_record_state_record_state = "some_unequal_record_state"
run_with_unequal_json_and_record_state_uid = (
"run_with_unequal_json_and_record_state_uid"
)
run_with_unequal_json_and_record_state = {
"metadata": {"uid": run_with_unequal_json_and_record_state_uid},
"status": {"state": run_with_unequal_json_and_record_state_json_state},
}
run = Run(
uid=run_with_unequal_json_and_record_state_uid,
project=config.default_project,
iteration=0,
state=run_with_unequal_json_and_record_state_record_state,
start_time=datetime.now(timezone.utc),
)
run.struct = run_with_unequal_json_and_record_state
db._upsert(db_session, run, ignore=True)

runs = db.list_runs(db_session)
assert len(runs) == 5

runs = db.list_runs(db_session, state="some_")
assert len(runs) == 4
assert run_without_state_uid not in [run["metadata"]["uid"] for run in runs]

runs = db.list_runs(db_session, state=run_with_json_state_state)
assert len(runs) == 1
assert runs[0]["metadata"]["uid"] == run_with_json_state_uid

runs = db.list_runs(db_session, state=run_with_record_state_state)
assert len(runs) == 1
assert runs[0]["metadata"]["uid"] == run_with_record_state_uid

runs = db.list_runs(db_session, state=run_with_equal_json_and_record_state_state)
assert len(runs) == 1
assert runs[0]["metadata"]["uid"] == run_with_equal_json_and_record_state_uid

runs = db.list_runs(
db_session, state=run_with_unequal_json_and_record_state_json_state
)
assert len(runs) == 1
assert runs[0]["metadata"]["uid"] == run_with_unequal_json_and_record_state_uid

runs = db.list_runs(
db_session, state=run_with_unequal_json_and_record_state_record_state
)
assert len(runs) == 0

0 comments on commit 61c77c6

Please sign in to comment.