diff --git a/src/dvclive/catalyst.py b/src/dvclive/catalyst.py index 4d01080e..2e6c1ef7 100644 --- a/src/dvclive/catalyst.py +++ b/src/dvclive/catalyst.py @@ -27,3 +27,6 @@ def on_epoch_end(self, runner) -> None: ) utils.save_checkpoint(checkpoint, self.model_file) self.live.next_step() + + def on_experiment_end(self, runner): # pylint: disable=unused-argument + self.live.end() diff --git a/src/dvclive/fastai.py b/src/dvclive/fastai.py index b6f7cd5d..0d934dfc 100644 --- a/src/dvclive/fastai.py +++ b/src/dvclive/fastai.py @@ -23,3 +23,6 @@ def after_epoch(self): if self.model_file: self.learn.save(self.model_file) self.live.next_step() + + def after_fit(self): + self.live.end() diff --git a/src/dvclive/huggingface.py b/src/dvclive/huggingface.py index 4b35d3fd..fc315f1e 100644 --- a/src/dvclive/huggingface.py +++ b/src/dvclive/huggingface.py @@ -42,3 +42,12 @@ def on_epoch_end( tokenizer = kwargs.get("tokenizer") if tokenizer: tokenizer.save_pretrained(self.model_file) + + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs + ): + self.live.end() diff --git a/src/dvclive/keras.py b/src/dvclive/keras.py index fb0357ad..104056f4 100644 --- a/src/dvclive/keras.py +++ b/src/dvclive/keras.py @@ -54,3 +54,8 @@ def on_epoch_end( else: self.model.save(self.model_file) self.live.next_step() + + def on_train_end( + self, logs: Optional[Dict] = None + ): # pylint: disable=unused-argument + self.live.end() diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 8114b317..6ca46f97 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -69,3 +69,7 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None): metric_name = standardize_metric_name(metric_name, __name__) self.experiment.log_metric(name=metric_name, val=metric_val) self.experiment.next_step() + + @rank_zero_only + def finalize(self, status: str) -> None: + self.experiment.end() diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 565eb26d..c2f2a2f0 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -43,7 +43,7 @@ def __init__( ): self._dir: str = dir self._resume: bool = resume or env2bool(env.DVCLIVE_RESUME) - + self._ended: bool = False self.studio_url = os.getenv(env.STUDIO_REPO_URL, None) self.studio_token = os.getenv(env.STUDIO_TOKEN, None) self.rev = None @@ -243,8 +243,10 @@ def make_report(self): def end(self): self.make_summary() if self.report_mode == "studio": - if not post_to_studio(self, "done", logger): - logger.warning("`post_to_studio` `done` event failed.") + if not self._ended: + if not post_to_studio(self, "done", logger): + logger.warning("`post_to_studio` `done` event failed.") + self._ended = True else: self.make_report() diff --git a/src/dvclive/xgb.py b/src/dvclive/xgb.py index ee421b5c..d2ea5523 100644 --- a/src/dvclive/xgb.py +++ b/src/dvclive/xgb.py @@ -26,3 +26,7 @@ def after_iteration(self, model, epoch, evals_log): if self.model_file: model.save_model(self.model_file) self.live.next_step() + + def after_training(self, model): + self.live.end() + return model diff --git a/tests/test_catalyst.py b/tests/test_catalyst.py index 7fd11e26..c2ea44f7 100644 --- a/tests/test_catalyst.py +++ b/tests/test_catalyst.py @@ -49,13 +49,17 @@ def runner_params(): } -def test_catalyst_callback(tmp_dir, runner, runner_params): +def test_catalyst_callback(tmp_dir, runner, runner_params, mocker): + callback = DVCLiveCallback() + live = callback.live + spy = mocker.spy(live, "end") + runner.train( **runner_params, num_epochs=2, callbacks=[ dl.AccuracyCallback(input_key="logits", target_key="targets"), - DVCLiveCallback(), + callback, ], logdir="./logs", valid_loader="valid", @@ -64,6 +68,7 @@ def test_catalyst_callback(tmp_dir, runner, runner_params): verbose=True, load_best_on_end=True, ) + spy.assert_called_once() assert os.path.exists("dvclive") diff --git a/tests/test_fastai.py b/tests/test_fastai.py index 18cb6df9..e30f0db8 100644 --- a/tests/test_fastai.py +++ b/tests/test_fastai.py @@ -38,13 +38,16 @@ def data_loader(): return xor_loader -def test_fastai_callback(tmp_dir, data_loader): +def test_fastai_callback(tmp_dir, data_loader, mocker): learn = tabular_learner(data_loader, metrics=accuracy) learn.remove_cb(ProgressCallback) learn.model_dir = os.path.abspath("./") callback = DVCLiveCallback("model") live = callback.live + + spy = mocker.spy(live, "end") learn.fit_one_cycle(2, cbs=[callback]) + spy.assert_called_once() assert os.path.exists(live.dir) diff --git a/tests/test_huggingface.py b/tests/test_huggingface.py index 843c795c..49efcce9 100644 --- a/tests/test_huggingface.py +++ b/tests/test_huggingface.py @@ -101,7 +101,7 @@ def args(): ) -def test_huggingface_integration(tmp_dir, model, args, data): +def test_huggingface_integration(tmp_dir, model, args, data, mocker): trainer = Trainer( model, args, @@ -110,8 +110,11 @@ def test_huggingface_integration(tmp_dir, model, args, data): compute_metrics=compute_metrics, ) callback = DVCLiveCallback() + live = callback.live + spy = mocker.spy(live, "end") trainer.add_callback(callback) trainer.train() + spy.assert_called_once() live = callback.live assert os.path.exists(live.dir) diff --git a/tests/test_keras.py b/tests/test_keras.py index 8262f498..89b1a651 100644 --- a/tests/test_keras.py +++ b/tests/test_keras.py @@ -35,10 +35,12 @@ def make(): yield make -def test_keras_callback(tmp_dir, xor_model, capture_wrap): +def test_keras_callback(tmp_dir, xor_model, capture_wrap, mocker): model, x, y = xor_model() callback = DVCLiveCallback() + live = callback.live + spy = mocker.spy(live, "end") model.fit( x, y, @@ -47,6 +49,7 @@ def test_keras_callback(tmp_dir, xor_model, capture_wrap): validation_split=0.2, callbacks=[callback], ) + spy.assert_called_once() assert os.path.exists("dvclive") logs, _ = parse_metrics(callback.live) diff --git a/tests/test_lightning.py b/tests/test_lightning.py index 5126f61d..6e18f4ca 100644 --- a/tests/test_lightning.py +++ b/tests/test_lightning.py @@ -84,11 +84,13 @@ def val_dataloader(self): pass -def test_lightning_integration(tmp_dir): +def test_lightning_integration(tmp_dir, mocker): # init model model = LitXOR() # init logger dvclive_logger = DVCLiveLogger("test_run", dir="logs") + live = dvclive_logger.experiment + spy = mocker.spy(live, "end") trainer = Trainer( logger=dvclive_logger, max_epochs=2, @@ -96,6 +98,7 @@ def test_lightning_integration(tmp_dir): log_every_n_steps=1, ) trainer.fit(model) + spy.assert_called_once() assert os.path.exists("logs") assert not os.path.exists("DvcLiveLogger") diff --git a/tests/test_studio.py b/tests/test_studio.py index ea9f820f..74e6d0dc 100644 --- a/tests/test_studio.py +++ b/tests/test_studio.py @@ -167,3 +167,23 @@ def test_post_to_studio_failed_start_request(tmp_dir, mocker, monkeypatch): live.next_step() assert mocked_post.call_count == 1 + + +@pytest.mark.studio +def test_post_to_studio_end_only_once(tmp_dir, mocker, monkeypatch): + mocker.patch("scmrepo.git.Git") + + valid_response = mocker.MagicMock() + valid_response.status_code = 200 + mocked_post = mocker.patch("requests.post", return_value=valid_response) + monkeypatch.setenv(env.STUDIO_ENDPOINT, "https://0.0.0.0") + monkeypatch.setenv(env.STUDIO_REPO_URL, "STUDIO_REPO_URL") + monkeypatch.setenv(env.STUDIO_TOKEN, "STUDIO_TOKEN") + + with Live() as live: + live.log_metric("foo", 1) + live.next_step() + + assert mocked_post.call_count == 3 + live.end() + assert mocked_post.call_count == 3 diff --git a/tests/test_xgboost.py b/tests/test_xgboost.py index 902dc83e..4264f1dd 100644 --- a/tests/test_xgboost.py +++ b/tests/test_xgboost.py @@ -26,8 +26,10 @@ def iris_data(): return xgb.DMatrix(x, y) -def test_xgb_integration(tmp_dir, train_params, iris_data): +def test_xgb_integration(tmp_dir, train_params, iris_data, mocker): callback = DVCLiveCallback("eval_data") + live = callback.live + spy = mocker.spy(live, "end") xgb.train( train_params, iris_data, @@ -35,6 +37,7 @@ def test_xgb_integration(tmp_dir, train_params, iris_data): num_boost_round=5, evals=[(iris_data, "eval_data")], ) + spy.assert_called_once() assert os.path.exists("dvclive")