Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions dvc/command/repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def run(self):
run_all=self.args.run_all,
jobs=self.args.jobs,
params=self.args.params,
pull=self.args.pull,
)

if len(stages) == 0:
Expand Down Expand Up @@ -198,4 +199,13 @@ def add_parser(subparsers, parent_parser):
repro_parser.add_argument(
"-j", "--jobs", type=int, help=argparse.SUPPRESS, metavar="<number>"
)
repro_parser.add_argument(
"--pull",
action="store_true",
default=False,
help=(
"Try automatically pulling missing cache for outputs restored "
"from the run-cache."
),
)
repro_parser.set_defaults(func=CmdRepro)
7 changes: 3 additions & 4 deletions dvc/repo/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def fetch(
remote is configured
"""

used_run_cache = self.stage_cache.pull(remote) if run_cache else []

if isinstance(targets, str):
targets = [targets]

Expand All @@ -51,13 +49,14 @@ def fetch(
remote=remote,
jobs=jobs,
recursive=recursive,
used_run_cache=used_run_cache,
)

downloaded = 0
failed = 0

try:
if run_cache:
self.stage_cache.pull(remote)
downloaded += self.cloud.pull(
used, jobs, remote=remote, show_checksums=show_checksums,
)
Expand All @@ -75,7 +74,7 @@ def fetch(
if failed:
raise DownloadError(failed)

return downloaded + len(used_run_cache)
return downloaded


def _fetch_external(self, repo_url, repo_rev, files, jobs):
Expand Down
4 changes: 2 additions & 2 deletions dvc/stage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,14 +421,14 @@ def commit(self):
out.commit()

@rwlocked(read=["deps"], write=["outs"])
def run(self, dry=False, no_commit=False, force=False, run_cache=True):
def run(self, dry=False, no_commit=False, force=False, **kwargs):
if (self.cmd or self.is_import) and not self.frozen and not dry:
self.remove_outs(ignore_remove=False, force=False)

if not self.frozen and self.is_import:
sync_import(self, dry, force)
elif not self.frozen and self.cmd:
run_stage(self, dry, force, run_cache)
run_stage(self, dry, force, **kwargs)
else:
args = (
("outputs", "frozen ") if self.frozen else ("data sources", "")
Expand Down
38 changes: 21 additions & 17 deletions dvc/stage/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,33 +156,37 @@ def save(self, stage):
dump_yaml(tmp, cache)
self.tree.move(PathInfo(tmp), path)

def _restore(self, stage):
stage.save_deps()
cache = self._load(stage)
if not cache:
raise RunCacheNotFoundError(stage)

StageLoader.fill_from_lock(stage, cache)
for out in self._uncached_outs(stage, cache):
out.checkout()

if not stage.outs_cached():
raise RunCacheNotFoundError(stage)

def restore(self, stage, run_cache=True):
def restore(self, stage, run_cache=True, pull=False):
if stage.is_callback or stage.always_changed:
raise RunCacheNotFoundError(stage)

if not stage.already_cached():
if (
not stage.changed_stage()
and stage.deps_cached()
and all(bool(out.hash_info) for out in stage.outs)
):
cache = to_single_stage_lockfile(stage)
else:
if not run_cache: # backward compatibility
raise RunCacheNotFoundError(stage)
self._restore(stage)
stage.save_deps()
cache = self._load(stage)
if not cache:
raise RunCacheNotFoundError(stage)

cached_stage = self._create_stage(cache, wdir=stage.wdir)

if pull:
self.repo.cloud.pull(cached_stage.get_used_cache())

if not cached_stage.outs_cached():
raise RunCacheNotFoundError(stage)

logger.info(
"Stage '%s' is cached - skipping run, checking out outputs",
stage.addressing,
)
stage.checkout()
cached_stage.checkout()

@staticmethod
def _transfer(func, from_remote, to_remote):
Expand Down
4 changes: 2 additions & 2 deletions dvc/stage/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ def cmd_run(stage, *args, **kwargs):
raise StageCmdFailedError(stage.cmd, retcode)


def run_stage(stage, dry=False, force=False, run_cache=False):
def run_stage(stage, dry=False, force=False, **kwargs):
if not (dry or force):
from .cache import RunCacheNotFoundError

try:
stage.repo.stage_cache.restore(stage, run_cache=run_cache)
stage.repo.stage_cache.restore(stage, **kwargs)
return
except RunCacheNotFoundError:
pass
Expand Down
37 changes: 31 additions & 6 deletions tests/func/test_run_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dvc.dvcfile import PIPELINE_LOCK
from dvc.utils import relpath
from dvc.utils.fs import remove


def _recurse_count_files(path):
Expand All @@ -15,7 +16,7 @@ def test_push_pull(tmp_dir, dvc, erepo_dir, run_copy, local_remote):
erepo_dir.add_remote(config=local_remote.config)
with erepo_dir.chdir():
assert not os.path.exists(erepo_dir.dvc.stage_cache.cache_dir)
assert erepo_dir.dvc.pull(run_cache=True)["fetched"] == 2
assert erepo_dir.dvc.pull(run_cache=True)["fetched"] == 0
assert os.listdir(erepo_dir.dvc.stage_cache.cache_dir)


Expand All @@ -32,7 +33,7 @@ def test_restore(tmp_dir, dvc, run_copy, mocker):

(stage,) = dvc.reproduce("copy-foo-bar")

mock_restore.assert_called_once_with(stage, run_cache=True)
mock_restore.assert_called_once_with(stage)
mock_run.assert_not_called()
assert (tmp_dir / "bar").exists() and not (tmp_dir / "foo").unlink()
assert (tmp_dir / PIPELINE_LOCK).exists()
Expand Down Expand Up @@ -103,7 +104,7 @@ def test_memory_for_multiple_runs_of_same_stage(
assert (tmp_dir / PIPELINE_LOCK).exists()
assert (tmp_dir / "bar").read_text() == "foobar"
mock_run.assert_not_called()
mock_restore.assert_called_once_with(stage, run_cache=True)
mock_restore.assert_called_once_with(stage)
mock_restore.reset_mock()

(tmp_dir / PIPELINE_LOCK).unlink()
Expand All @@ -112,7 +113,7 @@ def test_memory_for_multiple_runs_of_same_stage(

assert (tmp_dir / "bar").read_text() == "foo"
mock_run.assert_not_called()
mock_restore.assert_called_once_with(stage, run_cache=True)
mock_restore.assert_called_once_with(stage)
assert (tmp_dir / "bar").exists() and not (tmp_dir / "foo").unlink()
assert (tmp_dir / PIPELINE_LOCK).exists()

Expand Down Expand Up @@ -141,12 +142,36 @@ def test_memory_runs_of_multiple_stages(tmp_dir, dvc, run_copy, mocker):
assert (tmp_dir / "foo.bak").read_text() == "foo"
assert (tmp_dir / PIPELINE_LOCK).exists()
mock_run.assert_not_called()
mock_restore.assert_called_once_with(stage, run_cache=True)
mock_restore.assert_called_once_with(stage)
mock_restore.reset_mock()

(stage,) = dvc.reproduce("backup-bar")

assert (tmp_dir / "bar.bak").read_text() == "bar"
assert (tmp_dir / PIPELINE_LOCK).exists()
mock_run.assert_not_called()
mock_restore.assert_called_once_with(stage, run_cache=True)
mock_restore.assert_called_once_with(stage)


def test_restore_pull(tmp_dir, dvc, run_copy, mocker, local_remote):
tmp_dir.gen("foo", "foo")
stage = run_copy("foo", "bar", name="copy-foo-bar")

dvc.push()

mock_restore = mocker.spy(dvc.stage_cache, "restore")
mock_run = mocker.patch("dvc.stage.run.cmd_run")
mock_checkout = mocker.spy(dvc.cache.local, "checkout")

# removing any information that `dvc` could use to re-generate from
(tmp_dir / "bar").unlink()
(tmp_dir / PIPELINE_LOCK).unlink()
remove(stage.outs[0].cache_path)

(stage,) = dvc.reproduce("copy-foo-bar", pull=True)

mock_restore.assert_called_once_with(stage, pull=True)
mock_run.assert_not_called()
mock_checkout.assert_called_once()
assert (tmp_dir / "bar").exists() and not (tmp_dir / "foo").unlink()
assert (tmp_dir / PIPELINE_LOCK).exists()
1 change: 1 addition & 0 deletions tests/unit/command/test_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"queue": False,
"run_all": False,
"jobs": None,
"pull": False,
}


Expand Down
12 changes: 6 additions & 6 deletions tests/unit/stage/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def test_stage_cache(tmp_dir, dvc, mocker):
assert os.path.isfile(cache_file)

run_spy = mocker.patch("dvc.stage.run.cmd_run")
checkout_spy = mocker.spy(stage, "checkout")
checkout_spy = mocker.spy(dvc.cache.local, "checkout")
with dvc.lock, dvc.state:
stage.run()

assert not run_spy.called
assert checkout_spy.call_count == 1
assert checkout_spy.call_count == 2

assert (tmp_dir / "out").exists()
assert (tmp_dir / "out_no_cache").exists()
Expand Down Expand Up @@ -93,12 +93,12 @@ def test_stage_cache_params(tmp_dir, dvc, mocker):
assert os.path.isfile(cache_file)

run_spy = mocker.patch("dvc.stage.run.cmd_run")
checkout_spy = mocker.spy(stage, "checkout")
checkout_spy = mocker.spy(dvc.cache.local, "checkout")
with dvc.lock, dvc.state:
stage.run()

assert not run_spy.called
assert checkout_spy.call_count == 1
assert checkout_spy.call_count == 2

assert (tmp_dir / "out").exists()
assert (tmp_dir / "out_no_cache").exists()
Expand Down Expand Up @@ -147,12 +147,12 @@ def test_stage_cache_wdir(tmp_dir, dvc, mocker):
assert os.path.isfile(cache_file)

run_spy = mocker.patch("dvc.stage.run.cmd_run")
checkout_spy = mocker.spy(stage, "checkout")
checkout_spy = mocker.spy(dvc.cache.local, "checkout")
with dvc.lock, dvc.state:
stage.run()

assert not run_spy.called
assert checkout_spy.call_count == 1
assert checkout_spy.call_count == 2

assert (tmp_dir / "wdir" / "out").exists()
assert (tmp_dir / "wdir" / "out_no_cache").exists()
Expand Down