Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

post to studio in thread to avoid blocking #814

Merged
merged 4 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import math
import os
import shutil
import queue
import tempfile
import threading

from pathlib import Path, PurePath
from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING, Literal
Expand Down Expand Up @@ -171,6 +173,7 @@ def __init__(
self._studio_events_to_skip: Set[str] = set()
self._dvc_studio_config: Dict[str, Any] = {}
self._num_points_sent_to_studio: Dict[str, int] = {}
self._studio_queue = None
self._init_studio()

self._system_monitor: Optional[_SystemMonitor] = None # Monitoring thread
Expand Down Expand Up @@ -296,7 +299,7 @@ def _init_studio(self):
self._studio_events_to_skip.add("start")
self._studio_events_to_skip.add("done")
else:
self.post_to_studio("start")
post_to_studio(self, "start")

def _init_report(self):
if self._report_mode not in {None, "html", "notebook", "md"}:
Expand Down Expand Up @@ -428,7 +431,7 @@ def sync(self):

self.make_report()

self.post_to_studio("data")
self.post_data_to_studio()

def next_step(self):
"""
Expand Down Expand Up @@ -880,9 +883,19 @@ def make_dvcyaml(self):
"""
make_dvcyaml(self)

@catch_and_warn(DvcException, logger)
def post_to_studio(self, event: Literal["start", "data", "done"]):
post_to_studio(self, event)
def post_data_to_studio(self):
if not self._studio_queue:
self._studio_queue = queue.Queue()

def worker():
while True:
item = self._studio_queue.get()
post_to_studio(item, "data")
self._studio_queue.task_done()

threading.Thread(target=worker, daemon=True).start()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is post_to_studio defined now?

One thing to test - Ctrl-C handling for this. Just to make sure that our threads and queue wait doesn't mess up apps in any way. Thanks @dberenbaum for iterating on this!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we wrap the call to post_to_studio in @catch_and_warn(DvcException, logger) ?

Copy link
Contributor Author

@dberenbaum dberenbaum Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is post_to_studio defined now?

def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa: C901

One thing to test - Ctrl-C handling for this. Just to make sure that our threads and queue wait doesn't mess up apps in any way. Thanks @dberenbaum for iterating on this!

No, thanks for keeping me honest here. It's a pretty simple change but could easily see it breaking some things. This is one reason I wanted to keep start and done separate per your question above. The queue is set up in daemon mode which can exit without making all expected calls (without it I noticed tests were hanging). By separating start and end, we ensure those are called and returned in most scenarios. If a particular live data call gets missed, I don't think it's as critical. (edit: it also simplifies testing to have those calls in the main thread)

should we wrap the call to post_to_studio in @catch_and_warn(DvcException, logger) ?

Sure, pushed an update with that change


self._studio_queue.put(self)

def end(self):
"""
Expand Down Expand Up @@ -926,7 +939,7 @@ def end(self):
self.save_dvc_exp()

# Mark experiment as done
self.post_to_studio("done")
post_to_studio(self, "done")

cleanup_dvclive_step_completed()

Expand Down
78 changes: 54 additions & 24 deletions tests/test_post_to_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dvclive import Live
from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME, DVC_ROOT
from dvclive.plots import Image, Metric
from dvclive.studio import _adapt_image, get_dvc_studio_config
from dvclive.studio import _adapt_image, get_dvc_studio_config, post_to_studio


def get_studio_call(event_type, exp_name, **kwargs):
Expand Down Expand Up @@ -46,7 +46,9 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
)

live.log_metric("foo", 1)
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand All @@ -58,8 +60,10 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
),
)

live.step += 1
live.log_metric("foo", 2)
live.next_step()
live.make_summary()
post_to_studio(live, "data")

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand All @@ -72,7 +76,8 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
)

mocked_post.reset_mock()
live.end()
live.save_dvc_exp()
post_to_studio(live, "done")

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand Down Expand Up @@ -118,11 +123,15 @@ def test_post_to_studio_failed_data_request(
error_response.status_code = 400
mocker.patch("requests.post", return_value=error_response)
live.log_metric("foo", 1)
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

mocked_post = mocker.patch("requests.post", return_value=valid_response)
live.step += 1
live.log_metric("foo", 2)
live.next_step()
live.make_summary()
post_to_studio(live, "data")
mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
**get_studio_call(
Expand Down Expand Up @@ -154,6 +163,7 @@ def test_post_to_studio_failed_start_request(
live.next_step()

assert mocked_post.call_count == 1
assert live._studio_events_to_skip == {"start", "data", "done"}


def test_post_to_studio_done_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_post):
Expand Down Expand Up @@ -210,7 +220,9 @@ def test_post_to_studio_dvc_studio_config(

with Live() as live:
live.log_metric("foo", 1)
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

assert mocked_post.call_args.kwargs["headers"]["Authorization"] == "token token"

Expand All @@ -231,7 +243,9 @@ def test_post_to_studio_skip_if_no_token(

with Live() as live:
live.log_metric("foo", 1)
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

assert mocked_post.call_count == 0

Expand All @@ -241,7 +255,8 @@ def test_post_to_studio_shorten_names(tmp_dir, mocked_dvc_repo, mocked_studio_po

live = Live()
live.log_metric("eval/loss", 1)
live.next_step()
live.make_summary()
post_to_studio(live, "data")

plots_path = Path(live.plots_dir)
loss_path = (plots_path / Metric.subfolder / "eval/loss.tsv").as_posix()
Expand Down Expand Up @@ -269,7 +284,9 @@ def test_post_to_studio_inside_dvc_exp(

with Live() as live:
live.log_metric("foo", 1)
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
assert "start" not in call_types
Expand All @@ -287,7 +304,8 @@ def test_post_to_studio_inside_subdir(

live = Live()
live.log_metric("foo", 1)
live.next_step()
live.make_summary()
post_to_studio(live, "data")

foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix()

Expand Down Expand Up @@ -317,7 +335,8 @@ def test_post_to_studio_inside_subdir_dvc_exp(

live = Live()
live.log_metric("foo", 1)
live.next_step()
live.make_summary()
post_to_studio(live, "data")

foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix()

Expand Down Expand Up @@ -370,7 +389,9 @@ def test_post_to_studio_images(tmp_dir, mocked_dvc_repo, mocked_studio_post):

live = Live()
live.log_image("foo.png", ImagePIL.new("RGB", (10, 10), (0, 0, 0)))
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

foo_path = (Path(live.plots_dir) / Image.subfolder / "foo.png").as_posix()

Expand Down Expand Up @@ -409,11 +430,13 @@ def test_post_to_studio_name(tmp_dir, mocked_dvc_repo, mocked_studio_post):


def test_post_to_studio_if_done_skipped(tmp_dir, mocked_dvc_repo, mocked_studio_post):
live = Live()
live._studio_events_to_skip.add("start")
live._studio_events_to_skip.add("done")
live.log_metric("foo", 1)
live.end()
with Live() as live:
live._studio_events_to_skip.add("start")
live._studio_events_to_skip.add("done")
live.log_metric("foo", 1)
live.step = 0
live.make_summary()
post_to_studio(live, "data")

mocked_post, _ = mocked_studio_post
call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
Expand All @@ -439,8 +462,9 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post):
)

live.log_metric("foo", 1)
live.make_summary()
post_to_studio(live, "data")

live.next_step()
mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
**get_studio_call(
Expand All @@ -452,9 +476,11 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post):
),
)

live.step += 1
live.log_metric("foo", 2)
live.make_summary()
post_to_studio(live, "data")

live.next_step()
mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
**get_studio_call(
Expand All @@ -466,7 +492,7 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post):
),
)

live.end()
post_to_studio(live, "done")
mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
**get_studio_call("done", baseline_sha="0" * 40, exp_name=live._exp_name),
Expand All @@ -485,7 +511,9 @@ def test_post_to_studio_skip_if_no_repo_url(

with Live() as live:
live.log_metric("foo", 1)
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

assert mocked_post.call_count == 0

Expand All @@ -503,7 +531,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post
live.step = 0
live.log_metric("foo", 1)
live.log_metric("bar", 0.1)
live.sync()
live.make_summary()
post_to_studio(live, "data")

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand All @@ -521,7 +550,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post
live.log_metric("foo", 2)
live.log_metric("foo", 3)
live.log_metric("bar", 0.2)
live.sync()
live.make_summary()
post_to_studio(live, "data")

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand Down
Loading