Skip to content

Commit

Permalink
Add category filter to list artifacts (#462)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hedingber committed Oct 2, 2020
1 parent bfe974a commit ba066ca
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 25 deletions.
4 changes: 3 additions & 1 deletion mlrun/api/api/endpoints/artifacts.py
Expand Up @@ -5,6 +5,7 @@
from fastapi.concurrency import run_in_threadpool
from sqlalchemy.orm import Session

from mlrun.api import schemas
from mlrun.api.api import deps
from mlrun.api.api.utils import log_and_raise
from mlrun.api.utils.singletons.db import get_db
Expand Down Expand Up @@ -92,11 +93,12 @@ def list_artifacts(
name: str = None,
tag: str = None,
kind: str = None,
category: schemas.ArtifactCategories = None,
labels: List[str] = Query([], alias="label"),
db_session: Session = Depends(deps.get_db_session),
):
artifacts = get_db().list_artifacts(
db_session, name, project, tag, labels, kind=kind
db_session, name, project, tag, labels, kind=kind, category=category,
)
return {
"artifacts": artifacts,
Expand Down
1 change: 1 addition & 0 deletions mlrun/api/db/base.py
Expand Up @@ -92,6 +92,7 @@ def list_artifacts(
since=None,
until=None,
kind=None,
category: schemas.ArtifactCategories = None,
):
pass

Expand Down
1 change: 1 addition & 0 deletions mlrun/api/db/filedb/db.py
Expand Up @@ -81,6 +81,7 @@ def list_artifacts(
since=None,
until=None,
kind=None,
category: schemas.ArtifactCategories = None,
):
return self._transform_run_db_error(
self.db.list_artifacts, name, project, tag, labels, since, until
Expand Down
58 changes: 46 additions & 12 deletions mlrun/api/db/sqldb/db.py
Expand Up @@ -250,6 +250,7 @@ def list_artifacts(
since=None,
until=None,
kind=None,
category: schemas.ArtifactCategories = None,
):
project = project or config.default_project

Expand All @@ -263,7 +264,7 @@ def list_artifacts(
artifacts = ArtifactList(
artifact.struct
for artifact in self._find_artifacts(
session, project, uids, labels, since, until, name, kind
session, project, uids, labels, since, until, name, kind, category
)
)
return artifacts
Expand Down Expand Up @@ -828,6 +829,7 @@ def _find_artifacts(
until=None,
name=None,
kind=None,
category: schemas.ArtifactCategories = None,
):
"""
TODO: refactor this method
Expand All @@ -836,6 +838,10 @@ def _find_artifacts(
1. uids == "*" - in which we don't care about uids we just don't add any filter for this column
1. uids == "latest" - in which we find the relevant uid by finding the latest artifact using the updated column
"""
if category and kind:
message = "Category and Kind filters can't be given together"
logger.warning(message, kind=kind, category=category)
raise ValueError(message)
labels = label_set(labels)
query = self._query(session, Artifact, project=project)
if uids != "*":
Expand All @@ -857,21 +863,49 @@ def _find_artifacts(
query = query.filter(Artifact.key.ilike(f"%{name}%"))

if kind:
# see docstring of _post_query_runs_filter for why we're filtering it manually
filtered_artifacts = []
for artifact in query:
artifact_json = artifact.struct
if (
artifact_json
and isinstance(artifact_json, dict)
and kind in artifact_json.get("kind")
):
filtered_artifacts.append(artifact)
return filtered_artifacts
return self._filter_artifacts_by_kinds(query, [kind])

elif category:
return self._filter_artifacts_by_category(query, category)

else:
return query.all()

def _filter_artifacts_by_category(
self, artifacts, category: schemas.ArtifactCategories
):
kinds, exclude = category.to_kinds_filter()
return self._filter_artifacts_by_kinds(artifacts, kinds, exclude)

def _filter_artifacts_by_kinds(
self, artifacts, kinds: List[str], exclude: bool = False
):
"""
:param kinds - list of kinds to filter by
:param exclude - if true then the filter will be "all except" - get all artifacts excluding the ones who have
any of the given kinds
"""
# see docstring of _post_query_runs_filter for why we're filtering it manually
filtered_artifacts = []
for artifact in artifacts:
artifact_json = artifact.struct
if (
artifact_json
and isinstance(artifact_json, dict)
and (
(
not exclude
and any([kind == artifact_json.get("kind") for kind in kinds])
)
or (
exclude
and all([kind != artifact_json.get("kind") for kind in kinds])
)
)
):
filtered_artifacts.append(artifact)
return filtered_artifacts

def _find_functions(self, session, name, project, uid=None, labels=None):
query = self._query(session, Function, name=name, project=project)
if uid:
Expand Down
1 change: 1 addition & 0 deletions mlrun/api/schemas/__init__.py
@@ -1,5 +1,6 @@
# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx

from .artifact import ArtifactCategories
from .project import Project, ProjectOut, ProjectCreate, ProjectInDB, ProjectUpdate
from .schedule import (
SchedulesOutput,
Expand Down
27 changes: 27 additions & 0 deletions mlrun/api/schemas/artifact.py
@@ -0,0 +1,27 @@
import enum
import typing


class ArtifactCategories(str, enum.Enum):
model = "model"
dataset = "dataset"
other = "other"

def to_kinds_filter(self) -> typing.Tuple[typing.List[str], bool]:
# FIXME: these artifact definitions (or at least the kinds enum) should sit in a dedicated module
# import here to prevent import cycle
import mlrun.artifacts.dataset
import mlrun.artifacts.model

if self.value == ArtifactCategories.model.value:
return [mlrun.artifacts.model.ModelArtifact.kind], False
if self.value == ArtifactCategories.dataset.value:
return [mlrun.artifacts.dataset.DatasetArtifact.kind], False
if self.value == ArtifactCategories.other.value:
return (
[
mlrun.artifacts.model.ModelArtifact.kind,
mlrun.artifacts.dataset.DatasetArtifact.kind,
],
True,
)
81 changes: 69 additions & 12 deletions tests/api/db/test_artifacts.py
@@ -1,7 +1,10 @@
import pytest
from sqlalchemy.orm import Session
from mlrun.artifacts.plots import ChartArtifact, PlotArtifact
from mlrun.artifacts.dataset import DatasetArtifact
from mlrun.artifacts.model import ModelArtifact

from mlrun.api import schemas
from mlrun.api.db.base import DBInterface
from tests.api.db.conftest import dbs

Expand All @@ -13,8 +16,8 @@
def test_list_artifact_name_filter(db: DBInterface, db_session: Session):
artifact_name_1 = "artifact_name_1"
artifact_name_2 = "artifact_name_2"
artifact_1 = {"metadata": {"name": artifact_name_1}, "status": {"bla": "blabla"}}
artifact_2 = {"metadata": {"name": artifact_name_2}, "status": {"bla": "blabla"}}
artifact_1 = _generate_artifact(artifact_name_1)
artifact_2 = _generate_artifact(artifact_name_2)
uid = "artifact_uid"

db.store_artifact(
Expand Down Expand Up @@ -47,16 +50,8 @@ def test_list_artifact_kind_filter(db: DBInterface, db_session: Session):
artifact_kind_1 = ChartArtifact.kind
artifact_name_2 = "artifact_name_2"
artifact_kind_2 = PlotArtifact.kind
artifact_1 = {
"metadata": {"name": artifact_name_1},
"kind": artifact_kind_1,
"status": {"bla": "blabla"},
}
artifact_2 = {
"metadata": {"name": artifact_name_2},
"kind": artifact_kind_2,
"status": {"bla": "blabla"},
}
artifact_1 = _generate_artifact(artifact_name_1, artifact_kind_1)
artifact_2 = _generate_artifact(artifact_name_2, artifact_kind_2)
uid = "artifact_uid"

db.store_artifact(
Expand All @@ -75,3 +70,65 @@ def test_list_artifact_kind_filter(db: DBInterface, db_session: Session):
artifacts = db.list_artifacts(db_session, kind=artifact_kind_2)
assert len(artifacts) == 1
assert artifacts[0]["metadata"]["name"] == artifact_name_2


# running only on sqldb cause filedb is not really a thing anymore, will be removed soon
@pytest.mark.parametrize(
"db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"]
)
def test_list_artifact_category_filter(db: DBInterface, db_session: Session):
artifact_name_1 = "artifact_name_1"
artifact_kind_1 = ChartArtifact.kind
artifact_name_2 = "artifact_name_2"
artifact_kind_2 = PlotArtifact.kind
artifact_name_3 = "artifact_name_3"
artifact_kind_3 = ModelArtifact.kind
artifact_name_4 = "artifact_name_4"
artifact_kind_4 = DatasetArtifact.kind
artifact_1 = _generate_artifact(artifact_name_1, artifact_kind_1)
artifact_2 = _generate_artifact(artifact_name_2, artifact_kind_2)
artifact_3 = _generate_artifact(artifact_name_3, artifact_kind_3)
artifact_4 = _generate_artifact(artifact_name_4, artifact_kind_4)
uid = "artifact_uid"

db.store_artifact(
db_session, artifact_name_1, artifact_1, uid,
)
db.store_artifact(
db_session, artifact_name_2, artifact_2, uid,
)
db.store_artifact(
db_session, artifact_name_3, artifact_3, uid,
)
db.store_artifact(
db_session, artifact_name_4, artifact_4, uid,
)
artifacts = db.list_artifacts(db_session)
assert len(artifacts) == 4

artifacts = db.list_artifacts(db_session, category=schemas.ArtifactCategories.model)
assert len(artifacts) == 1
assert artifacts[0]["metadata"]["name"] == artifact_name_3

artifacts = db.list_artifacts(
db_session, category=schemas.ArtifactCategories.dataset
)
assert len(artifacts) == 1
assert artifacts[0]["metadata"]["name"] == artifact_name_4

artifacts = db.list_artifacts(db_session, category=schemas.ArtifactCategories.other)
assert len(artifacts) == 2
assert artifacts[0]["metadata"]["name"] == artifact_name_1
assert artifacts[1]["metadata"]["name"] == artifact_name_2


def _generate_artifact(name, kind=None):
artifact = {
"metadata": {"name": name},
"kind": kind,
"status": {"bla": "blabla"},
}
if kind:
artifact["kind"] = kind

return artifact

0 comments on commit ba066ca

Please sign in to comment.