Skip to content

Commit

Permalink
[Artifacts] Improve list artifacts querying [1.6.x] (#5658)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomerShor committed May 30, 2024
1 parent 0439055 commit 72845e9
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 77 deletions.
160 changes: 83 additions & 77 deletions server/api/db/sqldb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,7 @@ def list_artifacts(
uid: str = None,
producer_id: str = None,
producer_uri: str = None,
most_recent: bool = False,
):
project = project or config.default_project

Expand All @@ -666,20 +667,33 @@ def list_artifacts(
uid=uid,
producer_id=producer_id,
best_iteration=best_iteration,
producer_uri=producer_uri,
most_recent=most_recent,
attach_tags=not as_records,
)
if as_records:
return artifact_records

artifacts = ArtifactList()
for artifact in artifact_records:
for artifact, artifact_tag in artifact_records:
artifact_struct = artifact.full_object

# set the tags in the artifact struct
artifacts_with_tag = self._add_tags_to_artifact_struct(
session, artifact_struct, artifact.id, cls=ArtifactV2, tag=tag
)
artifacts.extend(artifacts_with_tag)
# Producer URI usually points to a run and is used to filter artifacts by the run that produced them.
# When the artifact was produced by a workflow, the producer id is a workflow id.
if producer_uri:
artifact_struct.setdefault("spec", {}).setdefault("producer", {})
artifact_producer_uri = artifact_struct["spec"]["producer"].get(
"uri", None
)
# We check if the producer uri is a substring of the artifact producer uri because it
# may contain additional information (like the run iteration) that we don't want to filter by.
if (
artifact_producer_uri is not None
and producer_uri not in artifact_producer_uri
):
continue

self._set_tag_in_artifact_struct(artifact_struct, artifact_tag)
artifacts.append(artifact_struct)

return artifacts

Expand Down Expand Up @@ -1152,23 +1166,6 @@ def _list_artifacts_for_tagging(

return artifacts

def _add_tags_to_artifact_struct(
self, session, artifact_struct, artifact_id, cls=Artifact, tag=None
):
artifacts = []
if tag and tag != "*":
self._set_tag_in_artifact_struct(artifact_struct, tag)
artifacts.append(artifact_struct)
else:
tag_results = self._query(session, cls.Tag, obj_id=artifact_id).all()
if not tag_results:
return [artifact_struct]
for tag_object in tag_results:
artifact_with_tag = deepcopy(artifact_struct)
self._set_tag_in_artifact_struct(artifact_with_tag, tag_object.name)
artifacts.append(artifact_with_tag)
return artifacts

@staticmethod
def _set_tag_in_artifact_struct(artifact, tag):
if is_legacy_artifact(artifact):
Expand Down Expand Up @@ -1211,40 +1208,69 @@ def _delete_artifacts_tags(

def _find_artifacts(
self,
session,
project,
ids=None,
tag=None,
labels=None,
since=None,
until=None,
name=None,
kind=None,
session: Session,
project: str,
ids: typing.Union[list[str], str] = None,
tag: str = None,
labels: typing.Union[list[str], str] = None,
since: datetime = None,
until: datetime = None,
name: str = None,
kind: mlrun.common.schemas.ArtifactCategories = None,
category: mlrun.common.schemas.ArtifactCategories = None,
iter=None,
uid=None,
producer_id=None,
best_iteration=False,
most_recent=False,
producer_uri=None,
):
iter: int = None,
uid: str = None,
producer_id: str = None,
best_iteration: bool = False,
most_recent: bool = False,
attach_tags: bool = False,
) -> typing.Union[
list[tuple[ArtifactV2, str]],
list[ArtifactV2],
]:
"""
Find artifacts by the given filters.
:param session: DB session
:param project: Project name
:param ids: Artifact IDs to filter by
:param tag: Tag to filter by
:param labels: Labels to filter by
:param since: Filter artifacts that were updated after this time
:param until: Filter artifacts that were updated before this time
:param name: Artifact name to filter by
:param kind: Artifact kind to filter by
:param category: Artifact category to filter by (if kind is not given)
:param iter: Artifact iteration to filter by
:param uid: Artifact UID to filter by
:param producer_id: Artifact producer ID to filter by
:param best_iteration: Filter by best iteration artifacts
:param most_recent: Filter by most recent artifacts
:param attach_tags: Whether to return a list of tuples of (ArtifactV2, tag_name). If False, only ArtifactV2
:return: a list of tuples of (ArtifactV2, tag_name) or a list of ArtifactV2 (if attach_tags is False)
"""
if category and kind:
message = "Category and Kind filters can't be given together"
logger.warning(message, kind=kind, category=category)
raise ValueError(message)

query = self._query(session, ArtifactV2, project=project)
query = session.query(ArtifactV2, ArtifactV2.Tag.name)

if ids and ids != "*":
query = query.filter(ArtifactV2.id.in_(ids))
# join on tags
if tag and tag != "*":
obj_name = name or None
object_tag_uids = self._resolve_class_tag_uids(
session, ArtifactV2, project, tag, obj_name
# If a tag is given, we can just join (faster than outer join) and filter on the tag
query = query.join(ArtifactV2.Tag, ArtifactV2.Tag.obj_id == ArtifactV2.id)
query = query.filter(ArtifactV2.Tag.name == tag)
else:
# If no tag is given, we need to outer join to get all artifacts, even if they don't have tags
query = query.outerjoin(
ArtifactV2.Tag, ArtifactV2.Tag.obj_id == ArtifactV2.id
)
if not object_tag_uids:
return []
query = query.filter(ArtifactV2.uid.in_(object_tag_uids))
if project:
query = query.filter(ArtifactV2.project == project)
if ids and ids != "*":
query = query.filter(ArtifactV2.id.in_(ids))
if uid:
query = query.filter(ArtifactV2.uid == uid)
if name:
Expand All @@ -1271,27 +1297,13 @@ def _find_artifacts(
if most_recent:
query = self._attach_most_recent_artifact_query(session, query)

# Producer URI usually points to a run and is used to filter artifacts by the run that produced them when
# the artifact producer id is a workflow id (artifact was created as part of a workflow).
if producer_uri:
artifacts = []
for artifact in query:
artifact_struct = artifact.full_object
artifact_struct.setdefault("spec", {}).setdefault("producer", {})
artifact_producer_uri = artifact_struct["spec"]["producer"].get(
"uri", None
)
# We check if the producer uri is a substring of the artifact producer uri because the producer uri
# may contain additional information (like the run iteration) that we don't want to filter by.
if (
artifact_producer_uri is not None
and producer_uri in artifact_producer_uri
):
artifacts.append(artifact)
artifacts_and_tags = query.all()

return artifacts
if not attach_tags:
# we might have duplicate records due to the tagging mechanism, so we need to deduplicate
return list({artifact for artifact, _ in artifacts_and_tags})

return query.all()
return artifacts_and_tags

def _add_artifact_name_query(self, query, name=None):
if not name:
Expand Down Expand Up @@ -2261,27 +2273,21 @@ def _calculate_feature_sets_counters(self, session) -> Dict[str, int]:
}
return project_to_feature_set_count

def _calculate_models_counters(self, session) -> Dict[str, int]:
import mlrun.artifacts

# The kind filter is applied post the query to the DB (manually in python code), so counting should be that
# way as well, therefore we're doing it here, and can't do it with sql as the above
def _calculate_models_counters(self, session) -> dict[str, int]:
# We're using the "most_recent" which gives us only one version of each artifact key, which is what we want to
# count (artifact count, not artifact versions count)
model_artifacts = self._find_artifacts(
session,
None,
kind=mlrun.artifacts.model.ModelArtifact.kind,
kind=mlrun.common.schemas.ArtifactCategories.model,
most_recent=True,
)
project_to_models_count = collections.defaultdict(int)
for model_artifact in model_artifacts:
project_to_models_count[model_artifact.project] += 1
return project_to_models_count

def _calculate_files_counters(self, session) -> Dict[str, int]:
# The category filter is applied post the query to the DB (manually in python code), so counting should be that
# way as well, therefore we're doing it here, and can't do it with sql as the above
def _calculate_files_counters(self, session) -> dict[str, int]:
# We're using the "most_recent" flag which gives us only one version of each artifact key, which is what we
# want to count (artifact count, not artifact versions count)
file_artifacts = self._find_artifacts(
Expand Down
9 changes: 9 additions & 0 deletions tests/api/db/test_sqldb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session

import mlrun.artifacts
import mlrun.common.schemas
import server.api.db.sqldb.models
from mlrun.lists import ArtifactList
Expand Down Expand Up @@ -84,15 +85,23 @@ def test_list_artifact_date(db: SQLDB, db_session: Session):
t3 = t2 - timedelta(days=7)
project = "p7"

# create artifacts in the db directly to avoid the store_artifact function which sets the updated field
artifacts_to_create = []
for key, updated, producer_id in [
("k1", t1, "p1"),
("k2", t2, "p2"),
("k3", t3, "p3"),
]:
artifact_struct = mlrun.artifacts.Artifact(
metadata=mlrun.artifacts.ArtifactMetadata(
key=key, project=project, tree=producer_id
),
spec=mlrun.artifacts.ArtifactSpec(),
)
db_artifact = ArtifactV2(
project=project, key=key, updated=updated, producer_id=producer_id
)
db_artifact.full_object = artifact_struct.to_dict()
artifacts_to_create.append(db_artifact)

db._upsert(db_session, artifacts_to_create)
Expand Down

0 comments on commit 72845e9

Please sign in to comment.