Skip to content

Commit

Permalink
Finish this PR
Browse files Browse the repository at this point in the history
1. move env name to dvc/env.py.
2. add some tests for it.
  • Loading branch information
karajan1001 committed Jul 27, 2021
1 parent 15036f3 commit 205ce57
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 31 deletions.
1 change: 1 addition & 0 deletions dvc/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
DVCLIVE_HTML = "DVCLIVE_HTML"
DVCLIVE_RESUME = "DVCLIVE_RESUME"
DVC_IGNORE_ISATTY = "DVC_IGNORE_ISATTY"
DVC_EXP_AUTO_PUSH = "DVC_EXP_AUTO_PUSH"
38 changes: 25 additions & 13 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from funcy import cached_property

from dvc.env import DVC_EXP_AUTO_PUSH
from dvc.exceptions import DvcException
from dvc.path_info import PathInfo
from dvc.repo import Repo
Expand Down Expand Up @@ -451,6 +452,28 @@ def _repro_args(cls, dvc):
kwargs = {}
return args, kwargs

@staticmethod
def _auto_push(git_remote: str, dvc: "Repo", scm: "Git"):
from dvc.repo.experiments.push import push

if git_remote == scm.root_dir:
logger.warning(
f"{DVC_EXP_AUTO_PUSH} {git_remote} is the running "
"repository auto push will not work"
)
return

branch = scm.get_ref(EXEC_BRANCH, follow=False)
branch_name = ExpRefInfo.from_ref(branch).name
push(
dvc,
git_remote,
branch_name,
push_cache=True,
run_cache=True,
)
logger.info({"pushed": branch_name})

@classmethod
def checkpoint_callback(
cls,
Expand All @@ -467,20 +490,9 @@ def checkpoint_callback(
scm, exp_hash, exp_name=name, force=force, checkpoint=True
)

git_remote = os.environ.get("DVC_EXP_AUTO_PUSH", None)
git_remote = os.environ.get(DVC_EXP_AUTO_PUSH, None)
if git_remote:
from dvc.repo.experiments.push import push

branch = scm.get_ref(EXEC_BRANCH, follow=False)
branch_name = ExpRefInfo.from_ref(branch).name
push(
dvc,
git_remote,
branch_name,
push_cache=True,
run_cache=True,
)
logger.info({"pushed": branch_name})
cls._auto_push(git_remote, dvc, scm)
logger.info("Checkpoint experiment iteration '%s'.", exp_rev[:7])
except UnchangedExperimentError:
pass
Expand Down
18 changes: 18 additions & 0 deletions tests/func/experiments/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,21 @@ def checkpoint_stage(tmp_dir, scm, dvc, mocker):
scm.commit("init")
stage.iterations = DEFAULT_ITERATIONS
return stage


@pytest.fixture
def git_upstream(tmp_dir, erepo_dir):
url = "file://{}".format(erepo_dir.resolve().as_posix())
tmp_dir.scm.gitpython.repo.create_remote("upstream", url)
erepo_dir.remote = "upstream"
erepo_dir.url = url
return erepo_dir


@pytest.fixture
def git_downstream(tmp_dir, erepo_dir):
url = "file://{}".format(tmp_dir.resolve().as_posix())
erepo_dir.scm.gitpython.repo.create_remote("upstream", url)
erepo_dir.remote = "upstream"
erepo_dir.url = url
return erepo_dir
74 changes: 74 additions & 0 deletions tests/func/experiments/test_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging
import os

import pytest
from funcy import first

from dvc.exceptions import DvcException
from dvc.repo.experiments import MultipleBranchError
from dvc.repo.experiments.base import EXEC_APPLY, EXEC_CHECKPOINT
from dvc.repo.experiments.utils import exp_refs_by_rev


@pytest.mark.parametrize("workspace", [True, False])
Expand Down Expand Up @@ -188,3 +192,73 @@ def test_resume_non_head_checkpoint(
)
new_head = first(results)
assert orig_branch != dvc.experiments.get_branch_by_rev(new_head)


@pytest.mark.parametrize("use_url", [True, False])
def test_auto_push_during_iterations(
tmp_dir, scm, dvc, checkpoint_stage, git_upstream, local_remote, use_url
):
# set up remote repo
remote = git_upstream.url if use_url else git_upstream.remote
git_upstream.scm.fetch_refspecs(str(tmp_dir), ["master:master"])

# without auto push
results = dvc.experiments.run(checkpoint_stage.addressing)
exp = first(results)
ref_info = first(exp_refs_by_rev(scm, exp))
assert git_upstream.scm.get_ref(str(ref_info)) is None

# add auto push
os.environ["DVC_EXP_AUTO_PUSH"] = remote
results = dvc.experiments.run(checkpoint_stage.addressing)
assert (tmp_dir / "foo").read_text() == "4"
exp = first(results)
ref_info = first(exp_refs_by_rev(scm, exp))
assert git_upstream.scm.get_ref(str(ref_info)) == exp

# check the data
with git_upstream.dvc.config.edit() as conf:
conf["remote"]["local"] = local_remote.config
conf["core"]["remote"] = "local"

git_upstream.dvc.experiments.apply(ref_info.name)
git_upstream.dvc.experiments.apply(exp)
git_upstream.dvc.pull()
assert (git_upstream / "foo").read_text() == "4"

# resume the remote checkpoint
os.environ.pop("DVC_EXP_AUTO_PUSH")
with git_upstream.chdir():
git_upstream.dvc.experiments.run(checkpoint_stage.addressing)
assert (git_upstream / "foo").read_text() == "6"


def test_auto_push_error_url(dvc, scm, checkpoint_stage, local_remote):
os.environ["DVC_EXP_AUTO_PUSH"] = "none"
assert (
dvc.experiments.run(checkpoint_stage.addressing, params=["foo=2"])
== {}
)


def test_auto_push_no_remote(dvc, scm, checkpoint_stage, git_upstream):
os.environ["DVC_EXP_AUTO_PUSH"] = git_upstream.url
assert (
dvc.experiments.run(checkpoint_stage.addressing, params=["foo=2"])
== {}
)


def test_auto_push_self_remote(tmp_dir, dvc, scm, checkpoint_stage, caplog):
root_dir = str(tmp_dir)
os.environ["DVC_EXP_AUTO_PUSH"] = root_dir
assert (
dvc.experiments.run(checkpoint_stage.addressing, params=["foo=2"])
!= {}
)

with caplog.at_level(logging.WARNING, logger="dvc"):
assert (
f"DVC_EXP_AUTO_PUSH {root_dir} is the running "
"repository auto push will not work" in caplog.messages
)
18 changes: 0 additions & 18 deletions tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,6 @@
from dvc.repo.experiments.utils import exp_refs_by_rev


@pytest.fixture
def git_upstream(tmp_dir, erepo_dir):
url = f"file://{erepo_dir.resolve().as_posix()}"
tmp_dir.scm.gitpython.repo.create_remote("upstream", url)
erepo_dir.remote = "upstream"
erepo_dir.url = url
return erepo_dir


@pytest.fixture
def git_downstream(tmp_dir, erepo_dir):
url = f"file://{tmp_dir.resolve().as_posix()}"
erepo_dir.scm.gitpython.repo.create_remote("upstream", url)
erepo_dir.remote = "upstream"
erepo_dir.url = url
return erepo_dir


@pytest.mark.parametrize("use_url", [True, False])
def test_push(tmp_dir, scm, dvc, git_upstream, exp_stage, use_url):
from dvc.exceptions import InvalidArgumentError
Expand Down

0 comments on commit 205ce57

Please sign in to comment.