From 77efa2cd8075e6de3aa371df2eb426770d0b9252 Mon Sep 17 00:00:00 2001 From: Sami Jawhar Date: Thu, 22 Oct 2020 11:19:43 -0500 Subject: [PATCH 1/2] Add normalized confusion plot, and add text labels --- dvc/repo/plots/template.py | 147 +++++++++++++++++++++++--- tests/func/metrics/plots/test_show.py | 42 +++++++- 2 files changed, 173 insertions(+), 16 deletions(-) diff --git a/dvc/repo/plots/template.py b/dvc/repo/plots/template.py index a0a1ac0a0a..9361b1278b 100644 --- a/dvc/repo/plots/template.py +++ b/dvc/repo/plots/template.py @@ -137,22 +137,140 @@ class DefaultConfusionTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v4.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "mark": "rect", - "encoding": { - "x": { - "field": Template.anchor("x"), - "type": "nominal", - "sort": "ascending", - "title": Template.anchor("x_label"), + "facet": {"field": "rev", "type": "nominal"}, + "spec": { + "transform": [ + { + "aggregate": [{"op": "count", "as": "xy_count"}], + "groupby": [Template.anchor("y"), Template.anchor("x")], + }, + { + "joinaggregate": [ + {"op": "max", "field": "xy_count", "as": "max_count"} + ], + "groupby": [], + }, + { + "calculate": "datum.xy_count / datum.max_count", + "as": "percent_of_max", + }, + ], + "encoding": { + "x": { + "field": Template.anchor("x"), + "type": "nominal", + "sort": "ascending", + "title": Template.anchor("x_label"), + }, + "y": { + "field": Template.anchor("y"), + "type": "nominal", + "sort": "ascending", + "title": Template.anchor("y_label"), + }, }, - "y": { - "field": Template.anchor("y"), - "type": "nominal", - "sort": "ascending", - "title": Template.anchor("y_label"), + "layer": [ + { + "mark": "rect", + "width": 300, + "height": 300, + "encoding": { + "color": { + "field": "xy_count", + "type": "quantitative", + "title": "", + "scale": {"domainMin": 0, "nice": True}, + } + }, + }, + { + "mark": "text", + "encoding": { + "text": {"field": "xy_count", "type": "quantitative"}, + "color": { + "condition": { + "test": "datum.percent_of_max > 0.5", + "value": "white", + }, + "value": "black", + }, + }, + }, + ], + }, + } + + +class NormalizedConfusionTemplate(Template): + DEFAULT_NAME = "confusion_normalized" + DEFAULT_CONTENT = { + "$schema": "https://vega.github.io/schema/vega-lite/v4.json", + "data": {"values": Template.anchor("data")}, + "title": Template.anchor("title"), + "facet": {"field": "rev", "type": "nominal"}, + "spec": { + "transform": [ + { + "aggregate": [{"op": "count", "as": "xy_count"}], + "groupby": [Template.anchor("y"), Template.anchor("x")], + }, + { + "joinaggregate": [ + {"op": "sum", "field": "xy_count", "as": "sum_y"} + ], + "groupby": [Template.anchor("y")], + }, + { + "calculate": "datum.xy_count / datum.sum_y", + "as": "percent_of_y", + }, + ], + "encoding": { + "x": { + "field": Template.anchor("x"), + "type": "nominal", + "sort": "ascending", + "title": Template.anchor("x_label"), + }, + "y": { + "field": Template.anchor("y"), + "type": "nominal", + "sort": "ascending", + "title": Template.anchor("y_label"), + }, }, - "color": {"aggregate": "count", "type": "quantitative"}, - "facet": {"field": "rev", "type": "nominal"}, + "layer": [ + { + "mark": "rect", + "width": 300, + "height": 300, + "encoding": { + "color": { + "field": "percent_of_y", + "type": "quantitative", + "title": "", + "scale": {"domain": [0, 1]}, + } + }, + }, + { + "mark": "text", + "encoding": { + "text": { + "field": "percent_of_y", + "type": "quantitative", + "format": ".2f", + }, + "color": { + "condition": { + "test": "datum.percent_of_y > 0.5", + "value": "white", + }, + "value": "black", + }, + }, + }, + ], }, } @@ -219,6 +337,7 @@ class PlotTemplates: TEMPLATES = [ DefaultLinearTemplate, DefaultConfusionTemplate, + NormalizedConfusionTemplate, DefaultScatterTemplate, SmoothLinearTemplate, ] diff --git a/tests/func/metrics/plots/test_show.py b/tests/func/metrics/plots/test_show.py index 0e4f05e532..39b51675ef 100644 --- a/tests/func/metrics/plots/test_show.py +++ b/tests/func/metrics/plots/test_show.py @@ -193,8 +193,46 @@ def test_plot_confusion(tmp_dir, dvc, run_copy_metrics): {"predicted": "B", "actual": "A", "rev": "workspace"}, {"predicted": "A", "actual": "A", "rev": "workspace"}, ] - assert plot_content["encoding"]["x"]["field"] == "predicted" - assert plot_content["encoding"]["y"]["field"] == "actual" + assert plot_content["spec"]["transform"][0]["groupby"] == [ + "actual", + "predicted", + ] + assert plot_content["spec"]["encoding"]["x"]["field"] == "predicted" + assert plot_content["spec"]["encoding"]["y"]["field"] == "actual" + + +def test_plot_confusion_normalized(tmp_dir, dvc, run_copy_metrics): + confusion_matrix = [ + {"predicted": "B", "actual": "A"}, + {"predicted": "A", "actual": "A"}, + ] + _write_json(tmp_dir, confusion_matrix, "metric_t.json") + run_copy_metrics( + "metric_t.json", + "metric.json", + plots_no_cache=["metric.json"], + commit="first run", + ) + + props = { + "template": "confusion_normalized", + "x": "predicted", + "y": "actual", + } + plot_string = dvc.plots.show(props=props)["metric.json"] + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + {"predicted": "B", "actual": "A", "rev": "workspace"}, + {"predicted": "A", "actual": "A", "rev": "workspace"}, + ] + assert plot_content["spec"]["transform"][0]["groupby"] == [ + "actual", + "predicted", + ] + assert plot_content["spec"]["transform"][1]["groupby"] == ["actual"] + assert plot_content["spec"]["encoding"]["x"]["field"] == "predicted" + assert plot_content["spec"]["encoding"]["y"]["field"] == "actual" def test_plot_multiple_revs_default(tmp_dir, scm, dvc, run_copy_metrics): From f354bc627323397de01ed675a930ded03a24c9cd Mon Sep 17 00:00:00 2001 From: Sami Jawhar Date: Sun, 25 Oct 2020 14:03:22 +0000 Subject: [PATCH 2/2] Impute missing XY combinations --- dvc/repo/plots/template.py | 24 ++++++++++++++++++++++++ tests/func/metrics/plots/test_show.py | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/dvc/repo/plots/template.py b/dvc/repo/plots/template.py index 9361b1278b..1e98de2ef4 100644 --- a/dvc/repo/plots/template.py +++ b/dvc/repo/plots/template.py @@ -144,6 +144,18 @@ class DefaultConfusionTemplate(Template): "aggregate": [{"op": "count", "as": "xy_count"}], "groupby": [Template.anchor("y"), Template.anchor("x")], }, + { + "impute": "xy_count", + "groupby": ["rev", Template.anchor("y")], + "key": Template.anchor("x"), + "value": 0, + }, + { + "impute": "xy_count", + "groupby": ["rev", Template.anchor("x")], + "key": Template.anchor("y"), + "value": 0, + }, { "joinaggregate": [ {"op": "max", "field": "xy_count", "as": "max_count"} @@ -214,6 +226,18 @@ class NormalizedConfusionTemplate(Template): "aggregate": [{"op": "count", "as": "xy_count"}], "groupby": [Template.anchor("y"), Template.anchor("x")], }, + { + "impute": "xy_count", + "groupby": ["rev", Template.anchor("y")], + "key": Template.anchor("x"), + "value": 0, + }, + { + "impute": "xy_count", + "groupby": ["rev", Template.anchor("x")], + "key": Template.anchor("y"), + "value": 0, + }, { "joinaggregate": [ {"op": "sum", "field": "xy_count", "as": "sum_y"} diff --git a/tests/func/metrics/plots/test_show.py b/tests/func/metrics/plots/test_show.py index 39b51675ef..d80910ace2 100644 --- a/tests/func/metrics/plots/test_show.py +++ b/tests/func/metrics/plots/test_show.py @@ -230,7 +230,7 @@ def test_plot_confusion_normalized(tmp_dir, dvc, run_copy_metrics): "actual", "predicted", ] - assert plot_content["spec"]["transform"][1]["groupby"] == ["actual"] + assert plot_content["spec"]["transform"][1]["groupby"] == ["rev", "actual"] assert plot_content["spec"]["encoding"]["x"]["field"] == "predicted" assert plot_content["spec"]["encoding"]["y"]["field"] == "actual"