Skip to content

Commit

Permalink
Add artifact kind filter in list artifacts (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hedingber committed Sep 30, 2020
1 parent 61c77c6 commit 1ce1e56
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 12 deletions.
5 changes: 4 additions & 1 deletion mlrun/api/api/endpoints/artifacts.py
Expand Up @@ -91,10 +91,13 @@ def list_artifacts(
project: str = config.default_project,
name: str = None,
tag: str = None,
kind: str = None,
labels: List[str] = Query([], alias="label"),
db_session: Session = Depends(deps.get_db_session),
):
artifacts = get_db().list_artifacts(db_session, name, project, tag, labels)
artifacts = get_db().list_artifacts(
db_session, name, project, tag, labels, kind=kind
)
return {
"artifacts": artifacts,
}
Expand Down
10 changes: 9 additions & 1 deletion mlrun/api/db/base.py
Expand Up @@ -83,7 +83,15 @@ def read_artifact(self, session, key, tag="", iter=None, project=""):

@abstractmethod
def list_artifacts(
self, session, name="", project="", tag="", labels=None, since=None, until=None
self,
session,
name="",
project="",
tag="",
labels=None,
since=None,
until=None,
kind=None,
):
pass

Expand Down
10 changes: 9 additions & 1 deletion mlrun/api/db/filedb/db.py
Expand Up @@ -72,7 +72,15 @@ def read_artifact(self, session, key, tag="", iter=None, project=""):
)

def list_artifacts(
self, session, name="", project="", tag="", labels=None, since=None, until=None
self,
session,
name="",
project="",
tag="",
labels=None,
since=None,
until=None,
kind=None,
):
return self._transform_run_db_error(
self.db.list_artifacts, name, project, tag, labels, since, until
Expand Down
41 changes: 32 additions & 9 deletions mlrun/api/db/sqldb/db.py
Expand Up @@ -249,6 +249,7 @@ def list_artifacts(
labels=None,
since=None,
until=None,
kind=None,
):
project = project or config.default_project

Expand All @@ -259,13 +260,13 @@ def list_artifacts(
if tag:
uids = self._resolve_tag(session, Artifact, project, tag)

arts = ArtifactList(
obj.struct
for obj in self._find_artifacts(
session, project, uids, labels, since, until, name
artifacts = ArtifactList(
artifact.struct
for artifact in self._find_artifacts(
session, project, uids, labels, since, until, name, kind
)
)
return arts
return artifacts

def del_artifact(self, session, key, tag="", project=""):
project = project or config.default_project
Expand Down Expand Up @@ -312,8 +313,8 @@ def _delete_artifact_labels(

def del_artifacts(self, session, name="", project="", tag="*", labels=None):
project = project or config.default_project
for obj in self._find_artifacts(session, project, tag, labels, name=name):
self.del_artifact(session, obj.key, "", project)
for artifact in self._find_artifacts(session, project, tag, labels, name=name):
self.del_artifact(session, artifact.key, "", project)

def store_function(
self, session, function, name, project="", tag="", versioned=False
Expand Down Expand Up @@ -818,7 +819,15 @@ def _latest_uid_filter(self, session, query):
)

def _find_artifacts(
self, session, project, uids, labels=None, since=None, until=None, name=None
self,
session,
project,
uids,
labels=None,
since=None,
until=None,
name=None,
kind=None,
):
"""
TODO: refactor this method
Expand Down Expand Up @@ -847,7 +856,21 @@ def _find_artifacts(
if name is not None:
query = query.filter(Artifact.key.ilike(f"%{name}%"))

return query
if kind:
# see docstring of _post_query_runs_filter for why we're filtering it manually
filtered_artifacts = []
for artifact in query:
artifact_json = artifact.struct
if (
artifact_json
and isinstance(artifact_json, dict)
and kind in artifact_json.get("kind")
):
filtered_artifacts.append(artifact)
return filtered_artifacts

else:
return query.all()

def _find_functions(self, session, name, project, uid=None, labels=None):
query = self._query(session, Function, name=name, project=project)
Expand Down
40 changes: 40 additions & 0 deletions tests/api/db/test_artifacts.py
@@ -1,5 +1,6 @@
import pytest
from sqlalchemy.orm import Session
from mlrun.artifacts.plots import ChartArtifact, PlotArtifact

from mlrun.api.db.base import DBInterface
from tests.api.db.conftest import dbs
Expand Down Expand Up @@ -35,3 +36,42 @@ def test_list_artifact_name_filter(db: DBInterface, db_session: Session):

artifacts = db.list_artifacts(db_session, name="artifact_name")
assert len(artifacts) == 2


# running only on sqldb cause filedb is not really a thing anymore, will be removed soon
@pytest.mark.parametrize(
"db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"]
)
def test_list_artifact_kind_filter(db: DBInterface, db_session: Session):
artifact_name_1 = "artifact_name_1"
artifact_kind_1 = ChartArtifact.kind
artifact_name_2 = "artifact_name_2"
artifact_kind_2 = PlotArtifact.kind
artifact_1 = {
"metadata": {"name": artifact_name_1},
"kind": artifact_kind_1,
"status": {"bla": "blabla"},
}
artifact_2 = {
"metadata": {"name": artifact_name_2},
"kind": artifact_kind_2,
"status": {"bla": "blabla"},
}
uid = "artifact_uid"

db.store_artifact(
db_session, artifact_name_1, artifact_1, uid,
)
db.store_artifact(
db_session, artifact_name_2, artifact_2, uid,
)
artifacts = db.list_artifacts(db_session)
assert len(artifacts) == 2

artifacts = db.list_artifacts(db_session, kind=artifact_kind_1)
assert len(artifacts) == 1
assert artifacts[0]["metadata"]["name"] == artifact_name_1

artifacts = db.list_artifacts(db_session, kind=artifact_kind_2)
assert len(artifacts) == 1
assert artifacts[0]["metadata"]["name"] == artifact_name_2

0 comments on commit 1ce1e56

Please sign in to comment.