Skip to content

Commit

Permalink
live: Add log_plot.
Browse files Browse the repository at this point in the history
Create DVC plots from datapoints (list of dictionaries) and plot config.

Closes #271
Closes #453

```
datapoints = [{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]
with Live() as live:
        live.log_plot("foo_default", datapoints, x="foo", y="bar")
        live.log_plot(
            "foo_scatter",
            datapoints,
            x="foo",
            y="bar",
            template="scatter",
        )
```
  • Loading branch information
daavoo committed Apr 25, 2023
1 parent 0f59e9f commit 5abd7bd
Show file tree
Hide file tree
Showing 9 changed files with 285 additions and 110 deletions.
2 changes: 1 addition & 1 deletion src/dvclive/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def make_dvcyaml(live):
if live._plots:
for plot in live._plots.values():
plot_path = plot.output_path.relative_to(live.dir)
plots.append({plot_path.as_posix(): plot.get_properties()})
plots.append({plot_path.as_posix(): plot.plot_config})
if plots:
dvcyaml["plots"] = plots

Expand Down
75 changes: 58 additions & 17 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
InvalidPlotTypeError,
InvalidReportModeError,
)
from .plots import PLOT_TYPES, SKLEARN_PLOTS, Image, Metric, NumpyEncoder
from .plots import PLOT_TYPES, SKLEARN_PLOTS, CustomPlot, Image, Metric, NumpyEncoder
from .report import BLANK_NOTEBOOK_REPORT, make_report
from .serialize import dump_json, dump_yaml, load_yaml
from .studio import get_studio_updates
Expand Down Expand Up @@ -261,15 +261,15 @@ def log_metric(self, name: str, val: Union[int, float], timestamp: bool = False)
raise InvalidDataTypeError(name, type(val))

if name in self._metrics:
data = self._metrics[name]
metric = self._metrics[name]
else:
data = Metric(name, self.plots_dir)
self._metrics[name] = data
metric = Metric(name, self.plots_dir)
self._metrics[name] = metric

data.step = self.step
data.dump(val, timestamp=timestamp)
metric.step = self.step
metric.dump(val, timestamp=timestamp)

self.summary = set_in(self.summary, data.summary_keys, val)
self.summary = set_in(self.summary, metric.summary_keys, val)
logger.debug(f"Logged {name}: {val}")

def log_image(self, name: str, val):
Expand All @@ -282,29 +282,70 @@ def log_image(self, name: str, val):
val = ImagePIL.open(val)

if name in self._images:
data = self._images[name]
image = self._images[name]
else:
data = Image(name, self.plots_dir)
self._images[name] = data
image = Image(name, self.plots_dir)
self._images[name] = image

data.step = self.step
data.dump(val)
image.step = self.step
image.dump(val)
logger.debug(f"Logged {name}: {val}")

def log_plot(
self,
name: str,
datapoints: List[Dict],
x: str,
y: str,
template: Optional[str] = None,
title: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
):
if not CustomPlot.could_log(datapoints):
raise InvalidDataTypeError(name, type(datapoints))

if name in self._plots:
plot = self._plots[name]
else:
plot = CustomPlot(
name,
self.plots_dir,
x=x,
y=y,
template=template,
title=title,
x_label=x_label,
y_label=y_label,
)
self._plots[name] = plot

plot.step = self.step
plot.dump(datapoints)
logger.debug(f"Logged {name}")

def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs):
val = (labels, predictions)

plot_config = {
k: v
for k, v in kwargs.items()
if k in ("title", "x_label", "y_label", "normalized")
}
name = name or kind
if name in self._plots:
data = self._plots[name]
plot = self._plots[name]
elif kind in SKLEARN_PLOTS and SKLEARN_PLOTS[kind].could_log(val):
data = SKLEARN_PLOTS[kind](name, self.plots_dir, **kwargs)
self._plots[data.name] = data
plot = SKLEARN_PLOTS[kind](name, self.plots_dir, **plot_config)
self._plots[plot.name] = plot
else:
raise InvalidPlotTypeError(name)

data.step = self.step
data.dump(val, **kwargs)
sklearn_kwargs = {
k: v for k, v in kwargs.items() if k not in plot_config or k != "normalized"
}
plot.step = self.step
plot.dump(val, **sklearn_kwargs)
logger.debug(f"Logged {name}")

def _read_params(self):
Expand Down
3 changes: 2 additions & 1 deletion src/dvclive/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .custom import CustomPlot
from .image import Image
from .metric import Metric
from .sklearn import Calibration, ConfusionMatrix, Det, PrecisionRecall, Roc
Expand All @@ -10,4 +11,4 @@
"precision_recall": PrecisionRecall,
"roc": Roc,
}
PLOT_TYPES = (*SKLEARN_PLOTS.values(), Metric, Image)
PLOT_TYPES = (*SKLEARN_PLOTS.values(), Metric, Image, CustomPlot)
53 changes: 53 additions & 0 deletions src/dvclive/plots/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from pathlib import Path
from typing import Optional

from dvclive.serialize import dump_json

from .base import Data


class CustomPlot(Data):
suffixes = [".json"]
subfolder = "custom"

def __init__(
self,
name: str,
output_folder: str,
x: str,
y: str,
template: Optional[str],
title: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
) -> None:
super().__init__(name, output_folder)
self.name = self.name.replace(".json", "")
config = {
"template": template,
"x": x,
"y": y,
"title": title,
"x_label": x_label,
"y_label": y_label,
}
self._plot_config = {k: v for k, v in config.items() if v is not None}

@property
def output_path(self) -> Path:
_path = Path(f"{self.output_folder / self.name}.json")
_path.parent.mkdir(exist_ok=True, parents=True)
return _path

@staticmethod
def could_log(val: object) -> bool:
if isinstance(val, list) and all(isinstance(x, dict) for x in val):
return True
return False

@property
def plot_config(self):
return self._plot_config

def dump(self, val, **kwargs) -> None: # noqa: ARG002
dump_json(val, self.output_path)
133 changes: 52 additions & 81 deletions src/dvclive/plots/sklearn.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,29 @@
import copy
from pathlib import Path

from dvclive.serialize import dump_json

from .base import Data
from .custom import CustomPlot


class SKLearnPlot(Data):
suffixes = [".json"]
class SKLearnPlot(CustomPlot):
subfolder = "sklearn"

def __init__(self, name: str, output_folder: str, **kwargs) -> None: # noqa: ARG002
super().__init__(name, output_folder)
self.name = self.name.replace(".json", "")

@property
def output_path(self) -> Path:
_path = Path(f"{self.output_folder / self.name}.json")
_path.parent.mkdir(exist_ok=True, parents=True)
return _path

@staticmethod
def could_log(val: object) -> bool:
if isinstance(val, tuple) and len(val) == 2: # noqa: PLR2004
return True
return False

def get_properties(self):
raise NotImplementedError


class Roc(SKLearnPlot):
DEFAULT_PROPERTIES = {
"template": "simple",
"x": "fpr",
"y": "tpr",
"title": "Receiver operating characteristic (ROC)",
"x_label": "False Positive Rate",
"y_label": "True Positive Rate",
}

def get_properties(self):
return copy.deepcopy(self.DEFAULT_PROPERTIES)
def __init__(self, name: str, output_folder: str, **plot_config) -> None:
plot_config["template"] = plot_config.get("template", "simple")
plot_config["title"] = plot_config.get(
"title", "Receiver operating characteristic (ROC)"
)
plot_config["x_label"] = plot_config.get("x_label", "False Positive Rate")
plot_config["y_label"] = plot_config.get("y_label", "True Positive Rate")
plot_config["x"] = "fpr"
plot_config["y"] = "tpr"
super().__init__(name, output_folder, **plot_config)

def dump(self, val, **kwargs) -> None:
from sklearn import metrics
Expand All @@ -59,17 +41,14 @@ def dump(self, val, **kwargs) -> None:


class PrecisionRecall(SKLearnPlot):
DEFAULT_PROPERTIES = {
"template": "simple",
"x": "recall",
"y": "precision",
"title": "Precision-Recall Curve",
"x_label": "Recall",
"y_label": "Precision",
}

def get_properties(self):
return copy.deepcopy(self.DEFAULT_PROPERTIES)
def __init__(self, name: str, output_folder: str, **plot_config) -> None:
plot_config["template"] = plot_config.get("template", "simple")
plot_config["title"] = plot_config.get("title", "Precision-Recall Curve")
plot_config["x_label"] = plot_config.get("x_label", "Recall")
plot_config["y_label"] = plot_config.get("y_label", "Precision")
plot_config["x"] = "recall"
plot_config["y"] = "precision"
super().__init__(name, output_folder, **plot_config)

def dump(self, val, **kwargs) -> None:
from sklearn import metrics
Expand All @@ -88,17 +67,16 @@ def dump(self, val, **kwargs) -> None:


class Det(SKLearnPlot):
DEFAULT_PROPERTIES = {
"template": "simple",
"x": "fpr",
"y": "fnr",
"title": "Detection error tradeoff (DET)",
"x_label": "False Positive Rate",
"y_label": "False Negative Rate",
}

def get_properties(self):
return copy.deepcopy(self.DEFAULT_PROPERTIES)
def __init__(self, name: str, output_folder: str, **plot_config) -> None:
plot_config["template"] = plot_config.get("template", "simple")
plot_config["title"] = plot_config.get(
"title", "Detection error tradeoff (DET)"
)
plot_config["x_label"] = plot_config.get("x_label", "False Positive Rate")
plot_config["y_label"] = plot_config.get("y_label", "False Negative Rate")
plot_config["x"] = "fpr"
plot_config["y"] = "fnr"
super().__init__(name, output_folder, **plot_config)

def dump(self, val, **kwargs) -> None:
from sklearn import metrics
Expand All @@ -117,24 +95,18 @@ def dump(self, val, **kwargs) -> None:


class ConfusionMatrix(SKLearnPlot):
DEFAULT_PROPERTIES = {
"template": "confusion",
"x": "actual",
"y": "predicted",
"title": "Confusion Matrix",
"x_label": "True Label",
"y_label": "Predicted Label",
}

def __init__(self, name: str, output_folder: str, **kwargs) -> None:
super().__init__(name, output_folder)
self.normalized = kwargs.get("normalized") or False

def get_properties(self):
properties = copy.deepcopy(self.DEFAULT_PROPERTIES)
if self.normalized:
properties["template"] = "confusion_normalized"
return properties
def __init__(self, name: str, output_folder: str, **plot_config) -> None:
plot_config["template"] = (
"confusion_normalized"
if plot_config.pop("normalized", None)
else plot_config.get("template", "confusion")
)
plot_config["title"] = plot_config.get("title", "Confusion Matrix")
plot_config["x_label"] = plot_config.get("x_label", "True Label")
plot_config["y_label"] = plot_config.get("y_label", "Predicted Label")
plot_config["x"] = "actual"
plot_config["y"] = "predicted"
super().__init__(name, output_folder, **plot_config)

def dump(self, val, **kwargs) -> None: # noqa: ARG002
cm = [
Expand All @@ -145,17 +117,16 @@ def dump(self, val, **kwargs) -> None: # noqa: ARG002


class Calibration(SKLearnPlot):
DEFAULT_PROPERTIES = {
"template": "simple",
"x": "prob_pred",
"y": "prob_true",
"title": "Calibration Curve",
"x_label": "Mean Predicted Probability",
"y_label": "Fraction of Positives",
}

def get_properties(self):
return copy.deepcopy(self.DEFAULT_PROPERTIES)
def __init__(self, name: str, output_folder: str, **plot_config) -> None:
plot_config["template"] = plot_config.get("template", "simple")
plot_config["title"] = plot_config.get("title", "Calibration Curve")
plot_config["x_label"] = plot_config.get(
"x_label", "Mean Predicted Probability"
)
plot_config["y_label"] = plot_config.get("y_label", "Fraction of Positives")
plot_config["x"] = "prob_pred"
plot_config["y"] = "prob_true"
super().__init__(name, output_folder, **plot_config)

def dump(self, val, **kwargs) -> None:
from sklearn import calibration
Expand Down
Loading

0 comments on commit 5abd7bd

Please sign in to comment.