Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cache: cache in dvc exp if not stage output #660

Merged
merged 2 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,31 +462,33 @@ def log_artifact(
def cache(self, path):
try:
if self._inside_dvc_exp:
msg = f"Skipping dvc add {path} because `dvc exp run` is running."
path_stage = None
for stage in self._dvc_repo.index.stages:
for out in stage.outs:
if out.fspath == str(Path(path).absolute()):
path_stage = stage
break
if not path_stage:
msg += (
"\nTo track it automatically during `dvc exp run`, "
"add it as an output of the pipeline stage."
if path_stage and path_stage.cmd:
msg = (
f"Skipping `dvc add {path}` because it is already being tracked"
" automatically as an output of `dvc exp run`."
)
logger.warning(msg)
elif path_stage.cmd:
msg += "\nIt is already being tracked automatically."
logger.info(msg)
else:
msg += (
"\nTo track it automatically during `dvc exp run`:"
return # skip caching
if path_stage:
msg = (
f"\nTo track '{path}' automatically during `dvc exp run`:"
f"\n1. Run `dvc exp remove {path_stage.addressing}` "
"to stop tracking it outside the pipeline."
"\n2. Add it as an output of the pipeline stage."
)
logger.warning(msg)
return
else:
msg = (
f"\nTo track '{path}' automatically during `dvc exp run`, "
"add it as an output of the pipeline stage."
)
logger.warning(msg)

stage = self._dvc_repo.add(str(path))

Expand Down
45 changes: 21 additions & 24 deletions tests/test_log_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,44 +225,41 @@ def test_log_artifact_type_model_when_dvc_add_fails(tmp_dir, mocker, mocked_dvc_
}


def test_log_artifact_inside_exp(tmp_dir, mocked_dvc_repo):
data = tmp_dir / "data"
data.touch()
with Live() as live:
live._inside_dvc_exp = True
live.log_artifact("data")
mocked_dvc_repo.add.assert_not_called()


@pytest.mark.parametrize("tracked", ["data_source", "stage", None])
def test_log_artifact_inside_exp_logger(tmp_dir, mocker, dvc_repo, tracked):
def test_log_artifact_inside_exp(tmp_dir, mocker, dvc_repo, tracked):
logger = mocker.patch("dvclive.live.logger")
data = tmp_dir / "data"
data.touch()
if tracked == "data_source":
data = tmp_dir / "data"
data.touch()
dvc_repo.add(data)
elif tracked == "stage":
dvcyaml_path = tmp_dir / "dvc.yaml"
with open(dvcyaml_path, "w") as f:
f.write(dvcyaml)
with Live() as live:
live._inside_dvc_exp = True
live.log_artifact("data")
msg = "Skipping dvc add data because `dvc exp run` is running."
if tracked == "data_source":
msg += (
"\nTo track it automatically during `dvc exp run`:"
live = Live()
spy = mocker.spy(live._dvc_repo, "add")
live._inside_dvc_exp = True
live.log_artifact("data")
if tracked == "stage":
msg = (
"Skipping `dvc add data` because it is already being tracked"
" automatically as an output of `dvc exp run`."
)
logger.info.assert_called_with(msg)
spy.assert_not_called()
elif tracked == "data_source":
msg = (
"\nTo track 'data' automatically during `dvc exp run`:"
"\n1. Run `dvc exp remove data.dvc` "
"to stop tracking it outside the pipeline."
"\n2. Add it as an output of the pipeline stage."
)
logger.warning.assert_called_with(msg)
elif tracked == "stage":
msg += "\nIt is already being tracked automatically."
logger.info.assert_called_with(msg)
spy.assert_called_once()
else:
msg += (
"\nTo track it automatically during `dvc exp run`, "
msg = (
"\nTo track 'data' automatically during `dvc exp run`, "
"add it as an output of the pipeline stage."
)
logger.warning.assert_called_with(msg)
spy.assert_called_once()