diff --git a/src/dvclive/report.py b/src/dvclive/report.py index c332b9b7..d2900e6e 100644 --- a/src/dvclive/report.py +++ b/src/dvclive/report.py @@ -17,6 +17,9 @@ from dvclive import Live +# noqa pylint: disable=protected-access + + def get_scalar_renderers(metrics_path): renderers = [] for suffix in Metric.suffixes: @@ -55,17 +58,33 @@ def get_image_renderers(images_folder): return renderers -def get_plot_renderers(plots_folder): +def get_plot_renderers(plots_folder, live): renderers = [] for suffix in SKLearnPlot.suffixes: for file in Path(plots_folder).rglob(f"*{suffix}"): - name = file.stem + name = file.relative_to(plots_folder).with_suffix("").as_posix() + properties = {} + + if name in SKLEARN_PLOTS: + properties = SKLEARN_PLOTS[name].get_properties() + data_field = name + else: + # Plot with custom name + logged_plot = live._plots[name] + for default_name, plot_class in SKLEARN_PLOTS.items(): + if isinstance(logged_plot, plot_class): + properties = plot_class.get_properties() + data_field = default_name + break + data = json.loads(file.read_text()) - if name in data: - data = data[name] + + if data_field in data: + data = data[data_field] + for row in data: row["rev"] = "workspace" - properties = SKLEARN_PLOTS[name].get_properties() + renderers.append(VegaRenderer(data, name, **properties)) return renderers @@ -94,19 +113,21 @@ def get_params_renderers(dvclive_params): return [] -def make_report(dvclive: "Live"): - plots_path = Path(dvclive.plots_dir) +def make_report(live: "Live"): + plots_path = Path(live.plots_dir) renderers = [] - renderers.extend(get_params_renderers(dvclive.params_file)) - renderers.extend(get_metrics_renderers(dvclive.metrics_file)) + renderers.extend(get_params_renderers(live.params_file)) + renderers.extend(get_metrics_renderers(live.metrics_file)) renderers.extend(get_scalar_renderers(plots_path / Metric.subfolder)) renderers.extend(get_image_renderers(plots_path / Image.subfolder)) - renderers.extend(get_plot_renderers(plots_path / SKLearnPlot.subfolder)) - - if dvclive.report_mode == "html": - render_html(renderers, dvclive.report_file, refresh_seconds=5) - elif dvclive.report_mode == "md": - render_markdown(renderers, dvclive.report_file) + renderers.extend( + get_plot_renderers(plots_path / SKLearnPlot.subfolder, live) + ) + + if live.report_mode == "html": + render_html(renderers, live.report_file, refresh_seconds=5) + elif live.report_mode == "md": + render_markdown(renderers, live.report_file) else: - raise ValueError(f"Invalid `mode` {dvclive.report_mode}.") + raise ValueError(f"Invalid `mode` {live.report_mode}.") diff --git a/tests/test_report.py b/tests/test_report.py index 0fe69000..eccd7942 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -8,7 +8,7 @@ from dvclive.env import DVCLIVE_OPEN from dvclive.plots import Image as LiveImage from dvclive.plots import Metric -from dvclive.plots.sklearn import ConfusionMatrix, SKLearnPlot +from dvclive.plots.sklearn import ConfusionMatrix, Roc, SKLearnPlot from dvclive.report import ( get_image_renderers, get_metrics_renderers, @@ -31,6 +31,12 @@ def test_get_renderers(tmp_dir, mocker): live.next_step() live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1]) + live.log_sklearn_plot( + "confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1], name="train/cm" + ) + live.log_sklearn_plot( + "roc", [0, 0, 1, 1], [1, 0.1, 0, 1], name="roc_curve" + ) image_renderers = get_image_renderers( tmp_dir / live.plots_dir / LiveImage.subfolder @@ -63,16 +69,27 @@ def test_get_renderers(tmp_dir, mocker): assert scalar_renderers[0].name == "static/foo/bar" plot_renderers = get_plot_renderers( - tmp_dir / live.plots_dir / SKLearnPlot.subfolder + tmp_dir / live.plots_dir / SKLearnPlot.subfolder, live ) - assert len(plot_renderers) == 1 - assert plot_renderers[0].datapoints == [ - {"actual": "0", "rev": "workspace", "predicted": "1"}, - {"actual": "0", "rev": "workspace", "predicted": "0"}, - {"actual": "1", "rev": "workspace", "predicted": "0"}, - {"actual": "1", "rev": "workspace", "predicted": "1"}, - ] - assert plot_renderers[0].properties == ConfusionMatrix.get_properties() + assert len(plot_renderers) == 3 + for plot_renderer in plot_renderers: + if plot_renderer.name == "roc_curve": + assert plot_renderer.datapoints == [ + {"fpr": 0.0, "rev": "workspace", "threshold": 2.0, "tpr": 0.0}, + {"fpr": 0.5, "rev": "workspace", "threshold": 1.0, "tpr": 0.5}, + {"fpr": 1.0, "rev": "workspace", "threshold": 0.1, "tpr": 0.5}, + {"fpr": 1.0, "rev": "workspace", "threshold": 0.0, "tpr": 1.0}, + ] + assert plot_renderer.properties == Roc.get_properties() + + else: + assert plot_renderer.datapoints == [ + {"actual": "0", "rev": "workspace", "predicted": "1"}, + {"actual": "0", "rev": "workspace", "predicted": "0"}, + {"actual": "1", "rev": "workspace", "predicted": "0"}, + {"actual": "1", "rev": "workspace", "predicted": "1"}, + ] + assert plot_renderer.properties == ConfusionMatrix.get_properties() metrics_renderer = get_metrics_renderers(live.metrics_file)[0] assert metrics_renderer.datapoints == [{"step": 1, "foo": {"bar": 1}}]