diff --git a/pyproject.toml b/pyproject.toml index f572a663..3121b8cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,13 +31,13 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ - "dvc>=3.33.3", + "dvc@git+https://github.com/iterative/dvc.git@refs/pull/10303/head", "dvc-render>=1.0.0,<2", - "dvc-studio-client>=0.17.1,<1", + "dvc-studio-client>=0.20,<1", "funcy", "gto", "ruamel.yaml", - "scmrepo" + "scmrepo>=3,<4" ] [project.optional-dependencies] diff --git a/src/dvclive/live.py b/src/dvclive/live.py index f2b12fb7..0cc0b17a 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -17,6 +17,7 @@ import PIL from dvc.exceptions import DvcException +from dvc.utils.studio import get_subrepo_relpath from funcy import set_in from ruamel.yaml.representer import RepresenterError @@ -141,6 +142,7 @@ def __init__( self._baseline_rev: str = os.getenv(env.DVC_EXP_BASELINE_REV, NULL_SHA) self._exp_name: Optional[str] = exp_name or os.getenv(env.DVC_EXP_NAME) self._exp_message: Optional[str] = exp_message + self._subdir: Optional[str] = None self._experiment_rev: Optional[str] = None self._inside_dvc_exp: bool = False self._inside_dvc_pipeline: bool = False @@ -240,6 +242,8 @@ def _init_dvc(self): # noqa: C901 if self._inside_dvc_pipeline: return + self._subdir = get_subrepo_relpath(self._dvc_repo) + if self._save_dvc_exp: mark_dvclive_only_started(self._exp_name) self._include_untracked.append(self.dir) diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index ca63e4b0..74de6d54 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -101,8 +101,11 @@ def post_to_studio(live: Live, event: Literal["start", "data", "done"]): return kwargs = {} - if event == "start" and live._exp_message: - kwargs["message"] = live._exp_message + if event == "start": + if message := live._exp_message: + kwargs["message"] = message + if subdir := live._subdir: + kwargs["subdir"] = subdir elif event == "data": metrics, params, plots = get_studio_updates(live) kwargs["step"] = live.step # type: ignore diff --git a/tests/conftest.py b/tests/conftest.py index 1c856e45..8b0e0dc7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,8 @@ import pytest from dvc_studio_client.env import DVC_STUDIO_TOKEN, DVC_STUDIO_URL, STUDIO_REPO_URL +from dvclive.utils import rel_path + @pytest.fixture() def tmp_dir(tmp_path, monkeypatch): @@ -19,12 +21,19 @@ def mocked_dvc_repo(tmp_dir, mocker): _dvc_repo.scm.get_ref.return_value = None _dvc_repo.scm.no_commits = False _dvc_repo.experiments.save.return_value = "e" * 40 - _dvc_repo.root_dir = tmp_dir + _dvc_repo.root_dir = _dvc_repo.scm.root_dir = tmp_dir + _dvc_repo.fs.relpath = rel_path _dvc_repo.config = {} mocker.patch("dvclive.live.get_dvc_repo", return_value=_dvc_repo) return _dvc_repo +@pytest.fixture() +def mocked_dvc_subrepo(tmp_dir, mocker, mocked_dvc_repo): + mocked_dvc_repo.root_dir = tmp_dir / "subdir" + return mocked_dvc_repo + + @pytest.fixture() def dvc_repo(tmp_dir): from dvc.repo import Repo diff --git a/tests/test_post_to_studio.py b/tests/test_post_to_studio.py index edc02a25..eae0e426 100644 --- a/tests/test_post_to_studio.py +++ b/tests/test_post_to_studio.py @@ -81,6 +81,18 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): ) +def test_post_to_studio_subrepo(tmp_dir, mocked_dvc_subrepo, mocked_studio_post): + live = Live() + live.log_param("fooparam", 1) + + mocked_post, _ = mocked_studio_post + + mocked_post.assert_called_with( + "https://0.0.0.0/api/live", + **get_studio_call("start", exp_name=live._exp_name, subdir="subdir"), + ) + + def test_post_to_studio_failed_data_request( tmp_dir, mocker, mocked_dvc_repo, mocked_studio_post ):