Skip to content

Commit

Permalink
[Project] Retain producers of exported artifacts (#5283)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomerShor committed Mar 19, 2024
1 parent 0f21baf commit 3a28f9d
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 16 deletions.
3 changes: 2 additions & 1 deletion mlrun/artifacts/base.py
Expand Up @@ -88,9 +88,10 @@ class ArtifactSpec(ModelObj):
"db_key",
"extra_data",
"unpackaging_instructions",
"producer",
]

_extra_fields = ["annotations", "producer", "sources", "license", "encoding"]
_extra_fields = ["annotations", "sources", "license", "encoding"]
_exclude_fields_from_uid_hash = [
# if the artifact is first created, it will not have a db_key,
# exclude it so further updates of the artifacts will have the same hash
Expand Down
122 changes: 107 additions & 15 deletions mlrun/projects/project.py
Expand Up @@ -1375,14 +1375,7 @@ def register_artifacts(self):
artifact_path = mlrun.utils.helpers.template_artifact_path(
self.spec.artifact_path or mlrun.mlconf.artifact_path, self.metadata.name
)
# TODO: To correctly maintain the list of artifacts from an exported project,
# we need to maintain the different trees that generated them
producer = ArtifactProducer(
"project",
self.metadata.name,
self.metadata.name,
tag=self._get_hexsha() or str(uuid.uuid4()),
)
project_tag = self._get_project_tag()
for artifact_dict in self.spec.artifacts:
if _is_imported_artifact(artifact_dict):
import_from = artifact_dict["import_from"]
Expand All @@ -1402,6 +1395,15 @@ def register_artifacts(self):
artifact.src_path = path.join(
self.spec.get_code_path(), artifact.src_path
)
producer = self._resolve_artifact_producer(artifact, project_tag)
# log the artifact only if it doesn't already exist
if (
producer.name != self.metadata.name
and self._resolve_existing_artifact(
artifact,
)
):
continue
artifact_manager.log_artifact(
producer, artifact, artifact_path=artifact_path
)
Expand Down Expand Up @@ -1498,12 +1500,20 @@ def log_artifact(
artifact_path = mlrun.utils.helpers.template_artifact_path(
artifact_path, self.metadata.name
)
producer = ArtifactProducer(
"project",
self.metadata.name,
self.metadata.name,
tag=self._get_hexsha() or str(uuid.uuid4()),
)
producer = self._resolve_artifact_producer(item)
if producer.name != self.metadata.name:
# the artifact producer is retained, log it only if it doesn't already exist
if existing_artifact := self._resolve_existing_artifact(
item,
tag,
):
artifact_key = item if isinstance(item, str) else item.key
logger.info(
"Artifact already exists, skipping logging",
key=artifact_key,
tag=tag,
)
return existing_artifact
item = am.log_artifact(
producer,
item,
Expand Down Expand Up @@ -3383,7 +3393,12 @@ def get_artifact(self, key, tag=None, iter=None, tree=None):
artifact = db.read_artifact(
key, tag, iter=iter, project=self.metadata.name, tree=tree
)
return dict_to_artifact(artifact)

# in tests, if an artifact is not found, the db returns None
# in real usage, the db should raise an exception
if artifact:
return dict_to_artifact(artifact)
return None

def list_artifacts(
self,
Expand Down Expand Up @@ -3811,6 +3826,83 @@ def _validate_file_path(self, file_path: str, param_name: str):
f"<project.spec.get_code_path()>/<{param_name}>)."
)

def _resolve_artifact_producer(
self,
artifact: typing.Union[str, Artifact],
project_producer_tag: str = None,
) -> typing.Optional[ArtifactProducer]:
"""
Resolve the artifact producer of the given artifact.
If the artifact's producer is a run, the artifact is registered with the original producer.
Otherwise, the artifact is registered with the current project as the producer.
:param artifact: The artifact to resolve its producer.
:param project_producer_tag: The tag to use for the project as the producer. If not provided, a tag will be
generated for the project.
:return: A tuple of the resolved producer and the resolved artifact.
"""

if not isinstance(artifact, str) and artifact.producer:
# if the artifact was imported from a yaml file, the producer can be a dict
if isinstance(artifact.spec.producer, ArtifactProducer):
producer_dict = artifact.spec.producer.get_meta()
else:
producer_dict = artifact.spec.producer

if producer_dict.get("kind", "") == "run":
return ArtifactProducer(
name=producer_dict.get("name", ""),
kind=producer_dict.get("kind", ""),
project=producer_dict.get("project", ""),
tag=producer_dict.get("tag", ""),
)

# do not retain the artifact's producer, replace it with the project as the producer
project_producer_tag = project_producer_tag or self._get_project_tag()
return ArtifactProducer(
kind="project",
name=self.metadata.name,
project=self.metadata.name,
tag=project_producer_tag,
)

def _resolve_existing_artifact(
self,
item: typing.Union[str, Artifact],
tag: str = None,
) -> typing.Optional[Artifact]:
"""
Check if there is and existing artifact with the given item and tag.
If there is, return the existing artifact. Otherwise, return None.
:param item: The item (or key) to check if there is an existing artifact for.
:param tag: The tag to check if there is an existing artifact for.
:return: The existing artifact if there is one, otherwise None.
"""
try:
if isinstance(item, str):
existing_artifact = self.get_artifact(key=item, tag=tag)
else:
existing_artifact = self.get_artifact(
key=item.key,
tag=item.tag,
iter=item.iter,
tree=item.tree,
)
if existing_artifact is not None:
return existing_artifact.from_dict(existing_artifact)
except mlrun.errors.MLRunNotFoundError:
logger.debug(
"No existing artifact was found",
key=item if isinstance(item, str) else item.key,
tag=tag if isinstance(item, str) else item.tag,
tree=None if isinstance(item, str) else item.tree,
)
return None

def _get_project_tag(self):
return self._get_hexsha() or str(uuid.uuid4())


def _set_as_current_default_project(project: MlrunProject):
mlrun.mlconf.default_project = project.metadata.name
Expand Down
30 changes: 30 additions & 0 deletions tests/artifacts/test_artifacts.py
Expand Up @@ -22,6 +22,7 @@

import pandas as pd
import pytest
import yaml

import mlrun
import mlrun.artifacts
Expand Down Expand Up @@ -587,3 +588,32 @@ def test_register_artifacts(rundb_mock):

artifact = project.get_artifact(artifact_key)
assert artifact.tree == expected_tree


def test_producer_in_exported_artifact():
project_name = "my-project"
project = mlrun.new_project(project_name, save=False)

artifact = project.log_artifact(
"x", body="123", is_inline=True, artifact_path=results_dir
)

assert artifact.producer.get("kind") == "project"
assert artifact.producer.get("name") == project_name

artifact_path = f"{results_dir}/x.yaml"
artifact.export(artifact_path)

with open(artifact_path) as file:
exported_artifact = yaml.load(file, Loader=yaml.FullLoader)
assert "producer" in exported_artifact["spec"]
assert exported_artifact["spec"]["producer"]["kind"] == "project"
assert exported_artifact["spec"]["producer"]["name"] == project_name

# remove the producer from the artifact and export it again
artifact.producer = None
artifact.export(artifact_path)

with open(artifact_path) as file:
exported_artifact = yaml.load(file, Loader=yaml.FullLoader)
assert "producer" not in exported_artifact["spec"]
89 changes: 89 additions & 0 deletions tests/projects/test_project.py
Expand Up @@ -27,6 +27,7 @@
import pytest

import mlrun
import mlrun.artifacts
import mlrun.common.schemas
import mlrun.errors
import mlrun.projects.project
Expand Down Expand Up @@ -900,6 +901,94 @@ def test_import_artifact_using_relative_path():
assert artifact.spec.db_key == "y"


def test_import_artifact_retain_producer(rundb_mock):
base_path = tests.conftest.results
project_1 = mlrun.new_project(
name="project-1", context=f"{base_path}/project_1", save=False
)
project_2 = mlrun.new_project(
name="project-2", context=f"{base_path}/project_2", save=False
)

# create an artifact with a 'run' producer
artifact = mlrun.artifacts.Artifact(key="x", body="123", is_inline=True)
run_name = "my-run"
run_tag = "some-tag"

# we set the producer as dict so the export will work
artifact.producer = mlrun.artifacts.ArtifactProducer(
kind="run",
project=project_1.name,
name=run_name,
tag=run_tag,
).get_meta()

# export the artifact
artifact_path = f"{base_path}/my-artifact.yaml"
artifact.export(artifact_path)

# import the artifact to another project
new_key = "y"
imported_artifact = project_2.import_artifact(artifact_path, new_key)
assert imported_artifact.producer == artifact.producer

# set the artifact on the first project
project_1.set_artifact(artifact.key, artifact)
project_1.save()

# load a new project from the first project's context
project_3 = mlrun.load_project(name="project-3", context=project_1.context)

# make sure the artifact was registered with the original producer
# the db key should include the run since it's a run artifact
db_key = f"{run_name}_{new_key}"
loaded_artifact = project_3.get_artifact(db_key)
assert loaded_artifact.producer == artifact.producer


def test_replace_exported_artifact_producer(rundb_mock):
base_path = tests.conftest.results
project_1 = mlrun.new_project(
name="project-1", context=f"{base_path}/project_1", save=False
)
project_2 = mlrun.new_project(
name="project-2", context=f"{base_path}/project_2", save=False
)

# create an artifact with a 'project' producer
key = "x"
artifact = mlrun.artifacts.Artifact(key=key, body="123", is_inline=True)

# we set the producer as dict so the export will work
artifact.producer = mlrun.artifacts.ArtifactProducer(
kind="project",
project=project_1.name,
name=project_1.name,
).get_meta()

# export the artifact
artifact_path = f"{base_path}/my-artifact.yaml"
artifact.export(artifact_path)

# import the artifact to another project
new_key = "y"
imported_artifact = project_2.import_artifact(artifact_path, new_key)
assert imported_artifact.producer != artifact.producer
assert imported_artifact.producer["name"] == project_2.name

# set the artifact on the first project
project_1.set_artifact(artifact.key, artifact)
project_1.save()

# load a new project from the first project's context
project_3 = mlrun.load_project(name="project-3", context=project_1.context)

# make sure the artifact was registered with the new project producer
loaded_artifact = project_3.get_artifact(key)
assert loaded_artifact.producer != artifact.producer
assert loaded_artifact.producer["name"] == project_3.name


@pytest.mark.parametrize(
"relative_artifact_path,project_context,expected_path,expected_in_context",
[
Expand Down

0 comments on commit 3a28f9d

Please sign in to comment.