diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 96fe5e1e..4160d8f1 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -2,6 +2,11 @@ from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.logger import ( + _convert_params, + _sanitize_callable_params, + _sanitize_params, +) from torch import is_tensor from dvclive import Live @@ -39,7 +44,10 @@ def name(self): @rank_zero_only def log_hyperparams(self, params, *args, **kwargs): - pass + params = _convert_params(params) + params = _sanitize_callable_params(params) + params = _sanitize_params(params) + self.experiment.log_params(params) @property # type: ignore @rank_zero_experiment diff --git a/tests/test_lightning.py b/tests/test_lightning.py index 03c3d3a2..fca25ebe 100644 --- a/tests/test_lightning.py +++ b/tests/test_lightning.py @@ -10,6 +10,7 @@ from dvclive.lightning import DVCLiveLogger from dvclive.plots.metric import Metric +from dvclive.serialize import load_yaml from dvclive.utils import parse_metrics # pylint: disable=redefined-outer-name, unused-argument @@ -30,11 +31,13 @@ def __len__(self): class LitXOR(LightningModule): - def __init__(self): + def __init__(self, latent_dims=4): super().__init__() - self.layer_1 = nn.Linear(2, 4) - self.layer_2 = nn.Linear(4, 2) + self.save_hyperparameters() + + self.layer_1 = nn.Linear(2, latent_dims) + self.layer_2 = nn.Linear(latent_dims, 2) def forward(self, *args, **kwargs): x = args[0] @@ -113,6 +116,10 @@ def test_lightning_integration(tmp_dir, mocker): assert os.path.join(scalars, "train", "step", "loss.tsv") in logs assert os.path.join(scalars, "epoch.tsv") in logs + params_file = dvclive_logger.experiment.params_file + assert os.path.exists(params_file) + assert load_yaml(params_file) == {"latent_dims": 4} + def test_lightning_default_dir(tmp_dir): model = LitXOR()