Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions src/dvclive/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from dvclive import Live


# noqa pylint: disable=protected-access


def get_scalar_renderers(metrics_path):
renderers = []
for suffix in Metric.suffixes:
Expand Down Expand Up @@ -55,17 +58,33 @@ def get_image_renderers(images_folder):
return renderers


def get_plot_renderers(plots_folder):
def get_plot_renderers(plots_folder, live):
renderers = []
for suffix in SKLearnPlot.suffixes:
for file in Path(plots_folder).rglob(f"*{suffix}"):
name = file.stem
name = file.relative_to(plots_folder).with_suffix("").as_posix()
properties = {}

if name in SKLEARN_PLOTS:
properties = SKLEARN_PLOTS[name].get_properties()
data_field = name
else:
# Plot with custom name
logged_plot = live._plots[name]
for default_name, plot_class in SKLEARN_PLOTS.items():
if isinstance(logged_plot, plot_class):
properties = plot_class.get_properties()
data_field = default_name
break

data = json.loads(file.read_text())
if name in data:
data = data[name]

if data_field in data:
data = data[data_field]

for row in data:
row["rev"] = "workspace"
properties = SKLEARN_PLOTS[name].get_properties()

renderers.append(VegaRenderer(data, name, **properties))
return renderers

Expand Down Expand Up @@ -94,19 +113,21 @@ def get_params_renderers(dvclive_params):
return []


def make_report(dvclive: "Live"):
plots_path = Path(dvclive.plots_dir)
def make_report(live: "Live"):
plots_path = Path(live.plots_dir)

renderers = []
renderers.extend(get_params_renderers(dvclive.params_file))
renderers.extend(get_metrics_renderers(dvclive.metrics_file))
renderers.extend(get_params_renderers(live.params_file))
renderers.extend(get_metrics_renderers(live.metrics_file))
renderers.extend(get_scalar_renderers(plots_path / Metric.subfolder))
renderers.extend(get_image_renderers(plots_path / Image.subfolder))
renderers.extend(get_plot_renderers(plots_path / SKLearnPlot.subfolder))

if dvclive.report_mode == "html":
render_html(renderers, dvclive.report_file, refresh_seconds=5)
elif dvclive.report_mode == "md":
render_markdown(renderers, dvclive.report_file)
renderers.extend(
get_plot_renderers(plots_path / SKLearnPlot.subfolder, live)
)

if live.report_mode == "html":
render_html(renderers, live.report_file, refresh_seconds=5)
elif live.report_mode == "md":
render_markdown(renderers, live.report_file)
else:
raise ValueError(f"Invalid `mode` {dvclive.report_mode}.")
raise ValueError(f"Invalid `mode` {live.report_mode}.")
37 changes: 27 additions & 10 deletions tests/test_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dvclive.env import DVCLIVE_OPEN
from dvclive.plots import Image as LiveImage
from dvclive.plots import Metric
from dvclive.plots.sklearn import ConfusionMatrix, SKLearnPlot
from dvclive.plots.sklearn import ConfusionMatrix, Roc, SKLearnPlot
from dvclive.report import (
get_image_renderers,
get_metrics_renderers,
Expand All @@ -31,6 +31,12 @@ def test_get_renderers(tmp_dir, mocker):
live.next_step()

live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1])
live.log_sklearn_plot(
"confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1], name="train/cm"
)
live.log_sklearn_plot(
"roc", [0, 0, 1, 1], [1, 0.1, 0, 1], name="roc_curve"
)

image_renderers = get_image_renderers(
tmp_dir / live.plots_dir / LiveImage.subfolder
Expand Down Expand Up @@ -63,16 +69,27 @@ def test_get_renderers(tmp_dir, mocker):
assert scalar_renderers[0].name == "static/foo/bar"

plot_renderers = get_plot_renderers(
tmp_dir / live.plots_dir / SKLearnPlot.subfolder
tmp_dir / live.plots_dir / SKLearnPlot.subfolder, live
)
assert len(plot_renderers) == 1
assert plot_renderers[0].datapoints == [
{"actual": "0", "rev": "workspace", "predicted": "1"},
{"actual": "0", "rev": "workspace", "predicted": "0"},
{"actual": "1", "rev": "workspace", "predicted": "0"},
{"actual": "1", "rev": "workspace", "predicted": "1"},
]
assert plot_renderers[0].properties == ConfusionMatrix.get_properties()
assert len(plot_renderers) == 3
for plot_renderer in plot_renderers:
if plot_renderer.name == "roc_curve":
assert plot_renderer.datapoints == [
{"fpr": 0.0, "rev": "workspace", "threshold": 2.0, "tpr": 0.0},
{"fpr": 0.5, "rev": "workspace", "threshold": 1.0, "tpr": 0.5},
{"fpr": 1.0, "rev": "workspace", "threshold": 0.1, "tpr": 0.5},
{"fpr": 1.0, "rev": "workspace", "threshold": 0.0, "tpr": 1.0},
]
assert plot_renderer.properties == Roc.get_properties()

else:
assert plot_renderer.datapoints == [
{"actual": "0", "rev": "workspace", "predicted": "1"},
{"actual": "0", "rev": "workspace", "predicted": "0"},
{"actual": "1", "rev": "workspace", "predicted": "0"},
{"actual": "1", "rev": "workspace", "predicted": "1"},
]
assert plot_renderer.properties == ConfusionMatrix.get_properties()

metrics_renderer = get_metrics_renderers(live.metrics_file)[0]
assert metrics_renderer.datapoints == [{"step": 1, "foo": {"bar": 1}}]
Expand Down