diff --git a/server/api/db/sqldb/db.py b/server/api/db/sqldb/db.py index 05da98a42131..b5d8557b4120 100644 --- a/server/api/db/sqldb/db.py +++ b/server/api/db/sqldb/db.py @@ -661,6 +661,7 @@ def list_artifacts( uid: str = None, producer_id: str = None, producer_uri: str = None, + most_recent: bool = False, ): project = project or config.default_project @@ -683,20 +684,33 @@ def list_artifacts( uid=uid, producer_id=producer_id, best_iteration=best_iteration, - producer_uri=producer_uri, + most_recent=most_recent, + attach_tags=not as_records, ) if as_records: return artifact_records artifacts = ArtifactList() - for artifact in artifact_records: + for artifact, artifact_tag in artifact_records: artifact_struct = artifact.full_object - # set the tags in the artifact struct - artifacts_with_tag = self._add_tags_to_artifact_struct( - session, artifact_struct, artifact.id, cls=ArtifactV2, tag=tag - ) - artifacts.extend(artifacts_with_tag) + # Producer URI usually points to a run and is used to filter artifacts by the run that produced them. + # When the artifact was produced by a workflow, the producer id is a workflow id. + if producer_uri: + artifact_struct.setdefault("spec", {}).setdefault("producer", {}) + artifact_producer_uri = artifact_struct["spec"]["producer"].get( + "uri", None + ) + # We check if the producer uri is a substring of the artifact producer uri because it + # may contain additional information (like the run iteration) that we don't want to filter by. + if ( + artifact_producer_uri is not None + and producer_uri not in artifact_producer_uri + ): + continue + + self._set_tag_in_artifact_struct(artifact_struct, artifact_tag) + artifacts.append(artifact_struct) return artifacts @@ -1169,23 +1183,6 @@ def _list_artifacts_for_tagging( return artifacts - def _add_tags_to_artifact_struct( - self, session, artifact_struct, artifact_id, cls=Artifact, tag=None - ): - artifacts = [] - if tag and tag != "*": - self._set_tag_in_artifact_struct(artifact_struct, tag) - artifacts.append(artifact_struct) - else: - tag_results = self._query(session, cls.Tag, obj_id=artifact_id).all() - if not tag_results: - return [artifact_struct] - for tag_object in tag_results: - artifact_with_tag = deepcopy(artifact_struct) - self._set_tag_in_artifact_struct(artifact_with_tag, tag_object.name) - artifacts.append(artifact_with_tag) - return artifacts - @staticmethod def _set_tag_in_artifact_struct(artifact, tag): artifact["metadata"]["tag"] = tag @@ -1225,40 +1222,69 @@ def _delete_artifacts_tags( def _find_artifacts( self, - session, - project, - ids=None, - tag=None, - labels=None, - since=None, - until=None, - name=None, - kind=None, + session: Session, + project: str, + ids: typing.Union[list[str], str] = None, + tag: str = None, + labels: typing.Union[list[str], str] = None, + since: datetime = None, + until: datetime = None, + name: str = None, + kind: mlrun.common.schemas.ArtifactCategories = None, category: mlrun.common.schemas.ArtifactCategories = None, - iter=None, - uid=None, - producer_id=None, - best_iteration=False, - most_recent=False, - producer_uri=None, - ): + iter: int = None, + uid: str = None, + producer_id: str = None, + best_iteration: bool = False, + most_recent: bool = False, + attach_tags: bool = False, + ) -> typing.Union[ + list[tuple[ArtifactV2, str]], + list[ArtifactV2], + ]: + """ + Find artifacts by the given filters. + + :param session: DB session + :param project: Project name + :param ids: Artifact IDs to filter by + :param tag: Tag to filter by + :param labels: Labels to filter by + :param since: Filter artifacts that were updated after this time + :param until: Filter artifacts that were updated before this time + :param name: Artifact name to filter by + :param kind: Artifact kind to filter by + :param category: Artifact category to filter by (if kind is not given) + :param iter: Artifact iteration to filter by + :param uid: Artifact UID to filter by + :param producer_id: Artifact producer ID to filter by + :param best_iteration: Filter by best iteration artifacts + :param most_recent: Filter by most recent artifacts + :param attach_tags: Whether to return a list of tuples of (ArtifactV2, tag_name). If False, only ArtifactV2 + + :return: a list of tuples of (ArtifactV2, tag_name) or a list of ArtifactV2 (if attach_tags is False) + """ if category and kind: message = "Category and Kind filters can't be given together" logger.warning(message, kind=kind, category=category) raise ValueError(message) - query = self._query(session, ArtifactV2, project=project) + query = session.query(ArtifactV2, ArtifactV2.Tag.name) - if ids and ids != "*": - query = query.filter(ArtifactV2.id.in_(ids)) + # join on tags if tag and tag != "*": - obj_name = name or None - object_tag_uids = self._resolve_class_tag_uids( - session, ArtifactV2, project, tag, obj_name + # If a tag is given, we can just join (faster than outer join) and filter on the tag + query = query.join(ArtifactV2.Tag, ArtifactV2.Tag.obj_id == ArtifactV2.id) + query = query.filter(ArtifactV2.Tag.name == tag) + else: + # If no tag is given, we need to outer join to get all artifacts, even if they don't have tags + query = query.outerjoin( + ArtifactV2.Tag, ArtifactV2.Tag.obj_id == ArtifactV2.id ) - if not object_tag_uids: - return [] - query = query.filter(ArtifactV2.uid.in_(object_tag_uids)) + if project: + query = query.filter(ArtifactV2.project == project) + if ids and ids != "*": + query = query.filter(ArtifactV2.id.in_(ids)) if uid: query = query.filter(ArtifactV2.uid == uid) if name: @@ -1285,27 +1311,13 @@ def _find_artifacts( if most_recent: query = self._attach_most_recent_artifact_query(session, query) - # Producer URI usually points to a run and is used to filter artifacts by the run that produced them when - # the artifact producer id is a workflow id (artifact was created as part of a workflow). - if producer_uri: - artifacts = [] - for artifact in query: - artifact_struct = artifact.full_object - artifact_struct.setdefault("spec", {}).setdefault("producer", {}) - artifact_producer_uri = artifact_struct["spec"]["producer"].get( - "uri", None - ) - # We check if the producer uri is a substring of the artifact producer uri because the producer uri - # may contain additional information (like the run iteration) that we don't want to filter by. - if ( - artifact_producer_uri is not None - and producer_uri in artifact_producer_uri - ): - artifacts.append(artifact) + artifacts_and_tags = query.all() - return artifacts + if not attach_tags: + # we might have duplicate records due to the tagging mechanism, so we need to deduplicate + return list({artifact for artifact, _ in artifacts_and_tags}) - return query.all() + return artifacts_and_tags def _add_artifact_name_query(self, query, name=None): if not name: @@ -2374,16 +2386,12 @@ def _calculate_feature_sets_counters(self, session) -> dict[str, int]: return project_to_feature_set_count def _calculate_models_counters(self, session) -> dict[str, int]: - import mlrun.artifacts - - # The kind filter is applied post the query to the DB (manually in python code), so counting should be that - # way as well, therefore we're doing it here, and can't do it with sql as the above # We're using the "most_recent" which gives us only one version of each artifact key, which is what we want to # count (artifact count, not artifact versions count) model_artifacts = self._find_artifacts( session, None, - kind=mlrun.artifacts.model.ModelArtifact.kind, + kind=mlrun.common.schemas.ArtifactCategories.model, most_recent=True, ) project_to_models_count = collections.defaultdict(int) @@ -2392,8 +2400,6 @@ def _calculate_models_counters(self, session) -> dict[str, int]: return project_to_models_count def _calculate_files_counters(self, session) -> dict[str, int]: - # The category filter is applied post the query to the DB (manually in python code), so counting should be that - # way as well, therefore we're doing it here, and can't do it with sql as the above # We're using the "most_recent" flag which gives us only one version of each artifact key, which is what we # want to count (artifact count, not artifact versions count) file_artifacts = self._find_artifacts( diff --git a/tests/api/db/test_sqldb.py b/tests/api/db/test_sqldb.py index 9aeab691ff35..04b26674e7c9 100644 --- a/tests/api/db/test_sqldb.py +++ b/tests/api/db/test_sqldb.py @@ -23,6 +23,7 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session +import mlrun.artifacts import mlrun.common.schemas import server.api.db.sqldb.models from mlrun.lists import ArtifactList @@ -84,15 +85,23 @@ def test_list_artifact_date(db: SQLDB, db_session: Session): t3 = t2 - timedelta(days=7) project = "p7" + # create artifacts in the db directly to avoid the store_artifact function which sets the updated field artifacts_to_create = [] for key, updated, producer_id in [ ("k1", t1, "p1"), ("k2", t2, "p2"), ("k3", t3, "p3"), ]: + artifact_struct = mlrun.artifacts.Artifact( + metadata=mlrun.artifacts.ArtifactMetadata( + key=key, project=project, tree=producer_id + ), + spec=mlrun.artifacts.ArtifactSpec(), + ) db_artifact = ArtifactV2( project=project, key=key, updated=updated, producer_id=producer_id ) + db_artifact.full_object = artifact_struct.to_dict() artifacts_to_create.append(db_artifact) db._upsert(db_session, artifacts_to_create)