diff --git a/mlrun/artifacts/base.py b/mlrun/artifacts/base.py index 323020c8300..52e393ee3c5 100644 --- a/mlrun/artifacts/base.py +++ b/mlrun/artifacts/base.py @@ -88,9 +88,10 @@ class ArtifactSpec(ModelObj): "db_key", "extra_data", "unpackaging_instructions", + "producer", ] - _extra_fields = ["annotations", "producer", "sources", "license", "encoding"] + _extra_fields = ["annotations", "sources", "license", "encoding"] _exclude_fields_from_uid_hash = [ # if the artifact is first created, it will not have a db_key, # exclude it so further updates of the artifacts will have the same hash diff --git a/mlrun/projects/project.py b/mlrun/projects/project.py index 4b50578d5f4..4484795a429 100644 --- a/mlrun/projects/project.py +++ b/mlrun/projects/project.py @@ -1375,14 +1375,7 @@ def register_artifacts(self): artifact_path = mlrun.utils.helpers.template_artifact_path( self.spec.artifact_path or mlrun.mlconf.artifact_path, self.metadata.name ) - # TODO: To correctly maintain the list of artifacts from an exported project, - # we need to maintain the different trees that generated them - producer = ArtifactProducer( - "project", - self.metadata.name, - self.metadata.name, - tag=self._get_hexsha() or str(uuid.uuid4()), - ) + project_tag = self._get_project_tag() for artifact_dict in self.spec.artifacts: if _is_imported_artifact(artifact_dict): import_from = artifact_dict["import_from"] @@ -1402,6 +1395,15 @@ def register_artifacts(self): artifact.src_path = path.join( self.spec.get_code_path(), artifact.src_path ) + producer = self._resolve_artifact_producer(artifact, project_tag) + # log the artifact only if it doesn't already exist + if ( + producer.name != self.metadata.name + and self._resolve_existing_artifact( + artifact, + ) + ): + continue artifact_manager.log_artifact( producer, artifact, artifact_path=artifact_path ) @@ -1498,12 +1500,20 @@ def log_artifact( artifact_path = mlrun.utils.helpers.template_artifact_path( artifact_path, self.metadata.name ) - producer = ArtifactProducer( - "project", - self.metadata.name, - self.metadata.name, - tag=self._get_hexsha() or str(uuid.uuid4()), - ) + producer = self._resolve_artifact_producer(item) + if producer.name != self.metadata.name: + # the artifact producer is retained, log it only if it doesn't already exist + if existing_artifact := self._resolve_existing_artifact( + item, + tag, + ): + artifact_key = item if isinstance(item, str) else item.key + logger.info( + "Artifact already exists, skipping logging", + key=artifact_key, + tag=tag, + ) + return existing_artifact item = am.log_artifact( producer, item, @@ -3383,7 +3393,12 @@ def get_artifact(self, key, tag=None, iter=None, tree=None): artifact = db.read_artifact( key, tag, iter=iter, project=self.metadata.name, tree=tree ) - return dict_to_artifact(artifact) + + # in tests, if an artifact is not found, the db returns None + # in real usage, the db should raise an exception + if artifact: + return dict_to_artifact(artifact) + return None def list_artifacts( self, @@ -3811,6 +3826,83 @@ def _validate_file_path(self, file_path: str, param_name: str): f"/<{param_name}>)." ) + def _resolve_artifact_producer( + self, + artifact: typing.Union[str, Artifact], + project_producer_tag: str = None, + ) -> typing.Optional[ArtifactProducer]: + """ + Resolve the artifact producer of the given artifact. + If the artifact's producer is a run, the artifact is registered with the original producer. + Otherwise, the artifact is registered with the current project as the producer. + + :param artifact: The artifact to resolve its producer. + :param project_producer_tag: The tag to use for the project as the producer. If not provided, a tag will be + generated for the project. + :return: A tuple of the resolved producer and the resolved artifact. + """ + + if not isinstance(artifact, str) and artifact.producer: + # if the artifact was imported from a yaml file, the producer can be a dict + if isinstance(artifact.spec.producer, ArtifactProducer): + producer_dict = artifact.spec.producer.get_meta() + else: + producer_dict = artifact.spec.producer + + if producer_dict.get("kind", "") == "run": + return ArtifactProducer( + name=producer_dict.get("name", ""), + kind=producer_dict.get("kind", ""), + project=producer_dict.get("project", ""), + tag=producer_dict.get("tag", ""), + ) + + # do not retain the artifact's producer, replace it with the project as the producer + project_producer_tag = project_producer_tag or self._get_project_tag() + return ArtifactProducer( + kind="project", + name=self.metadata.name, + project=self.metadata.name, + tag=project_producer_tag, + ) + + def _resolve_existing_artifact( + self, + item: typing.Union[str, Artifact], + tag: str = None, + ) -> typing.Optional[Artifact]: + """ + Check if there is and existing artifact with the given item and tag. + If there is, return the existing artifact. Otherwise, return None. + + :param item: The item (or key) to check if there is an existing artifact for. + :param tag: The tag to check if there is an existing artifact for. + :return: The existing artifact if there is one, otherwise None. + """ + try: + if isinstance(item, str): + existing_artifact = self.get_artifact(key=item, tag=tag) + else: + existing_artifact = self.get_artifact( + key=item.key, + tag=item.tag, + iter=item.iter, + tree=item.tree, + ) + if existing_artifact is not None: + return existing_artifact.from_dict(existing_artifact) + except mlrun.errors.MLRunNotFoundError: + logger.debug( + "No existing artifact was found", + key=item if isinstance(item, str) else item.key, + tag=tag if isinstance(item, str) else item.tag, + tree=None if isinstance(item, str) else item.tree, + ) + return None + + def _get_project_tag(self): + return self._get_hexsha() or str(uuid.uuid4()) + def _set_as_current_default_project(project: MlrunProject): mlrun.mlconf.default_project = project.metadata.name diff --git a/tests/artifacts/test_artifacts.py b/tests/artifacts/test_artifacts.py index 759d93220f2..d299eb6db27 100644 --- a/tests/artifacts/test_artifacts.py +++ b/tests/artifacts/test_artifacts.py @@ -22,6 +22,7 @@ import pandas as pd import pytest +import yaml import mlrun import mlrun.artifacts @@ -587,3 +588,32 @@ def test_register_artifacts(rundb_mock): artifact = project.get_artifact(artifact_key) assert artifact.tree == expected_tree + + +def test_producer_in_exported_artifact(): + project_name = "my-project" + project = mlrun.new_project(project_name, save=False) + + artifact = project.log_artifact( + "x", body="123", is_inline=True, artifact_path=results_dir + ) + + assert artifact.producer.get("kind") == "project" + assert artifact.producer.get("name") == project_name + + artifact_path = f"{results_dir}/x.yaml" + artifact.export(artifact_path) + + with open(artifact_path) as file: + exported_artifact = yaml.load(file, Loader=yaml.FullLoader) + assert "producer" in exported_artifact["spec"] + assert exported_artifact["spec"]["producer"]["kind"] == "project" + assert exported_artifact["spec"]["producer"]["name"] == project_name + + # remove the producer from the artifact and export it again + artifact.producer = None + artifact.export(artifact_path) + + with open(artifact_path) as file: + exported_artifact = yaml.load(file, Loader=yaml.FullLoader) + assert "producer" not in exported_artifact["spec"] diff --git a/tests/projects/test_project.py b/tests/projects/test_project.py index 0982fcfbd28..d3767041a36 100644 --- a/tests/projects/test_project.py +++ b/tests/projects/test_project.py @@ -27,6 +27,7 @@ import pytest import mlrun +import mlrun.artifacts import mlrun.common.schemas import mlrun.errors import mlrun.projects.project @@ -900,6 +901,94 @@ def test_import_artifact_using_relative_path(): assert artifact.spec.db_key == "y" +def test_import_artifact_retain_producer(rundb_mock): + base_path = tests.conftest.results + project_1 = mlrun.new_project( + name="project-1", context=f"{base_path}/project_1", save=False + ) + project_2 = mlrun.new_project( + name="project-2", context=f"{base_path}/project_2", save=False + ) + + # create an artifact with a 'run' producer + artifact = mlrun.artifacts.Artifact(key="x", body="123", is_inline=True) + run_name = "my-run" + run_tag = "some-tag" + + # we set the producer as dict so the export will work + artifact.producer = mlrun.artifacts.ArtifactProducer( + kind="run", + project=project_1.name, + name=run_name, + tag=run_tag, + ).get_meta() + + # export the artifact + artifact_path = f"{base_path}/my-artifact.yaml" + artifact.export(artifact_path) + + # import the artifact to another project + new_key = "y" + imported_artifact = project_2.import_artifact(artifact_path, new_key) + assert imported_artifact.producer == artifact.producer + + # set the artifact on the first project + project_1.set_artifact(artifact.key, artifact) + project_1.save() + + # load a new project from the first project's context + project_3 = mlrun.load_project(name="project-3", context=project_1.context) + + # make sure the artifact was registered with the original producer + # the db key should include the run since it's a run artifact + db_key = f"{run_name}_{new_key}" + loaded_artifact = project_3.get_artifact(db_key) + assert loaded_artifact.producer == artifact.producer + + +def test_replace_exported_artifact_producer(rundb_mock): + base_path = tests.conftest.results + project_1 = mlrun.new_project( + name="project-1", context=f"{base_path}/project_1", save=False + ) + project_2 = mlrun.new_project( + name="project-2", context=f"{base_path}/project_2", save=False + ) + + # create an artifact with a 'project' producer + key = "x" + artifact = mlrun.artifacts.Artifact(key=key, body="123", is_inline=True) + + # we set the producer as dict so the export will work + artifact.producer = mlrun.artifacts.ArtifactProducer( + kind="project", + project=project_1.name, + name=project_1.name, + ).get_meta() + + # export the artifact + artifact_path = f"{base_path}/my-artifact.yaml" + artifact.export(artifact_path) + + # import the artifact to another project + new_key = "y" + imported_artifact = project_2.import_artifact(artifact_path, new_key) + assert imported_artifact.producer != artifact.producer + assert imported_artifact.producer["name"] == project_2.name + + # set the artifact on the first project + project_1.set_artifact(artifact.key, artifact) + project_1.save() + + # load a new project from the first project's context + project_3 = mlrun.load_project(name="project-3", context=project_1.context) + + # make sure the artifact was registered with the new project producer + loaded_artifact = project_3.get_artifact(key) + assert loaded_artifact.producer != artifact.producer + assert loaded_artifact.producer["name"] == project_3.name + + @pytest.mark.parametrize( "relative_artifact_path,project_context,expected_path,expected_in_context", [