Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions dvc/commands/experiments/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@ def run(self):
num=self.args.num,
git_remote=self.args.git_remote,
)
tags = self.repo.scm.describe(exps)
remained = {baseline for baseline, tag in tags.items() if tag is None}
base = "refs/heads/"
ref_heads = self.repo.scm.describe(remained, base=base)

for baseline in exps:
tag = self.repo.scm.describe(baseline)
if not tag:
branch = self.repo.scm.describe(baseline, base="refs/heads")
if branch:
tag = branch.split("/")[-1]
name = tag if tag else baseline[:7]
name = (
tags[baseline]
or ref_heads[baseline][len(base) :]
or baseline[:7]
)
if not name_only:
print(f"{name}:")
for exp_name in exps[baseline]:
Expand Down
37 changes: 22 additions & 15 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,9 @@ def reproduce_celery(

def _log_reproduced(self, revs: Iterable[str], tmp_dir: bool = False):
names = []
rev_names = self.get_exact_name(revs)
for rev in revs:
name = self.get_exact_name(rev)
name = rev_names[rev]
names.append(name if name else rev[:7])
ui.write("\nRan experiment(s): {}".format(", ".join(names)))
if tmp_dir:
Expand Down Expand Up @@ -411,25 +412,31 @@ def get_branch_by_rev(
raise MultipleBranchError(rev, ref_infos)
return str(ref_infos[0])

def get_exact_name(self, rev: str):
def get_exact_name(self, revs: Iterable[str]) -> Dict[str, Optional[str]]:
"""Returns preferred name for the specified revision.

Prefers tags, branches (heads), experiments in that orer.
"""
result: Dict[str, Optional[str]] = {}
exclude = f"{EXEC_NAMESPACE}/*"
ref = self.scm.describe(rev, base=EXPS_NAMESPACE, exclude=exclude)
if ref:
try:
name = ExpRefInfo.from_ref(ref).name
if name:
return name
except InvalidExpRefError:
pass
if rev in self.stash_revs:
return self.stash_revs[rev].name
if rev in self.celery_queue.failed_stash.stash_revs:
return self.celery_queue.failed_stash.stash_revs[rev].name
return None
ref_dict = self.scm.describe(
revs, base=EXPS_NAMESPACE, exclude=exclude
)
for rev in revs:
name: Optional[str] = None
ref = ref_dict[rev]
if ref:
try:
name = ExpRefInfo.from_ref(ref).name
except InvalidExpRefError:
pass
if not name:
if rev in self.stash_revs:
name = self.stash_revs[rev].name
elif rev in self.celery_queue.failed_stash.stash_revs:
name = self.celery_queue.failed_stash.stash_revs[rev].name
result[rev] = name
return result

def get_running_exps(self, fetch_refs: bool = True) -> Dict[str, Any]:
"""Return info for running experiments."""
Expand Down
62 changes: 44 additions & 18 deletions dvc/repo/experiments/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@ def _collect_experiment_commit(
repo: "Repo",
exp_rev: str,
status: ExpStatus = ExpStatus.Success,
sha_only=True,
param_deps=False,
running=None,
onerror: Optional[Callable] = None,
is_baseline: bool = False,
):
from dvc.dependency import ParamsDependency, RepoDependency

Expand Down Expand Up @@ -94,18 +92,6 @@ def _collect_experiment_commit(
)
res["metrics"] = vals

if not sha_only and rev != "workspace":
name: Optional[str] = None
if is_baseline:
for refspec in ["refs/tags", "refs/heads"]:
name = repo.scm.describe(rev, base=refspec)
if name:
name = name.replace(f"{refspec}/", "")
break
name = name or repo.experiments.get_exact_name(rev)
if name:
res["name"] = name

return res


Expand Down Expand Up @@ -154,6 +140,46 @@ def _collect_experiment_branch(
return res


def get_names(repo: "Repo", result: Dict[str, Dict[str, Any]]):

rev_set = set()
baseline_set = set()
for baseline in result:
for rev in result[baseline]:
if rev == "baseline":
rev = baseline
baseline_set.add(rev)
if rev != "workspace":
rev_set.add(rev)

names: Dict[str, Optional[str]] = {}
for base in ("refs/tags/", "refs/heads/"):
if rev_set:
names.update(
(rev, ref[len(base) :])
for rev, ref in repo.scm.describe(
baseline_set, base=base
).items()
if ref is not None
)
rev_set.difference_update(names.keys())

exact_name = repo.experiments.get_exact_name(rev_set)

for baseline, baseline_results in result.items():
for rev, rev_result in baseline_results.items():
name: Optional[str] = None
if rev == "baseline":
rev = baseline
if rev == "workspace":
continue
name = names.get(rev, None)
name = name or exact_name[rev]
if name:
rev_result["data"]["name"] = name


# flake8: noqa: C901
def show(
repo: "Repo",
all_branches=False,
Expand Down Expand Up @@ -194,12 +220,10 @@ def show(
res[rev]["baseline"] = _collect_experiment_commit(
repo,
rev,
sha_only=sha_only,
status=status,
param_deps=param_deps,
running=running,
onerror=onerror,
is_baseline=True,
)

if rev == "workspace":
Expand All @@ -220,7 +244,6 @@ def show(
repo,
exp_ref,
rev,
sha_only=sha_only,
param_deps=param_deps,
running=running,
onerror=onerror,
Expand Down Expand Up @@ -252,11 +275,14 @@ def show(
experiment = _collect_experiment_commit(
repo,
stash_rev,
sha_only=sha_only,
status=status,
param_deps=param_deps,
running=running,
onerror=onerror,
)
res[entry.baseline_rev][stash_rev] = experiment

if not sha_only:
get_names(repo, res)

return res
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ install_requires =
rich>=10.13.0
pyparsing>=2.4.7
typing-extensions>=3.7.4
scmrepo==0.1.1
scmrepo==0.1.2
dvc-render==0.0.12
dvc-task==0.1.4
dvclive>=0.10.0
Expand Down
6 changes: 3 additions & 3 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_new_simple(tmp_dir, scm, dvc, exp_stage, mocker, name, workspace):
assert (tmp_dir / "metrics.yaml").read_text().strip() == "foo: 2"

exp_name = name if name else ref_info.name
assert dvc.experiments.get_exact_name(exp) == exp_name
assert dvc.experiments.get_exact_name([exp])[exp] == exp_name
assert resolve_rev(scm, exp_name) == exp


Expand Down Expand Up @@ -481,7 +481,7 @@ def test_subdir(tmp_dir, scm, dvc, workspace):
with fs.open("dir/metrics.yaml", mode="r", encoding="utf-8") as fobj:
assert fobj.read().strip() == "foo: 2"

assert dvc.experiments.get_exact_name(exp) == ref_info.name
assert dvc.experiments.get_exact_name([exp])[exp] == ref_info.name
assert resolve_rev(scm, ref_info.name) == exp


Expand Down Expand Up @@ -525,7 +525,7 @@ def test_subrepo(tmp_dir, scm, workspace):
with fs.open("dir/repo/metrics.yaml", mode="r", encoding="utf-8") as fobj:
assert fobj.read().strip() == "foo: 2"

assert subrepo.dvc.experiments.get_exact_name(exp) == ref_info.name
assert subrepo.dvc.experiments.get_exact_name([exp])[exp] == ref_info.name
assert resolve_rev(scm, ref_info.name) == exp


Expand Down
2 changes: 1 addition & 1 deletion tests/func/experiments/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_show_checkpoint(

for i, rev in enumerate(checkpoints):
if i == 0:
name = dvc.experiments.get_exact_name(rev)
name = dvc.experiments.get_exact_name([rev])[rev]
name = f"{rev[:7]} [{name}]"
fs = "╓"
elif i == len(checkpoints) - 1:
Expand Down