diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 42e79e96..6a8978ff 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -6,7 +6,9 @@ import math import os import shutil +import queue import tempfile +import threading from pathlib import Path, PurePath from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING, Literal @@ -171,6 +173,7 @@ def __init__( self._studio_events_to_skip: Set[str] = set() self._dvc_studio_config: Dict[str, Any] = {} self._num_points_sent_to_studio: Dict[str, int] = {} + self._studio_queue = None self._init_studio() self._system_monitor: Optional[_SystemMonitor] = None # Monitoring thread @@ -296,7 +299,7 @@ def _init_studio(self): self._studio_events_to_skip.add("start") self._studio_events_to_skip.add("done") else: - self.post_to_studio("start") + post_to_studio(self, "start") def _init_report(self): if self._report_mode not in {None, "html", "notebook", "md"}: @@ -428,7 +431,7 @@ def sync(self): self.make_report() - self.post_to_studio("data") + self.post_data_to_studio() def next_step(self): """ @@ -880,9 +883,19 @@ def make_dvcyaml(self): """ make_dvcyaml(self) - @catch_and_warn(DvcException, logger) - def post_to_studio(self, event: Literal["start", "data", "done"]): - post_to_studio(self, event) + def post_data_to_studio(self): + if not self._studio_queue: + self._studio_queue = queue.Queue() + + def worker(): + while True: + item = self._studio_queue.get() + post_to_studio(item, "data") + self._studio_queue.task_done() + + threading.Thread(target=worker, daemon=True).start() + + self._studio_queue.put(self) def end(self): """ @@ -926,7 +939,7 @@ def end(self): self.save_dvc_exp() # Mark experiment as done - self.post_to_studio("done") + post_to_studio(self, "done") cleanup_dvclive_step_completed() diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index 31a61e5c..6ed398e7 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -7,9 +7,12 @@ from pathlib import PureWindowsPath from typing import TYPE_CHECKING, Literal, Mapping +from dvc.exceptions import DvcException from dvc_studio_client.config import get_studio_config from dvc_studio_client.post_live_metrics import post_live_metrics +from .utils import catch_and_warn + if TYPE_CHECKING: from dvclive.live import Live from dvclive.serialize import load_yaml @@ -96,6 +99,7 @@ def increment_num_points_sent_to_studio(live, plots): return live +@catch_and_warn(DvcException, logger) def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa: C901 if event in live._studio_events_to_skip: return diff --git a/tests/test_post_to_studio.py b/tests/test_post_to_studio.py index 4729ba5c..6aee0741 100644 --- a/tests/test_post_to_studio.py +++ b/tests/test_post_to_studio.py @@ -9,7 +9,7 @@ from dvclive import Live from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME, DVC_ROOT from dvclive.plots import Image, Metric -from dvclive.studio import _adapt_image, get_dvc_studio_config +from dvclive.studio import _adapt_image, get_dvc_studio_config, post_to_studio def get_studio_call(event_type, exp_name, **kwargs): @@ -46,7 +46,9 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): ) live.log_metric("foo", 1) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -58,8 +60,10 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): ), ) + live.step += 1 live.log_metric("foo", 2) - live.next_step() + live.make_summary() + post_to_studio(live, "data") mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -72,7 +76,8 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): ) mocked_post.reset_mock() - live.end() + live.save_dvc_exp() + post_to_studio(live, "done") mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -118,11 +123,15 @@ def test_post_to_studio_failed_data_request( error_response.status_code = 400 mocker.patch("requests.post", return_value=error_response) live.log_metric("foo", 1) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") mocked_post = mocker.patch("requests.post", return_value=valid_response) + live.step += 1 live.log_metric("foo", 2) - live.next_step() + live.make_summary() + post_to_studio(live, "data") mocked_post.assert_called_with( "https://0.0.0.0/api/live", **get_studio_call( @@ -154,6 +163,7 @@ def test_post_to_studio_failed_start_request( live.next_step() assert mocked_post.call_count == 1 + assert live._studio_events_to_skip == {"start", "data", "done"} def test_post_to_studio_done_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_post): @@ -210,7 +220,9 @@ def test_post_to_studio_dvc_studio_config( with Live() as live: live.log_metric("foo", 1) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") assert mocked_post.call_args.kwargs["headers"]["Authorization"] == "token token" @@ -231,7 +243,9 @@ def test_post_to_studio_skip_if_no_token( with Live() as live: live.log_metric("foo", 1) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") assert mocked_post.call_count == 0 @@ -241,7 +255,8 @@ def test_post_to_studio_shorten_names(tmp_dir, mocked_dvc_repo, mocked_studio_po live = Live() live.log_metric("eval/loss", 1) - live.next_step() + live.make_summary() + post_to_studio(live, "data") plots_path = Path(live.plots_dir) loss_path = (plots_path / Metric.subfolder / "eval/loss.tsv").as_posix() @@ -269,7 +284,9 @@ def test_post_to_studio_inside_dvc_exp( with Live() as live: live.log_metric("foo", 1) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list] assert "start" not in call_types @@ -287,7 +304,8 @@ def test_post_to_studio_inside_subdir( live = Live() live.log_metric("foo", 1) - live.next_step() + live.make_summary() + post_to_studio(live, "data") foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() @@ -317,7 +335,8 @@ def test_post_to_studio_inside_subdir_dvc_exp( live = Live() live.log_metric("foo", 1) - live.next_step() + live.make_summary() + post_to_studio(live, "data") foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() @@ -370,7 +389,9 @@ def test_post_to_studio_images(tmp_dir, mocked_dvc_repo, mocked_studio_post): live = Live() live.log_image("foo.png", ImagePIL.new("RGB", (10, 10), (0, 0, 0))) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") foo_path = (Path(live.plots_dir) / Image.subfolder / "foo.png").as_posix() @@ -409,11 +430,13 @@ def test_post_to_studio_name(tmp_dir, mocked_dvc_repo, mocked_studio_post): def test_post_to_studio_if_done_skipped(tmp_dir, mocked_dvc_repo, mocked_studio_post): - live = Live() - live._studio_events_to_skip.add("start") - live._studio_events_to_skip.add("done") - live.log_metric("foo", 1) - live.end() + with Live() as live: + live._studio_events_to_skip.add("start") + live._studio_events_to_skip.add("done") + live.log_metric("foo", 1) + live.step = 0 + live.make_summary() + post_to_studio(live, "data") mocked_post, _ = mocked_studio_post call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list] @@ -439,8 +462,9 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): ) live.log_metric("foo", 1) + live.make_summary() + post_to_studio(live, "data") - live.next_step() mocked_post.assert_called_with( "https://0.0.0.0/api/live", **get_studio_call( @@ -452,9 +476,11 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): ), ) + live.step += 1 live.log_metric("foo", 2) + live.make_summary() + post_to_studio(live, "data") - live.next_step() mocked_post.assert_called_with( "https://0.0.0.0/api/live", **get_studio_call( @@ -466,7 +492,7 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): ), ) - live.end() + post_to_studio(live, "done") mocked_post.assert_called_with( "https://0.0.0.0/api/live", **get_studio_call("done", baseline_sha="0" * 40, exp_name=live._exp_name), @@ -485,7 +511,9 @@ def test_post_to_studio_skip_if_no_repo_url( with Live() as live: live.log_metric("foo", 1) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") assert mocked_post.call_count == 0 @@ -503,7 +531,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post live.step = 0 live.log_metric("foo", 1) live.log_metric("bar", 0.1) - live.sync() + live.make_summary() + post_to_studio(live, "data") mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -521,7 +550,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post live.log_metric("foo", 2) live.log_metric("foo", 3) live.log_metric("bar", 0.2) - live.sync() + live.make_summary() + post_to_studio(live, "data") mocked_post.assert_called_with( "https://0.0.0.0/api/live",