Skip to content

Commit

Permalink
Allow for logging to Studio when not inside a repo (#646)
Browse files Browse the repository at this point in the history
* post to studio even without git/dvc repo

* tests for no-git scenario

* studio: make no-repo paths relative to cwd

* make ruff happy

* don't require exp name

* don't require baseline rev

* refactor studio path formatting

* live: Set new defaults `report=None` and `save_dvc_exp=True`.

* frameworks: Drop model_file.

* update examples

* Write to root dvc.yaml (#687)

* add dvcyaml to root

* clean up dvcyaml implementation

* fix existing tests

* add new tests

* add unit tests for updating dvcyaml

* use posix paths

* don't resolve symlinks

* drop entire dvclive dir on cleanup

* fix studio tests

* revert cleanup changes

* unify rel_path util func

* cleanup test

* refactor tests

* add test for multiple dvclive instances

* put dvc_file logic into _init_dvc_file

---------

Co-authored-by: daavoo <daviddelaiglesiacastro@gmail.com>

* report: Drop "auto" logic.

Fallback to `None` when conditions are not met for other types.

* studio: Extract `post_to_studio` and decoulple from `make_report` (#705)

* refactor(tests): Split `test_main` into separate files.

Rename test_frameworks to frameworks.

* fix matplotlib warning

* fix studio tests

* fix windows studio paths

* fix windows studio paths for plots

* skip fabric tests if not installed

* drop dvc repo

* drop dvcignore

* drop unrelated test_fabric.py file

* fix windows paths

* fix windows paths

* adapt plot paths even if no dvc repo

* default baseline rev to all zeros

* consolidate repro tests

* set null sha as variable

* add type hints to studio

* limit windows path handling to studio

* fix typing errors in studio module

* fix mypy in live module

* drop checking for dvc_file

---------

Co-authored-by: daavoo <daviddelaiglesiacastro@gmail.com>
  • Loading branch information
dberenbaum and daavoo committed Feb 15, 2024
1 parent b8526e9 commit 35ce2f3
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 85 deletions.
16 changes: 13 additions & 3 deletions src/dvclive/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Optional

from dvclive import env
from dvclive.plots import Image, Metric
from dvclive.serialize import dump_yaml
from dvclive.utils import StrPath, rel_path
Expand Down Expand Up @@ -131,9 +132,14 @@ def _update_entries(old, new, key):
def get_exp_name(name, scm, baseline_rev) -> str:
from dvc.exceptions import InvalidArgumentError
from dvc.repo.experiments.refs import ExpRefInfo
from dvc.repo.experiments.utils import check_ref_format, get_random_exp_name
from dvc.repo.experiments.utils import (
check_ref_format,
gen_random_name,
get_random_exp_name,
)

if name:
name = name or os.getenv(env.DVC_EXP_NAME)
if name and scm and baseline_rev:
ref = ExpRefInfo(baseline_sha=baseline_rev, name=name)
if scm.get_ref(str(ref)):
logger.warning(f"Experiment conflicts with existing experiment '{name}'.")
Expand All @@ -144,7 +150,11 @@ def get_exp_name(name, scm, baseline_rev) -> str:
logger.warning(e)
else:
return name
return get_random_exp_name(scm, baseline_rev)
if scm and baseline_rev:
return get_random_exp_name(scm, baseline_rev)
if name:
return name
return gen_random_name()


def find_overlapping_stage(dvc_repo: "Repo", path: StrPath) -> Optional["Stage"]:
Expand Down
47 changes: 17 additions & 30 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@

ParamLike = Union[int, float, str, bool, List["ParamLike"], Dict[str, "ParamLike"]]

NULL_SHA: str = "0" * 40


class Live:
def __init__(
Expand Down Expand Up @@ -136,8 +138,8 @@ def __init__(
self._report_notebook = None
self._init_report()

self._baseline_rev: Optional[str] = None
self._exp_name: Optional[str] = exp_name
self._baseline_rev: str = os.getenv(env.DVC_EXP_BASELINE_REV, NULL_SHA)
self._exp_name: Optional[str] = exp_name or os.getenv(env.DVC_EXP_NAME)
self._exp_message: Optional[str] = exp_message
self._experiment_rev: Optional[str] = None
self._inside_dvc_exp: bool = False
Expand All @@ -156,7 +158,7 @@ def __init__(
else:
self._init_cleanup()

self._latest_studio_step = self.step if resume else -1
self._latest_studio_step: int = self.step if resume else -1
self._studio_events_to_skip: Set[str] = set()
self._dvc_studio_config: Dict[str, Any] = {}
self._init_studio()
Expand Down Expand Up @@ -189,28 +191,36 @@ def _init_cleanup(self):
os.remove(dvc_file)

@catch_and_warn(DvcException, logger)
def _init_dvc(self):
def _init_dvc(self): # noqa: C901
from dvc.scm import NoSCM

if os.getenv(env.DVC_ROOT, None):
self._inside_dvc_pipeline = True
self._init_dvc_pipeline()
self._dvc_repo = get_dvc_repo()

scm = self._dvc_repo.scm if self._dvc_repo else None
if isinstance(scm, NoSCM):
scm = None
if scm:
self._baseline_rev = scm.get_rev()
self._exp_name = get_exp_name(self._exp_name, scm, self._baseline_rev)
logger.info(f"Logging to experiment '{self._exp_name}'")

dvc_logger = logging.getLogger("dvc")
dvc_logger.setLevel(os.getenv(env.DVCLIVE_LOGLEVEL, "WARNING").upper())

self._dvc_file = self._init_dvc_file()

if (self._dvc_repo is None) or isinstance(self._dvc_repo.scm, NoSCM):
if not scm:
if self._save_dvc_exp:
logger.warning(
"Can't save experiment without a Git Repo."
"\nCreate a Git repo (`git init`) and commit (`git commit`)."
)
self._save_dvc_exp = False
return
if self._dvc_repo.scm.no_commits:
if scm.no_commits:
if self._save_dvc_exp:
logger.warning(
"Can't save experiment to an empty Git Repo."
Expand All @@ -230,12 +240,7 @@ def _init_dvc(self):
if self._inside_dvc_pipeline:
return

self._baseline_rev = self._dvc_repo.scm.get_rev()
if self._save_dvc_exp:
self._exp_name = get_exp_name(
self._exp_name, self._dvc_repo.scm, self._baseline_rev
)
logger.info(f"Logging to experiment '{self._exp_name}'")
mark_dvclive_only_started(self._exp_name)
self._include_untracked.append(self.dir)

Expand All @@ -249,8 +254,6 @@ def _init_dvc_file(self) -> str:
def _init_dvc_pipeline(self):
if os.getenv(env.DVC_EXP_BASELINE_REV, None):
# `dvc exp` execution
self._baseline_rev = os.getenv(env.DVC_EXP_BASELINE_REV, "")
self._exp_name = os.getenv(env.DVC_EXP_NAME, "")
self._inside_dvc_exp = True
if self._save_dvc_exp:
logger.info("Ignoring `save_dvc_exp` because `dvc exp run` is running")
Expand All @@ -275,22 +278,6 @@ def _init_studio(self):
logger.debug("Skipping `studio` report `start` and `done` events.")
self._studio_events_to_skip.add("start")
self._studio_events_to_skip.add("done")
elif self._dvc_repo is None:
logger.warning(
"Can't connect to Studio without a DVC Repo."
"\nYou can create a DVC Repo by calling `dvc init`."
)
self._studio_events_to_skip.add("start")
self._studio_events_to_skip.add("data")
self._studio_events_to_skip.add("done")
elif not self._save_dvc_exp:
logger.warning(
"Can't connect to Studio without creating a DVC experiment."
"\nIf you have a DVC Pipeline, run it with `dvc exp run`."
)
self._studio_events_to_skip.add("start")
self._studio_events_to_skip.add("data")
self._studio_events_to_skip.add("done")
else:
self.post_to_studio("start")

Expand Down Expand Up @@ -840,7 +827,7 @@ def make_dvcyaml(self):
make_dvcyaml(self)

@catch_and_warn(DvcException, logger)
def post_to_studio(self, event: str):
def post_to_studio(self, event: Literal["start", "data", "done"]):
post_to_studio(self, event)

def end(self):
Expand Down
33 changes: 20 additions & 13 deletions src/dvclive/studio.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
# ruff: noqa: SLF001
from __future__ import annotations
import base64
import logging
import math
import os
from pathlib import PureWindowsPath
from typing import TYPE_CHECKING, Literal, Mapping

from dvc_studio_client.config import get_studio_config
from dvc_studio_client.post_live_metrics import post_live_metrics

if TYPE_CHECKING:
from dvclive.live import Live
from dvclive.serialize import load_yaml
from dvclive.utils import parse_metrics, rel_path
from dvclive.utils import parse_metrics, rel_path, StrPath

logger = logging.getLogger("dvclive")


def _get_unsent_datapoints(plot, latest_step):
def _get_unsent_datapoints(plot: Mapping, latest_step: int):
return [x for x in plot if int(x["step"]) > latest_step]


def _cast_to_numbers(datapoints):
def _cast_to_numbers(datapoints: Mapping):
for datapoint in datapoints:
for k, v in datapoint.items():
if k == "step":
Expand All @@ -33,31 +38,33 @@ def _cast_to_numbers(datapoints):
return datapoints


def _adapt_path(live, name):
def _adapt_path(live: Live, name: StrPath):
if live._dvc_repo is not None:
name = rel_path(name, live._dvc_repo.root_dir)
if os.name == "nt":
name = str(PureWindowsPath(name).as_posix())
return name


def _adapt_plot_datapoints(live, plot):
def _adapt_plot_datapoints(live: Live, plot: Mapping):
datapoints = _get_unsent_datapoints(plot, live._latest_studio_step)
return _cast_to_numbers(datapoints)


def _adapt_image(image_path):
def _adapt_image(image_path: StrPath):
with open(image_path, "rb") as fobj:
return base64.b64encode(fobj.read()).decode("utf-8")


def _adapt_images(live):
def _adapt_images(live: Live):
return {
_adapt_path(live, image.output_path): {"image": _adapt_image(image.output_path)}
for image in live._images.values()
if image.step > live._latest_studio_step
}


def get_studio_updates(live):
def get_studio_updates(live: Live):
if os.path.isfile(live.params_file):
params_file = live.params_file
params_file = _adapt_path(live, params_file)
Expand All @@ -82,14 +89,14 @@ def get_studio_updates(live):
return metrics, params, plots


def get_dvc_studio_config(live):
def get_dvc_studio_config(live: Live):
config = {}
if live._dvc_repo:
config = live._dvc_repo.config.get("studio")
return get_studio_config(dvc_studio_config=config)


def post_to_studio(live, event):
def post_to_studio(live: Live, event: Literal["start", "data", "done"]):
if event in live._studio_events_to_skip:
return

Expand All @@ -98,7 +105,7 @@ def post_to_studio(live, event):
kwargs["message"] = live._exp_message
elif event == "data":
metrics, params, plots = get_studio_updates(live)
kwargs["step"] = live.step
kwargs["step"] = live.step # type: ignore
kwargs["metrics"] = metrics
kwargs["params"] = params
kwargs["plots"] = plots
Expand All @@ -108,10 +115,10 @@ def post_to_studio(live, event):
response = post_live_metrics(
event,
live._baseline_rev,
live._exp_name,
live._exp_name, # type: ignore
"dvclive",
dvc_studio_config=live._dvc_studio_config,
**kwargs,
**kwargs, # type: ignore
)
if not response:
logger.warning(f"`post_to_studio` `{event}` failed.")
Expand Down
42 changes: 8 additions & 34 deletions tests/test_dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,24 @@ def test_get_dvc_repo_subdir(tmp_dir):
def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo):
live = Live(save_dvc_exp=save)
live.end()
assert live._baseline_rev is not None
assert live._exp_name is not None
if save:
assert live._baseline_rev is not None
assert live._exp_name is not None
mocked_dvc_repo.experiments.save.assert_called_with(
name=live._exp_name,
include_untracked=[live.dir, "dvc.yaml"],
force=True,
message=None,
)
else:
assert live._baseline_rev is not None
assert live._exp_name is None
mocked_dvc_repo.experiments.save.assert_not_called()


def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch, mocker):
def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch):
monkeypatch.setenv(DVC_EXP_BASELINE_REV, "foo")
monkeypatch.setenv(DVC_EXP_NAME, "bar")
monkeypatch.setenv(DVC_ROOT, tmp_dir)

mocker.patch("dvclive.live.get_dvc_repo", return_value=None)
live = Live()
live.end()

Expand All @@ -60,31 +57,6 @@ def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch, mocker):
assert live._inside_dvc_pipeline


def test_exp_save_run_on_dvc_repro(tmp_dir, mocker):
dvc_repo = mocker.MagicMock()
dvc_stage = mocker.MagicMock()
dvc_file = mocker.MagicMock()
dvc_repo.index.stages = [dvc_stage, dvc_file]
dvc_repo.scm.get_rev.return_value = "current_rev"
dvc_repo.scm.get_ref.return_value = None
dvc_repo.scm.no_commits = False
dvc_repo.config = {}
dvc_repo.root_dir = tmp_dir
mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo)
live = Live()
assert live._save_dvc_exp
assert live._baseline_rev is not None
assert live._exp_name is not None
live.end()

dvc_repo.experiments.save.assert_called_with(
name=live._exp_name,
include_untracked=[live.dir, "dvc.yaml"],
force=True,
message=None,
)


def test_exp_save_with_dvc_files(tmp_dir, mocker):
dvc_repo = mocker.MagicMock()
dvc_file = mocker.MagicMock()
Expand Down Expand Up @@ -166,7 +138,7 @@ def test_errors_on_git_add_are_catched(tmp_dir, mocked_dvc_repo, monkeypatch):
mocked_dvc_repo.scm.untracked_files.return_value = ["dvclive/metrics.json"]
mocked_dvc_repo.scm.add.side_effect = DvcException("foo")

with Live(dvcyaml=False) as live:
with Live() as live:
live.summary["foo"] = 1


Expand Down Expand Up @@ -204,10 +176,12 @@ def test_no_scm_repo(tmp_dir, mocker):
assert live._save_dvc_exp is False


def test_dvc_repro(tmp_dir, monkeypatch, mocker):
def test_dvc_repro(tmp_dir, monkeypatch, mocked_dvc_repo, mocked_studio_post):
monkeypatch.setenv(DVC_ROOT, "root")
mocker.patch("dvclive.live.get_dvc_repo", return_value=None)
live = Live(save_dvc_exp=True)
assert live._baseline_rev is not None
assert live._exp_name is not None
assert not live._studio_events_to_skip
assert not live._save_dvc_exp


Expand Down
Loading

0 comments on commit 35ce2f3

Please sign in to comment.