Skip to content

Commit

Permalink
[Artifacts] Improve list artifact tags db query (#5648)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomerShor committed May 29, 2024
1 parent 39bbdf8 commit 4436666
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 61 deletions.
21 changes: 3 additions & 18 deletions server/api/api/endpoints/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,28 +97,13 @@ async def list_artifact_tags(
mlrun.common.schemas.AuthorizationAction.read,
auth_info,
)
tag_tuples = await run_in_threadpool(
tags = await run_in_threadpool(
server.api.crud.Artifacts().list_artifact_tags, db_session, project, category
)
artifact_key_to_tag = {tag_tuple[1]: tag_tuple[2] for tag_tuple in tag_tuples}
allowed_artifact_keys = await server.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions(
mlrun.common.schemas.AuthorizationResourceTypes.artifact,
list(artifact_key_to_tag.keys()),
lambda artifact_key: (
project,
artifact_key,
),
auth_info,
)
tags = [
tag_tuple[2]
for tag_tuple in tag_tuples
if tag_tuple[1] in allowed_artifact_keys
]

return {
"project": project,
# Remove duplicities
"tags": list(set(tags)),
"tags": tags,
}


Expand Down
77 changes: 43 additions & 34 deletions server/api/db/sqldb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,24 +836,28 @@ def del_artifacts(

def list_artifact_tags(
self, session, project, category: mlrun.common.schemas.ArtifactCategories = None
) -> list[tuple[str, str, str]]:
) -> list[str]:
"""
:return: a list of Tuple of (project, artifact.key, tag)
List all tags for artifacts in the DB
:param session: DB session
:param project: Project name
:param category: Artifact category to filter by
:return: a list of distinct tags
"""
artifacts = self.list_artifacts(session, project=project, category=category)
results = []
for artifact in artifacts:
# we want to return only artifacts that have tags when listing tags
if artifact["metadata"].get("tag"):
results.append(
(
project,
artifact["spec"].get("db_key"),
artifact["metadata"].get("tag"),
)
)
query = (
self._query(session, ArtifactV2.Tag.name)
.select_from(ArtifactV2)
.join(ArtifactV2.Tag, ArtifactV2.Tag.obj_id == ArtifactV2.id)
.filter(ArtifactV2.project == project)
.group_by(ArtifactV2.Tag.name)
)
if category:
query = self._add_artifact_category_query(category, query)

return results
# the query returns a list of tuples, we need to extract the tag from each tuple
return [tag for (tag,) in query]

@retry_on_conflict
def overwrite_artifacts_with_tag(
Expand Down Expand Up @@ -1219,20 +1223,6 @@ def _delete_artifacts_tags(
if commit:
session.commit()

def _add_artifact_name_query(self, query, name=None):
if not name:
return query

if name.startswith("~"):
# Escape special chars (_,%) since we still need to do a like query.
exact_name = self._escape_characters_for_like_query(name)
# Use Like query to find substring matches
return query.filter(
ArtifactV2.key.ilike(f"%{exact_name[1:]}%", escape="\\")
)

return query.filter(ArtifactV2.key == name)

def _find_artifacts(
self,
session,
Expand Down Expand Up @@ -1291,11 +1281,7 @@ def _find_artifacts(
if kind:
query = query.filter(ArtifactV2.kind == kind)
elif category:
kinds, exclude = category.to_kinds_filter()
if exclude:
query = query.filter(ArtifactV2.kind.notin_(kinds))
else:
query = query.filter(ArtifactV2.kind.in_(kinds))
query = self._add_artifact_category_query(category, query)
if most_recent:
query = self._attach_most_recent_artifact_query(session, query)

Expand All @@ -1321,6 +1307,29 @@ def _find_artifacts(

return query.all()

def _add_artifact_name_query(self, query, name=None):
if not name:
return query

if name.startswith("~"):
# Escape special chars (_,%) since we still need to do a like query.
exact_name = self._escape_characters_for_like_query(name)
# Use Like query to find substring matches
return query.filter(
ArtifactV2.key.ilike(f"%{exact_name[1:]}%", escape="\\")
)

return query.filter(ArtifactV2.key == name)

@staticmethod
def _add_artifact_category_query(category, query):
kinds, exclude = category.to_kinds_filter()
if exclude:
query = query.filter(ArtifactV2.kind.notin_(kinds))
else:
query = query.filter(ArtifactV2.kind.in_(kinds))
return query

def _get_existing_artifact(
self,
session,
Expand Down
4 changes: 2 additions & 2 deletions tests/api/db/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,9 +1062,9 @@ async def test_project_file_counter(self, db: DBInterface, db_session: Session):
artifacts = db.list_artifacts(db_session, project=project, tag="latest")
assert len(artifacts) == 5

# query all artifacts tags, should return 15+5=20 tags
# query all artifacts tags, should return 4 tags = 3 tags + latest
tags = db.list_artifact_tags(db_session, project=project)
assert len(tags) == 20
assert len(tags) == 4

# files counters should return the most recent artifacts, for each key -> 5 artifacts
project_to_files_count = db._calculate_files_counters(db_session)
Expand Down
13 changes: 6 additions & 7 deletions tests/api/db/test_sqldb.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,24 @@ def test_list_artifact_tags(db: SQLDB, db_session: Session):

tags = db.list_artifact_tags(db_session, "p1")
expected_tags = [
("p1", "k1", "t1"),
("p1", "k1", "latest"),
("p1", "k1", "t2"),
("p1", "k2", "t3"),
("p1", "k2", "latest"),
"t1",
"latest",
"t2",
"t3",
]
assert deepdiff.DeepDiff(tags, expected_tags, ignore_order=True) == {}

# filter by category
model_tags = db.list_artifact_tags(
db_session, "p1", mlrun.common.schemas.ArtifactCategories.model
)
expected_tags = [("p1", "k2", "t3"), ("p1", "k2", "latest")]
expected_tags = ["t3", "latest"]
assert deepdiff.DeepDiff(expected_tags, model_tags, ignore_order=True) == {}

model_tags = db.list_artifact_tags(
db_session, "p2", mlrun.common.schemas.ArtifactCategories.dataset
)
expected_tags = [("p2", "k3", "t4"), ("p2", "k3", "latest")]
expected_tags = ["t4", "latest"]
assert deepdiff.DeepDiff(expected_tags, model_tags, ignore_order=True) == {}


Expand Down

0 comments on commit 4436666

Please sign in to comment.