Skip to content

Commit

Permalink
[API] Fix get run with artifacts and iterations (#5587)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonmr committed May 17, 2024
1 parent 9227bc2 commit 3e16327
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 8 deletions.
9 changes: 9 additions & 0 deletions server/api/crud/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,18 @@ def get_run(
else:
# Producer URI is the URI of the MLClientCtx object that produced the artifact
producer_uri = f"{project}/{run['metadata']['uid']}"
if iter:
producer_uri += f"-{iter}"

best_iteration = False
if not iter:
iter = None
best_iteration = True

artifacts = server.api.crud.Artifacts().list_artifacts(
db_session,
iter=iter,
best_iteration=best_iteration,
producer_id=producer_id,
producer_uri=producer_uri,
project=project,
Expand Down
12 changes: 10 additions & 2 deletions server/api/db/sqldb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ def tag_artifacts(
session,
ArtifactV2,
project=project,
).filter(
).with_entities(ArtifactV2.id).filter(
ArtifactV2.key.in_(artifacts_keys),
).order_by(ArtifactV2.id.asc()).populate_existing().with_for_update().all()

Expand Down Expand Up @@ -1305,7 +1305,15 @@ def _find_artifacts(
for artifact in query:
artifact_struct = artifact.full_object
artifact_struct.setdefault("spec", {}).setdefault("producer", {})
if artifact_struct["spec"]["producer"].get("uri") == producer_uri:
artifact_producer_uri = artifact_struct["spec"]["producer"].get(
"uri", None
)
# We check if the producer uri is a substring of the artifact producer uri because the producer uri
# may contain additional information (like the run iteration) that we don't want to filter by.
if (
artifact_producer_uri is not None
and producer_uri in artifact_producer_uri
):
artifacts.append(artifact)

return artifacts
Expand Down
149 changes: 143 additions & 6 deletions tests/api/crud/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,42 +460,179 @@ def test_get_workflow_run_restore_artifacts_metadata(

self._validate_run_artifacts(artifacts, db, project, run_uid)

def test_get_workflow_run_iteration_restore_artifacts_metadata(
self, db: sqlalchemy.orm.Session
):
project = "project-name"
run_uid = str(uuid.uuid4())
workflow_uid = str(uuid.uuid4())
iter = 3
artifacts = self._generate_artifacts(project, run_uid, workflow_uid, iter=iter)

for artifact in artifacts:
server.api.crud.Artifacts().store_artifact(
db,
artifact["spec"]["db_key"],
artifact,
iter=iter,
project=project,
)

server.api.crud.Runs().store_run(
db,
{
"metadata": {
"name": "run-name",
"uid": run_uid,
"iter": iter,
"labels": {
"kind": "job",
"workflow": workflow_uid,
},
},
"status": {
"artifacts": artifacts,
},
},
run_uid,
iter=iter,
project=project,
)

self._validate_run_artifacts(artifacts, db, project, run_uid, iter)

def test_get_workflow_run_best_iteration_restore_artifacts_metadata(
self, db: sqlalchemy.orm.Session
):
project = "project-name"
run_uid = str(uuid.uuid4())
workflow_uid = str(uuid.uuid4())
best_iteration = 3
best_iteration_count = 5
best_iteration_artifacts = self._generate_artifacts(
project,
run_uid,
workflow_uid,
artifacts_len=best_iteration_count,
iter=best_iteration,
)

for artifact in best_iteration_artifacts:
server.api.utils.singletons.db.get_db().store_artifact(
db,
artifact["spec"]["db_key"],
artifact,
None,
iter=best_iteration,
tag="latest",
project=project,
best_iteration=True,
)

bad_iteration = 5
bad_iteration_count = 3
bad_iteration_artifacts = self._generate_artifacts(
project,
run_uid,
workflow_uid,
artifacts_len=bad_iteration_count,
iter=bad_iteration,
key_prefix="bad_key",
)
for artifact in bad_iteration_artifacts:
server.api.crud.Artifacts().store_artifact(
db,
artifact["spec"]["db_key"],
artifact,
iter=bad_iteration,
project=project,
)

parent_run_count = 1
parent_run_arts = self._generate_artifacts(
project,
run_uid,
workflow_uid,
artifacts_len=parent_run_count,
key_prefix="parent_key",
)
for artifact in parent_run_arts:
server.api.crud.Artifacts().store_artifact(
db,
artifact["spec"]["db_key"],
artifact,
project=project,
)

server.api.crud.Runs().store_run(
db,
{
"metadata": {
"name": "run-name",
"uid": run_uid,
"labels": {
"kind": "job",
"workflow": workflow_uid,
},
},
},
run_uid,
project=project,
)

self._validate_run_artifacts(
best_iteration_artifacts + parent_run_arts, db, project, run_uid
)

@staticmethod
def _generate_artifacts(project, run_uid, workflow_uid=None, artifacts_len=2):
def _generate_artifacts(
project,
run_uid,
workflow_uid=None,
artifacts_len=2,
iter=None,
key_prefix="key",
):
artifacts = []
i = 0
while len(artifacts) < artifacts_len:
artifact = {
"kind": "artifact",
"metadata": {
"key": f"key{i}",
"key": f"{key_prefix}{i}",
"tree": workflow_uid or run_uid,
"uid": f"uid{i}",
"project": project,
"iter": None,
"iter": iter,
},
"spec": {
"db_key": f"db_key{i}",
},
"status": {},
}
if workflow_uid:
producer_uri = f"{project}/{run_uid}"
if iter:
producer_uri += f"-{iter}"
artifact["spec"]["producer"] = {
"uri": f"{project}/{run_uid}",
"uri": producer_uri,
}
artifacts.append(artifact)
i += 1
return artifacts

@staticmethod
def _validate_run_artifacts(artifacts, db, project, run_uid):
run = server.api.crud.Runs().get_run(db, run_uid, 0, project)
def _validate_run_artifacts(artifacts, db, project, run_uid, iter=0):
run = server.api.crud.Runs().get_run(db, run_uid, iter, project)
assert "artifacts" in run["status"]
enriched_artifacts = list(run["status"]["artifacts"])

def sort_by_key(e):
return e["metadata"]["key"]

assert len(enriched_artifacts) == len(
artifacts
), "Number of artifacts is different"
enriched_artifacts.sort(key=sort_by_key)
artifacts.sort(key=sort_by_key)
for artifact, enriched_artifact in zip(artifacts, enriched_artifacts):
Expand Down

0 comments on commit 3e16327

Please sign in to comment.