From 4ddf1f0d9fe92190bbba58a4d3bd2e5558a0d362 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Wed, 19 Oct 2022 17:03:13 +0800 Subject: [PATCH] exp show: Use batch call on `scm.describe` fix: #8451 1. Change the call of scm.describe from individual revision to a collection of them. Accelerate the run speed of `exp show` 2. Bump scmrepo to 0.1.2 --- dvc/commands/experiments/ls.py | 16 +++--- dvc/repo/experiments/__init__.py | 37 +++++++------ dvc/repo/experiments/show.py | 62 +++++++++++++++------- setup.cfg | 2 +- tests/func/experiments/test_experiments.py | 6 +-- tests/func/experiments/test_show.py | 2 +- 6 files changed, 81 insertions(+), 44 deletions(-) diff --git a/dvc/commands/experiments/ls.py b/dvc/commands/experiments/ls.py index 7b09a573d6..d279c98c29 100644 --- a/dvc/commands/experiments/ls.py +++ b/dvc/commands/experiments/ls.py @@ -16,13 +16,17 @@ def run(self): num=self.args.num, git_remote=self.args.git_remote, ) + tags = self.repo.scm.describe(exps) + remained = {baseline for baseline, tag in tags.items() if tag is None} + base = "refs/heads/" + ref_heads = self.repo.scm.describe(remained, base=base) + for baseline in exps: - tag = self.repo.scm.describe(baseline) - if not tag: - branch = self.repo.scm.describe(baseline, base="refs/heads") - if branch: - tag = branch.split("/")[-1] - name = tag if tag else baseline[:7] + name = ( + tags[baseline] + or ref_heads[baseline][len(base) :] + or baseline[:7] + ) if not name_only: print(f"{name}:") for exp_name in exps[baseline]: diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 0b0347ccae..d2a7e2a40f 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -211,8 +211,9 @@ def reproduce_celery( def _log_reproduced(self, revs: Iterable[str], tmp_dir: bool = False): names = [] + rev_names = self.get_exact_name(revs) for rev in revs: - name = self.get_exact_name(rev) + name = rev_names[rev] names.append(name if name else rev[:7]) ui.write("\nRan experiment(s): {}".format(", ".join(names))) if tmp_dir: @@ -411,25 +412,31 @@ def get_branch_by_rev( raise MultipleBranchError(rev, ref_infos) return str(ref_infos[0]) - def get_exact_name(self, rev: str): + def get_exact_name(self, revs: Iterable[str]) -> Dict[str, Optional[str]]: """Returns preferred name for the specified revision. Prefers tags, branches (heads), experiments in that orer. """ + result: Dict[str, Optional[str]] = {} exclude = f"{EXEC_NAMESPACE}/*" - ref = self.scm.describe(rev, base=EXPS_NAMESPACE, exclude=exclude) - if ref: - try: - name = ExpRefInfo.from_ref(ref).name - if name: - return name - except InvalidExpRefError: - pass - if rev in self.stash_revs: - return self.stash_revs[rev].name - if rev in self.celery_queue.failed_stash.stash_revs: - return self.celery_queue.failed_stash.stash_revs[rev].name - return None + ref_dict = self.scm.describe( + revs, base=EXPS_NAMESPACE, exclude=exclude + ) + for rev in revs: + name: Optional[str] = None + ref = ref_dict[rev] + if ref: + try: + name = ExpRefInfo.from_ref(ref).name + except InvalidExpRefError: + pass + if not name: + if rev in self.stash_revs: + name = self.stash_revs[rev].name + elif rev in self.celery_queue.failed_stash.stash_revs: + name = self.celery_queue.failed_stash.stash_revs[rev].name + result[rev] = name + return result def get_running_exps(self, fetch_refs: bool = True) -> Dict[str, Any]: """Return info for running experiments.""" diff --git a/dvc/repo/experiments/show.py b/dvc/repo/experiments/show.py index 8cdf8b3537..9ffef26160 100644 --- a/dvc/repo/experiments/show.py +++ b/dvc/repo/experiments/show.py @@ -31,11 +31,9 @@ def _collect_experiment_commit( repo: "Repo", exp_rev: str, status: ExpStatus = ExpStatus.Success, - sha_only=True, param_deps=False, running=None, onerror: Optional[Callable] = None, - is_baseline: bool = False, ): from dvc.dependency import ParamsDependency, RepoDependency @@ -94,18 +92,6 @@ def _collect_experiment_commit( ) res["metrics"] = vals - if not sha_only and rev != "workspace": - name: Optional[str] = None - if is_baseline: - for refspec in ["refs/tags", "refs/heads"]: - name = repo.scm.describe(rev, base=refspec) - if name: - name = name.replace(f"{refspec}/", "") - break - name = name or repo.experiments.get_exact_name(rev) - if name: - res["name"] = name - return res @@ -154,6 +140,46 @@ def _collect_experiment_branch( return res +def get_names(repo: "Repo", result: Dict[str, Dict[str, Any]]): + + rev_set = set() + baseline_set = set() + for baseline in result: + for rev in result[baseline]: + if rev == "baseline": + rev = baseline + baseline_set.add(rev) + if rev != "workspace": + rev_set.add(rev) + + names: Dict[str, Optional[str]] = {} + for base in ("refs/tags/", "refs/heads/"): + if rev_set: + names.update( + (rev, ref[len(base) :]) + for rev, ref in repo.scm.describe( + baseline_set, base=base + ).items() + if ref is not None + ) + rev_set.difference_update(names.keys()) + + exact_name = repo.experiments.get_exact_name(rev_set) + + for baseline, baseline_results in result.items(): + for rev, rev_result in baseline_results.items(): + name: Optional[str] = None + if rev == "baseline": + rev = baseline + if rev == "workspace": + continue + name = names.get(rev, None) + name = name or exact_name[rev] + if name: + rev_result["data"]["name"] = name + + +# flake8: noqa: C901 def show( repo: "Repo", all_branches=False, @@ -194,12 +220,10 @@ def show( res[rev]["baseline"] = _collect_experiment_commit( repo, rev, - sha_only=sha_only, status=status, param_deps=param_deps, running=running, onerror=onerror, - is_baseline=True, ) if rev == "workspace": @@ -220,7 +244,6 @@ def show( repo, exp_ref, rev, - sha_only=sha_only, param_deps=param_deps, running=running, onerror=onerror, @@ -252,11 +275,14 @@ def show( experiment = _collect_experiment_commit( repo, stash_rev, - sha_only=sha_only, status=status, param_deps=param_deps, running=running, onerror=onerror, ) res[entry.baseline_rev][stash_rev] = experiment + + if not sha_only: + get_names(repo, res) + return res diff --git a/setup.cfg b/setup.cfg index 281cd25a5b..b5d592ee2d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -63,7 +63,7 @@ install_requires = rich>=10.13.0 pyparsing>=2.4.7 typing-extensions>=3.7.4 - scmrepo==0.1.1 + scmrepo==0.1.2 dvc-render==0.0.12 dvc-task==0.1.4 dvclive>=0.10.0 diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 48625bf31b..1291e06f8a 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -38,7 +38,7 @@ def test_new_simple(tmp_dir, scm, dvc, exp_stage, mocker, name, workspace): assert (tmp_dir / "metrics.yaml").read_text().strip() == "foo: 2" exp_name = name if name else ref_info.name - assert dvc.experiments.get_exact_name(exp) == exp_name + assert dvc.experiments.get_exact_name([exp])[exp] == exp_name assert resolve_rev(scm, exp_name) == exp @@ -481,7 +481,7 @@ def test_subdir(tmp_dir, scm, dvc, workspace): with fs.open("dir/metrics.yaml", mode="r", encoding="utf-8") as fobj: assert fobj.read().strip() == "foo: 2" - assert dvc.experiments.get_exact_name(exp) == ref_info.name + assert dvc.experiments.get_exact_name([exp])[exp] == ref_info.name assert resolve_rev(scm, ref_info.name) == exp @@ -525,7 +525,7 @@ def test_subrepo(tmp_dir, scm, workspace): with fs.open("dir/repo/metrics.yaml", mode="r", encoding="utf-8") as fobj: assert fobj.read().strip() == "foo: 2" - assert subrepo.dvc.experiments.get_exact_name(exp) == ref_info.name + assert subrepo.dvc.experiments.get_exact_name([exp])[exp] == ref_info.name assert resolve_rev(scm, ref_info.name) == exp diff --git a/tests/func/experiments/test_show.py b/tests/func/experiments/test_show.py index 546cad341b..21e59d0bef 100644 --- a/tests/func/experiments/test_show.py +++ b/tests/func/experiments/test_show.py @@ -245,7 +245,7 @@ def test_show_checkpoint( for i, rev in enumerate(checkpoints): if i == 0: - name = dvc.experiments.get_exact_name(rev) + name = dvc.experiments.get_exact_name([rev])[rev] name = f"{rev[:7]} [{name}]" fs = "╓" elif i == len(checkpoints) - 1: