From c4d803f6f0e934281a9fe0e688c0a143380a92ef Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Mon, 5 Dec 2022 15:54:24 -0500 Subject: [PATCH 1/3] lightning: log params --- src/dvclive/lightning.py | 2 +- tests/test_lightning.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 96fe5e1e..38c42f1d 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -39,7 +39,7 @@ def name(self): @rank_zero_only def log_hyperparams(self, params, *args, **kwargs): - pass + self.experiment.log_params(params) @property # type: ignore @rank_zero_experiment diff --git a/tests/test_lightning.py b/tests/test_lightning.py index 03c3d3a2..fca25ebe 100644 --- a/tests/test_lightning.py +++ b/tests/test_lightning.py @@ -10,6 +10,7 @@ from dvclive.lightning import DVCLiveLogger from dvclive.plots.metric import Metric +from dvclive.serialize import load_yaml from dvclive.utils import parse_metrics # pylint: disable=redefined-outer-name, unused-argument @@ -30,11 +31,13 @@ def __len__(self): class LitXOR(LightningModule): - def __init__(self): + def __init__(self, latent_dims=4): super().__init__() - self.layer_1 = nn.Linear(2, 4) - self.layer_2 = nn.Linear(4, 2) + self.save_hyperparameters() + + self.layer_1 = nn.Linear(2, latent_dims) + self.layer_2 = nn.Linear(latent_dims, 2) def forward(self, *args, **kwargs): x = args[0] @@ -113,6 +116,10 @@ def test_lightning_integration(tmp_dir, mocker): assert os.path.join(scalars, "train", "step", "loss.tsv") in logs assert os.path.join(scalars, "epoch.tsv") in logs + params_file = dvclive_logger.experiment.params_file + assert os.path.exists(params_file) + assert load_yaml(params_file) == {"latent_dims": 4} + def test_lightning_default_dir(tmp_dir): model = LitXOR() From e58f50ecd0c9ddc67e417fdbbde9ce04f6c98a1a Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Tue, 6 Dec 2022 13:58:49 -0500 Subject: [PATCH 2/3] lightning: clean params for logging --- src/dvclive/lightning.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 38c42f1d..af659a50 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -2,6 +2,11 @@ from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.logger import ( + _convert_params, + _sanitize_params, + _sanitize_callable_params, +) from torch import is_tensor from dvclive import Live @@ -39,6 +44,9 @@ def name(self): @rank_zero_only def log_hyperparams(self, params, *args, **kwargs): + params = _convert_params(params) + params = _sanitize_callable_params(params) + params = _sanitize_params(params) self.experiment.log_params(params) @property # type: ignore From 068128b138fb72c42fb6df3b5c6c006074038680 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Dec 2022 18:59:16 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/dvclive/lightning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index af659a50..4160d8f1 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -3,9 +3,9 @@ from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.logger import ( - _convert_params, - _sanitize_params, - _sanitize_callable_params, + _convert_params, + _sanitize_callable_params, + _sanitize_params, ) from torch import is_tensor