From 3e16bb25eb174dc881f869fc60abd2892c4017c4 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Thu, 18 Apr 2024 08:55:37 -0400 Subject: [PATCH 1/4] post to studio in thread to avoid blocking --- src/dvclive/live.py | 4 +++- tests/test_post_to_studio.py | 45 +++++++++++++++++++++++++----------- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 42e79e96..7d0c3415 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -7,6 +7,7 @@ import os import shutil import tempfile +import threading from pathlib import Path, PurePath from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING, Literal @@ -882,7 +883,8 @@ def make_dvcyaml(self): @catch_and_warn(DvcException, logger) def post_to_studio(self, event: Literal["start", "data", "done"]): - post_to_studio(self, event) + thread = threading.Thread(target=post_to_studio, args=(self, event)) + thread.start() def end(self): """ diff --git a/tests/test_post_to_studio.py b/tests/test_post_to_studio.py index 4729ba5c..815cc2b7 100644 --- a/tests/test_post_to_studio.py +++ b/tests/test_post_to_studio.py @@ -9,7 +9,7 @@ from dvclive import Live from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME, DVC_ROOT from dvclive.plots import Image, Metric -from dvclive.studio import _adapt_image, get_dvc_studio_config +from dvclive.studio import _adapt_image, get_dvc_studio_config, post_to_studio def get_studio_call(event_type, exp_name, **kwargs): @@ -46,7 +46,9 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): ) live.log_metric("foo", 1) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -58,8 +60,10 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): ), ) + live.step += 1 live.log_metric("foo", 2) - live.next_step() + live.make_summary() + post_to_studio(live, "data") mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -72,7 +76,8 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): ) mocked_post.reset_mock() - live.end() + live.save_dvc_exp() + post_to_studio(live, "done") mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -118,11 +123,15 @@ def test_post_to_studio_failed_data_request( error_response.status_code = 400 mocker.patch("requests.post", return_value=error_response) live.log_metric("foo", 1) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") mocked_post = mocker.patch("requests.post", return_value=valid_response) + live.step += 1 live.log_metric("foo", 2) - live.next_step() + live.make_summary() + post_to_studio(live, "data") mocked_post.assert_called_with( "https://0.0.0.0/api/live", **get_studio_call( @@ -241,7 +250,8 @@ def test_post_to_studio_shorten_names(tmp_dir, mocked_dvc_repo, mocked_studio_po live = Live() live.log_metric("eval/loss", 1) - live.next_step() + live.make_summary() + post_to_studio(live, "data") plots_path = Path(live.plots_dir) loss_path = (plots_path / Metric.subfolder / "eval/loss.tsv").as_posix() @@ -287,7 +297,8 @@ def test_post_to_studio_inside_subdir( live = Live() live.log_metric("foo", 1) - live.next_step() + live.make_summary() + post_to_studio(live, "data") foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() @@ -317,7 +328,8 @@ def test_post_to_studio_inside_subdir_dvc_exp( live = Live() live.log_metric("foo", 1) - live.next_step() + live.make_summary() + post_to_studio(live, "data") foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() @@ -439,8 +451,9 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): ) live.log_metric("foo", 1) + live.make_summary() + post_to_studio(live, "data") - live.next_step() mocked_post.assert_called_with( "https://0.0.0.0/api/live", **get_studio_call( @@ -452,9 +465,11 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): ), ) + live.step += 1 live.log_metric("foo", 2) + live.make_summary() + post_to_studio(live, "data") - live.next_step() mocked_post.assert_called_with( "https://0.0.0.0/api/live", **get_studio_call( @@ -466,7 +481,7 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): ), ) - live.end() + post_to_studio(live, "done") mocked_post.assert_called_with( "https://0.0.0.0/api/live", **get_studio_call("done", baseline_sha="0" * 40, exp_name=live._exp_name), @@ -503,7 +518,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post live.step = 0 live.log_metric("foo", 1) live.log_metric("bar", 0.1) - live.sync() + live.make_summary() + post_to_studio(live, "data") mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -521,7 +537,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post live.log_metric("foo", 2) live.log_metric("foo", 3) live.log_metric("bar", 0.2) - live.sync() + live.make_summary() + post_to_studio(live, "data") mocked_post.assert_called_with( "https://0.0.0.0/api/live", From 23f1fca2f0490fae47f1c6fe6dec10b948a7ee90 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 19 Apr 2024 11:03:26 -0400 Subject: [PATCH 2/4] queue for studio data posts --- src/dvclive/live.py | 25 ++++++++++++++++++------- tests/test_post_to_studio.py | 21 ++++++++++++++++----- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 7d0c3415..6a8978ff 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -6,6 +6,7 @@ import math import os import shutil +import queue import tempfile import threading @@ -172,6 +173,7 @@ def __init__( self._studio_events_to_skip: Set[str] = set() self._dvc_studio_config: Dict[str, Any] = {} self._num_points_sent_to_studio: Dict[str, int] = {} + self._studio_queue = None self._init_studio() self._system_monitor: Optional[_SystemMonitor] = None # Monitoring thread @@ -297,7 +299,7 @@ def _init_studio(self): self._studio_events_to_skip.add("start") self._studio_events_to_skip.add("done") else: - self.post_to_studio("start") + post_to_studio(self, "start") def _init_report(self): if self._report_mode not in {None, "html", "notebook", "md"}: @@ -429,7 +431,7 @@ def sync(self): self.make_report() - self.post_to_studio("data") + self.post_data_to_studio() def next_step(self): """ @@ -881,10 +883,19 @@ def make_dvcyaml(self): """ make_dvcyaml(self) - @catch_and_warn(DvcException, logger) - def post_to_studio(self, event: Literal["start", "data", "done"]): - thread = threading.Thread(target=post_to_studio, args=(self, event)) - thread.start() + def post_data_to_studio(self): + if not self._studio_queue: + self._studio_queue = queue.Queue() + + def worker(): + while True: + item = self._studio_queue.get() + post_to_studio(item, "data") + self._studio_queue.task_done() + + threading.Thread(target=worker, daemon=True).start() + + self._studio_queue.put(self) def end(self): """ @@ -928,7 +939,7 @@ def end(self): self.save_dvc_exp() # Mark experiment as done - self.post_to_studio("done") + post_to_studio(self, "done") cleanup_dvclive_step_completed() diff --git a/tests/test_post_to_studio.py b/tests/test_post_to_studio.py index 815cc2b7..ecb80067 100644 --- a/tests/test_post_to_studio.py +++ b/tests/test_post_to_studio.py @@ -163,6 +163,7 @@ def test_post_to_studio_failed_start_request( live.next_step() assert mocked_post.call_count == 1 + assert live._studio_events_to_skip == {"start", "data", "done"} def test_post_to_studio_done_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_post): @@ -219,7 +220,9 @@ def test_post_to_studio_dvc_studio_config( with Live() as live: live.log_metric("foo", 1) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") assert mocked_post.call_args.kwargs["headers"]["Authorization"] == "token token" @@ -240,7 +243,9 @@ def test_post_to_studio_skip_if_no_token( with Live() as live: live.log_metric("foo", 1) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") assert mocked_post.call_count == 0 @@ -279,7 +284,9 @@ def test_post_to_studio_inside_dvc_exp( with Live() as live: live.log_metric("foo", 1) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list] assert "start" not in call_types @@ -382,7 +389,9 @@ def test_post_to_studio_images(tmp_dir, mocked_dvc_repo, mocked_studio_post): live = Live() live.log_image("foo.png", ImagePIL.new("RGB", (10, 10), (0, 0, 0))) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") foo_path = (Path(live.plots_dir) / Image.subfolder / "foo.png").as_posix() @@ -500,7 +509,9 @@ def test_post_to_studio_skip_if_no_repo_url( with Live() as live: live.log_metric("foo", 1) - live.next_step() + live.step = 0 + live.make_summary() + post_to_studio(live, "data") assert mocked_post.call_count == 0 From 79d07b11139e38dbc8a600f5ac5ce335afa6862d Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 19 Apr 2024 11:09:16 -0400 Subject: [PATCH 3/4] fix test_post_to_studio_if_done_skipped --- tests/test_post_to_studio.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_post_to_studio.py b/tests/test_post_to_studio.py index ecb80067..6aee0741 100644 --- a/tests/test_post_to_studio.py +++ b/tests/test_post_to_studio.py @@ -430,11 +430,13 @@ def test_post_to_studio_name(tmp_dir, mocked_dvc_repo, mocked_studio_post): def test_post_to_studio_if_done_skipped(tmp_dir, mocked_dvc_repo, mocked_studio_post): - live = Live() - live._studio_events_to_skip.add("start") - live._studio_events_to_skip.add("done") - live.log_metric("foo", 1) - live.end() + with Live() as live: + live._studio_events_to_skip.add("start") + live._studio_events_to_skip.add("done") + live.log_metric("foo", 1) + live.step = 0 + live.make_summary() + post_to_studio(live, "data") mocked_post, _ = mocked_studio_post call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list] From 254e2ff0b496d2ab0506dd95ae5625b17782edfc Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 19 Apr 2024 13:52:26 -0400 Subject: [PATCH 4/4] catch and warn in src/dvclive/studio.py:post_to_studio --- src/dvclive/studio.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index 31a61e5c..6ed398e7 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -7,9 +7,12 @@ from pathlib import PureWindowsPath from typing import TYPE_CHECKING, Literal, Mapping +from dvc.exceptions import DvcException from dvc_studio_client.config import get_studio_config from dvc_studio_client.post_live_metrics import post_live_metrics +from .utils import catch_and_warn + if TYPE_CHECKING: from dvclive.live import Live from dvclive.serialize import load_yaml @@ -96,6 +99,7 @@ def increment_num_points_sent_to_studio(live, plots): return live +@catch_and_warn(DvcException, logger) def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa: C901 if event in live._studio_events_to_skip: return