Skip to content

Commit

Permalink
[Artifacts] Added Bokeh artifact (#1105)
Browse files Browse the repository at this point in the history
  • Loading branch information
guy1992l committed Aug 11, 2021
1 parent be56dd3 commit 9081f14
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 79 deletions.
1 change: 1 addition & 0 deletions extras-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ azure-storage-blob~=12.0, <12.7.0
adlfs~=0.7.1
azure-identity~=1.5
azure-keyvault-secrets~=4.2
bokeh~=2.3
2 changes: 1 addition & 1 deletion mlrun/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
from .dataset import DatasetArtifact, TableArtifact, update_dataset_meta
from .manager import ArtifactManager, ArtifactProducer, dict_to_artifact
from .model import ModelArtifact, get_model, update_model
from .plots import ChartArtifact, PlotArtifact
from .plots import BokehArtifact, ChartArtifact, PlotArtifact
64 changes: 56 additions & 8 deletions mlrun/artifacts/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from ..utils import dict_to_json
from .base import Artifact

plot_template = """<h3 style="text-align:center">{}</h3>
<img title="{}" src="data:image/png;base64,{}">"""


class PlotArtifact(Artifact):
_TEMPLATE = """
<h3 style="text-align:center">{}</h3>
<img title="{}" src="data:image/png;base64,{}">
"""
kind = "plot"

def __init__(
Expand Down Expand Up @@ -55,10 +56,11 @@ def get_body(self):
data = png_output.getvalue()

data_uri = base64.b64encode(data).decode("utf-8")
return plot_template.format(self.description or self.key, self.key, data_uri)
return self._TEMPLATE.format(self.description or self.key, self.key, data_uri)


chart_template = """
class ChartArtifact(Artifact):
_TEMPLATE = """
<html>
<head>
<script
Expand All @@ -82,8 +84,6 @@ def get_body(self):
</html>
"""


class ChartArtifact(Artifact):
kind = "chart"

def __init__(
Expand Down Expand Up @@ -121,7 +121,55 @@ def get_body(self):
self.options["title"] = self.title or self.key
data = [self.header] + self.rows
return (
chart_template.replace("$data$", dict_to_json(data))
self._TEMPLATE.replace("$data$", dict_to_json(data))
.replace("$opts$", dict_to_json(self.options))
.replace("$chart$", self.chart)
)


class BokehArtifact(Artifact):
"""
Bokeh artifact is an artifact for saving Bokeh generated figures. They will be stored in html format.
"""

kind = "bokeh"

def __init__(
self, figure, key: str = None, target_path: str = None,
):
"""
Initialize a Bokeh artifact with the given figure.
:param figure: Bokeh figure ('bokeh.plotting.Figure' object) to save as an artifact.
:param key: Key for the artifact to be stored in the database.
:param target_path: Path to save the artifact.
"""
super().__init__(key=key, target_path=target_path, viewer="bokeh")

# Validate input:
try:
from bokeh.plotting import Figure
except ImportError:
raise ImportError(
"Using 'BokehArtifact' requires bokeh package. Use pip install mlrun[bokeh] to install it"
)
if not isinstance(figure, Figure):
raise ValueError(
"BokehArtifact requires the figure parameter to be a "
"'bokeh.plotting.Figure' but received '{}'".format(type(figure))
)

# Continue initializing the bokeh artifact:
self._figure = figure
self.format = "html"

def get_body(self):
"""
Get the artifact's body - the bokeh figure's html code.
:return: The figure's html code.
"""
from bokeh.embed import file_html
from bokeh.resources import CDN

return file_html(self._figure, CDN, self.key)
209 changes: 154 additions & 55 deletions mlrun/frameworks/_common/loggers/mlrun_logger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Dict, Union
from typing import Dict, List, Union

import numpy as np
from bokeh.plotting import figure

from mlrun import MLClientCtx
from mlrun.artifacts import Artifact, ChartArtifact
from mlrun.artifacts import Artifact, BokehArtifact
from mlrun.frameworks._common.loggers.logger import Logger
from mlrun.frameworks._common.model_handler import ModelHandler

Expand Down Expand Up @@ -102,32 +103,26 @@ def log_epoch_to_context(
self._context._results = child_context.results

# Log the epochs metrics results as chart artifacts:
for metrics_prefix, metrics_dictionary in zip(
for loop, metrics_dictionary in zip(
["training", "validation"],
[self._training_results, self._validation_results],
):
for metric_name, metric_epochs in metrics_dictionary.items():
# Create the chart artifact:
chart_name = "{}_{}_epoch_{}".format(
metrics_prefix, metric_name, len(metric_epochs)
)
chart_artifact = ChartArtifact(
key="{}.html".format(chart_name),
header=["iteration", "result"],
data=list(
np.array(
[list(np.arange(len(metric_epochs[-1]))), metric_epochs[-1]]
).transpose()
),
for metric_name in metrics_dictionary:
# Create the bokeh artifact:
artifact = self._generate_metric_results_artifact(
epoch=len(metrics_dictionary[metric_name]),
loop=loop,
name=metric_name,
results=metrics_dictionary[metric_name][-1],
)
# Log the artifact:
child_context.log_artifact(
chart_artifact,
local_path=chart_artifact.key,
artifact,
local_path=artifact.key,
artifact_path=child_context.artifact_path,
)
# Collect it for later adding it to the model logging as extra data:
self._artifacts[chart_name] = chart_artifact
self._artifacts[artifact.key.split(".")[0]] = artifact

# Commit and commit children for MLRun flag bug:
self._context.commit()
Expand All @@ -147,55 +142,33 @@ def log_run(self, model_handler: ModelHandler):
:param model_handler: The model handler object holding the model to save and log.
"""
# Create chart artifacts for summaries:
for metric_name, metric_results in self._training_summaries.items():
if metric_name in self._validation_summaries:
header = ["epoch", "training_result", "validation_result"]
data = list(
np.array(
[
list(np.arange(len(metric_results))),
metric_results,
self._validation_summaries[metric_name],
]
).transpose()
)
else:
header = ["epoch", "training_result"]
data = list(
np.array(
[list(np.arange(len(metric_results))), metric_results]
).transpose()
)
# Create the chart artifact:
chart_name = "{}_summary".format(metric_name)
chart_artifact = ChartArtifact(
key="{}.html".format(chart_name), header=header, data=data,
for metric_name in self._training_summaries:
# Create the bokeh artifact:
artifact = self._generate_summary_results_artifact(
name=metric_name,
training_results=self._training_summaries[metric_name],
validation_results=self._validation_summaries.get(metric_name, None),
)
# Log the artifact:
self._context.log_artifact(
chart_artifact, local_path=chart_artifact.key,
artifact, local_path=artifact.key,
)
# Collect it for later adding it to the model logging as extra data:
self._artifacts[chart_name] = chart_artifact
self._artifacts[artifact.key.split(".")[0]] = artifact

# Create chart artifacts for dynamic hyperparameters:
for parameter_name, parameter_values in self._dynamic_hyperparameters.items():
for parameter_name in self._dynamic_hyperparameters:
# Create the chart artifact:
chart_artifact = ChartArtifact(
key="{}.html".format(parameter_name),
header=["epoch", "value"],
data=list(
np.array(
[list(np.arange(len(parameter_values))), parameter_values]
).transpose(),
),
artifact = self._generate_dynamic_hyperparameter_values_artifact(
name=parameter_name,
values=self._dynamic_hyperparameters[parameter_name],
)
# Log the artifact:
self._context.log_artifact(
chart_artifact, local_path=chart_artifact.key,
artifact, local_path=artifact.key,
)
# Collect it for later adding it to the model logging as extra data:
self._artifacts[parameter_name] = chart_artifact
self._artifacts[artifact.key.split(".")[0]] = artifact

# Log the model:
model_handler.set_context(context=self._context)
Expand All @@ -208,3 +181,129 @@ def log_run(self, model_handler: ModelHandler):

# Commit:
self._context.commit()

@staticmethod
def _generate_metric_results_artifact(
epoch: int, loop: str, name: str, results: List[float]
) -> BokehArtifact:
"""
Generate a bokeh artifact for the results of the metric provided.
:param epoch: The epoch of the recorded resutls.
:param loop: The results loop, training or validation.
:param name: The metric name.
:param results: The metric results at the given epoch.
:return: The generated bokeh figure wrapped in MLRun artifact.
"""
# Parse the artifact's name:
artifact_name = "{}_{}_epoch_{}".format(loop, name, epoch)

# Initialize a bokeh figure:
metric_figure = figure(
title="{} Results for epoch {}".format(name, epoch),
x_axis_label="Batches",
y_axis_label="Results",
x_axis_type="linear",
)

# Draw the results:
metric_figure.line(x=list(np.arange(len(results))), y=results)

# Create the bokeh artifact:
artifact = BokehArtifact(
key="{}.html".format(artifact_name), figure=metric_figure
)

return artifact

@staticmethod
def _generate_summary_results_artifact(
name: str, training_results: List[float], validation_results: List[float]
) -> BokehArtifact:
"""
Generate a bokeh artifact for the results summary across all the epochs of training.
:param name: The metric name.
:param training_results: The metric training results summaries across the epochs.
:param validation_results: The metric validation results summaries across the epochs. If validation was not
performed, None should be passed.
:return: The generated bokeh figure wrapped in MLRun artifact.
"""
# Parse the artifact's name:
artifact_name = "{}_summary".format(name)

# Initialize a bokeh figure:
summary_figure = figure(
title="{} Summary".format(name),
x_axis_label="Epochs",
y_axis_label="Results",
x_axis_type="linear",
)

# Draw the results:
summary_figure.line(
x=list(np.arange(1, len(training_results) + 1)),
y=training_results,
legend_label="Training",
)
summary_figure.circle(
x=list(np.arange(1, len(training_results) + 1)),
y=training_results,
legend_label="Training",
)
if validation_results is not None:
summary_figure.line(
x=list(np.arange(1, len(validation_results) + 1)),
y=validation_results,
legend_label="Validation",
color="orangered",
)
summary_figure.circle(
x=list(np.arange(1, len(validation_results) + 1)),
y=validation_results,
legend_label="Validation",
color="orangered",
)

# Create the bokeh artifact:
artifact = BokehArtifact(
key="{}.html".format(artifact_name), figure=summary_figure
)

return artifact

@staticmethod
def _generate_dynamic_hyperparameter_values_artifact(
name: str, values: List[float]
) -> BokehArtifact:
"""
Generate a bokeh artifact for the values of the hyperparameter provided.
:param name: The hyperparameter name.
:param values: The hyperparameter values across the training.
:return: The generated bokeh figure wrapped in MLRun artifact.
"""
# Parse the artifact's name:
artifact_name = "{}.html".format(name)

# Initialize a bokeh figure:
hyperparameter_figure = figure(
title=name,
x_axis_label="Epochs",
y_axis_label="Values",
x_axis_type="linear",
)

# Draw the values:
hyperparameter_figure.line(x=list(np.arange(len(values))), y=values)
hyperparameter_figure.circle(x=list(np.arange(len(values))), y=values)

# Create the bokeh artifact:
artifact = BokehArtifact(
key="{}.html".format(artifact_name), figure=hyperparameter_figure
)

return artifact
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def load_deps(path):
# <12.7.0 from adlfs 0.6.3
"azure-blob-storage": ["azure-storage-blob~=12.0, <12.7.0", "adlfs~=0.7.1"],
"azure-key-vault": ["azure-identity~=1.5", "azure-keyvault-secrets~=4.2"],
# mlrun.frameworks requirements per framework: # TODO: should be added in a later PR
"bokeh": ["bokeh~=2.3"],
}
extras_require["complete"] = sorted(
{
Expand Down

0 comments on commit 9081f14

Please sign in to comment.