Skip to content

Commit

Permalink
dvc: Decouple dvc_api_available from get_dvc_repo.
Browse files Browse the repository at this point in the history
Closes #473
  • Loading branch information
daavoo committed Mar 2, 2023
1 parent b38478b commit 1f6b9bd
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 54 deletions.
15 changes: 10 additions & 5 deletions src/dvclive/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,19 @@ def make_checkpoint():
sleep(_CHECKPOINT_SLEEP)


def get_dvc_repo():
def dvc_api_available() -> bool:
# noqa pylint: disable=unused-import
try:
from dvc.exceptions import NotDvcRepoError
from dvc.repo import Repo
from dvc.scm import SCMError
import dvc # noqa: F401
except ImportError:
return None
return False
return True


def get_dvc_repo():
from dvc.exceptions import NotDvcRepoError
from dvc.repo import Repo
from dvc.scm import SCMError

try:
return Repo()
Expand Down
90 changes: 41 additions & 49 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union

from dvc_studio_client.env import STUDIO_TOKEN
from dvc_studio_client.post_live_metrics import post_live_metrics
from ruamel.yaml.representer import RepresenterError

from . import env
from .dvc import (
dvc_api_available,
get_dvc_repo,
get_random_exp_name,
make_checkpoint,
Expand All @@ -29,13 +32,6 @@
open_file_in_browser,
)

try:
from dvc_studio_client.env import STUDIO_TOKEN
from dvc_studio_client.post_live_metrics import post_live_metrics
except ImportError:
post_live_metrics = None
STUDIO_TOKEN = None

logging.basicConfig()
logger = logging.getLogger("dvclive")
logger.setLevel(os.getenv(env.DVCLIVE_LOGLEVEL, "INFO").upper())
Expand Down Expand Up @@ -110,11 +106,6 @@ def _init_cleanup(self):
os.remove(f)

def _init_dvc(self):
self._dvc_repo = get_dvc_repo()

if self._dvc_repo is not None:
self._baseline_rev = self._dvc_repo.scm.get_rev()

if os.getenv(env.DVC_EXP_BASELINE_REV, None):
# `dvc exp` execution
self._baseline_rev = os.getenv(env.DVC_EXP_BASELINE_REV, "")
Expand All @@ -124,50 +115,56 @@ def _init_dvc(self):
logger.warning(
"Ignoring `_save_dvc_exp` because `dvc exp run` is running"
)
elif self._save_dvc_exp:
if self._dvc_repo is not None:
# `DVCLive Only` or `dvc repro` execution
self._exp_name = get_random_exp_name(
self._dvc_repo.scm, self._baseline_rev
)
mark_dvclive_only_started()
else:
self._save_dvc_exp = False
return

self._dvc_repo = get_dvc_repo()
if self._dvc_repo is None:
if self._save_dvc_exp:
if not dvc_api_available():
logger.warning(
"Can't save experiment without the DVC Python API."
"\nYou can install it by calling `pip install dvc`."
)
logger.warning(
"Can't save experiment without a DVC Repo."
"\nYou can create a DVC Repo by calling `dvc init`."
)
self._save_dvc_exp = False
return

def _init_studio(self):
if post_live_metrics is not None:
if not os.getenv(STUDIO_TOKEN, None):
logger.debug("Skipping `studio` report.")
self._studio_events_to_skip.add("start")
self._studio_events_to_skip.add("data")
self._studio_events_to_skip.add("done")
return
self._baseline_rev = self._dvc_repo.scm.get_rev()
if self._save_dvc_exp:
self._exp_name = get_random_exp_name(self._dvc_repo.scm, self._baseline_rev)
mark_dvclive_only_started()

if not (self._dvc_repo or self._inside_dvc_exp):
logger.debug("`studio` report can't be used without a DVC Repo.")
def _init_studio(self):
if not os.getenv(STUDIO_TOKEN, None):
logger.debug("Missing env var `STUDIO_TOKEN`, skipping `studio` report.")
self._studio_events_to_skip.add("start")
self._studio_events_to_skip.add("data")
self._studio_events_to_skip.add("done")
return

if self._inside_dvc_exp:
elif self._inside_dvc_exp:
logger.debug("Skipping `studio` report `start` and `done` events.")
self._studio_events_to_skip.add("start")
self._studio_events_to_skip.add("done")
else:
response = False
if post_live_metrics is not None:
response = post_live_metrics(
"start", self._baseline_rev, self._exp_name, "dvclive"
)
else:
logger.debug(
"`dvc_studio_client` is not installed.\n"
"You can install it with `pip install dvc-studio-client`."
elif self._dvc_repo is None:
if not dvc_api_available():
logger.warning(
"Can't send updates to Studio without the DVC Python API."
"\nYou can install it by calling `pip install dvc`."
)
logger.warning(
"Can't send updates to Studio without a DVC Repo."
"\nYou can create a DVC Repo by calling `dvc init`."
)
self._studio_events_to_skip.add("start")
self._studio_events_to_skip.add("data")
self._studio_events_to_skip.add("done")
else:
response = post_live_metrics(
"start", self._baseline_rev, self._exp_name, "dvclive"
)
if not response:
logger.debug(
"`studio` report `start` event failed. "
Expand Down Expand Up @@ -384,15 +381,10 @@ def end(self):
logger.warning("`post_to_studio` `done` event failed.")
self._studio_events_to_skip.add("done")
self._studio_events_to_skip.add("data")

else:
self.make_report()

if (
self._dvc_repo is not None
and not self._inside_dvc_exp
and self._save_dvc_exp
):
if self._save_dvc_exp:
from dvc.exceptions import DvcException

try:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,9 @@ def test_exp_save_dvcexception_is_ignored(tmp_dir, mocker):

with Live(save_dvc_exp=True):
pass


def test_exp_save_skipped_if_not_dvc_api_available(tmp_dir, mocker):
mocker.patch("dvclive.live.dvc_api_available", return_value=False)
live = Live(save_dvc_exp=True)
assert not live._save_dvc_exp

0 comments on commit 1f6b9bd

Please sign in to comment.