From 3414163f17cec0785be73a46839ad0075d193dd7 Mon Sep 17 00:00:00 2001 From: Piotr Picheta Date: Wed, 24 Apr 2024 14:05:24 +0200 Subject: [PATCH] Fixes #812 (#813) * Fixes #812 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * named arguments in log_sklearn_plot * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * replaced sklearn_kwargs with kwargs --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/dvclive/live.py | 32 +++++++++++++++++++++----------- tests/plots/test_sklearn.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 6a8978ff..f28637d3 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -624,6 +624,10 @@ def log_sklearn_plot( labels: Union[List, np.ndarray], predictions: Union[List, Tuple, np.ndarray], name: Optional[str] = None, + title: Optional[str] = None, + x_label: Optional[str] = None, + y_label: Optional[str] = None, + normalized: Optional[bool] = None, **kwargs, ): """ @@ -638,14 +642,17 @@ def log_sklearn_plot( "roc"): a supported plot type. labels (List | np.ndarray): array of ground truth labels. predictions (List | np.ndarray): array of predicted labels (for - `"confusion_matrix"`) or predicted probabilities (for other plots). + `"confusion_matrix"`) or predicted probabilities (for other plots). name (str): optional name of the output file. If not provided, `kind` will - be used as name. + be used as name. + title (str): optional title to be displayed. + x_label (str): optional label for the x axis. + y_label (str): optional label for the y axis. + normalized (bool): optional, `confusion_matrix` with values normalized to + `<0, 1>` range. kwargs: additional arguments to tune the result. Arguments are passed to the scikit-learn function (e.g. `drop_intermediate=True` for the `"roc"` - type). Plus extra arguments supported by the type of a plot are: - - `normalized`: default to `False`. `confusion_matrix` with values - normalized to `<0, 1>` range. + type). Raises: InvalidPlotTypeError: thrown if the provided `kind` does not correspond to any of the supported plots. @@ -654,9 +661,15 @@ def log_sklearn_plot( plot_config = { k: v - for k, v in kwargs.items() - if k in ("title", "x_label", "y_label", "normalized") + for k, v in { + "title": title, + "x_label": x_label, + "y_label": y_label, + "normalized": normalized, + }.items() + if v is not None } + name = name or kind if name in self._plots: plot = self._plots[name] @@ -666,11 +679,8 @@ def log_sklearn_plot( else: raise InvalidPlotTypeError(name) - sklearn_kwargs = { - k: v for k, v in kwargs.items() if k not in plot_config or k != "normalized" - } plot.step = self.step - plot.dump(val, **sklearn_kwargs) + plot.dump(val, **kwargs) logger.debug(f"Logged {name}") def _read_params(self): diff --git a/tests/plots/test_sklearn.py b/tests/plots/test_sklearn.py index 848765e7..65a20e6d 100644 --- a/tests/plots/test_sklearn.py +++ b/tests/plots/test_sklearn.py @@ -162,7 +162,7 @@ def test_custom_title(tmp_dir, y_true_y_pred_y_score): live = Live() out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder - y_true, y_pred, _ = y_true_y_pred_y_score + y_true, y_pred, y_score = y_true_y_pred_y_score live.log_sklearn_plot( "confusion_matrix", @@ -174,8 +174,38 @@ def test_custom_title(tmp_dir, y_true_y_pred_y_score): live.log_sklearn_plot( "confusion_matrix", y_true, y_pred, name="val/cm", title="Val Confusion Matrix" ) + live.log_sklearn_plot( + "precision_recall", + y_true, + y_score, + name="val/prc", + title="Val Precision Recall", + ) assert (out / "train" / "cm.json").exists() assert (out / "val" / "cm.json").exists() + assert (out / "val" / "prc.json").exists() assert live._plots["train/cm"].plot_config["title"] == "Train Confusion Matrix" assert live._plots["val/cm"].plot_config["title"] == "Val Confusion Matrix" + assert live._plots["val/prc"].plot_config["title"] == "Val Precision Recall" + + +def test_custom_labels(tmp_dir, y_true_y_pred_y_score): + """https://github.com/iterative/dvclive/issues/453""" + live = Live() + out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder + + y_true, _, y_score = y_true_y_pred_y_score + + live.log_sklearn_plot( + "precision_recall", + y_true, + y_score, + name="val/prc", + x_label="x_test", + y_label="y_test", + ) + assert (out / "val" / "prc.json").exists() + + assert live._plots["val/prc"].plot_config["x_label"] == "x_test" + assert live._plots["val/prc"].plot_config["y_label"] == "y_test"