In [None]:
##### Copyright 2020 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NitroML: Benchmark Overview

This notebook allows users to analyze NitroML benchmark results for both running and completed pipelines.


<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/googleinterns/nitroml/blob/master/notebooks/overview.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/googleinterns/nitroml/blob/master/notebooks/overview.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

In [0]:
import altair as alt
import tensorflow.compat.v2 as tf

from colabtools import interactive_forms
from colabtools import interactive_table
from nitroml.benchmark import results
from google3.third_party.ml_metadata.proto import metadata_store_pb2
from google3.third_party.ml_metadata.metadata_store import metadata_store

tf.enable_v2_behavior()

## Connect to the ML Metadata (MLMD) database

First we need to connect to our MLMD database which stores the results of our
benchmark runs.

In [0]:
#@title { run:"auto" }
#@test {"skip": true}

#@markdown ### **Step 1:** Choose your MLMD Database:
database = "SQLite (CNS)"  #@param [ "SQLite (Local)", "SQLite (CNS)", "MySQL"]
enable_upgrade_migration = False  #@param { type: "boolean" }

#@markdown ### **Step 2:** Configure the connection:
config = metadata_store_pb2.ConnectionConfig()
if database == "SQLite (Local)":
  #@markdown #### If you selected `SQLite (Local)`:
  config.sqlite.filename_uri = "/tmp/nitroml/example/mlmd.sqlite"  #@param {type:"string"}
elif database == "SQLite (CNS)":
  #@markdown #### If you selected `SQLite (CNS)`:
  config.cns.cns_dir = "/cns/ok-d/home/user/nitroml/pipelines/mlmd"  #@param {type:"string"}
  config.cns.db_file = "mlmd.sqlite"  #@param {type:"string"}
  config.cns.use_mvcc = True  #@param {type:"boolean"}
else:
  #@markdown #### If you selected `MySQL`:
  config.mysql.socket = "/cloudsql/google.com:rube-mldb-v0:us-central1:rube-mldb-v0"  #@param {type:"string"}
  config.mysql.database = "nitroml_mldb_user"  #@param {type:"string"}
  config.mysql.user = "root"  #@param {type:"string"}
  config.mysql.password = "rube-super-secret-password"  #@param {type:"string"}

store = metadata_store.MetadataStore(
    config, enable_upgrade_migration=enable_upgrade_migration)

## Display benchmark results

Next we load and visualize `pd.DataFrame` containing our benchmark results.
These results contain contextual features such as the start time, and pipeline
ID, and benchmark metrics as computed by the downstream Evaluators. If your
benchmark included an `EstimatorTrainer` component, its hyperparameters may also
display in the table below.

In [0]:
#@title { run:"auto" }
#@test {"skip": true}

#@markdown ### Choose how to aggregate metrics:
mean = False  #@param { type: "boolean" }
stdev = False  #@param { type: "boolean" }
min_and_max = False  #@param { type: "boolean" }

agg = []
if mean:
  agg.append("mean")
if stdev:
  agg.append("std")
if min_and_max:
  agg += ["min", "max"]

df = results.overview(store, metric_aggregators=agg)
interactive_table.Create(
    dataframe=df.sort_index(ascending=False),
    num_rows_per_page=25,
    max_columns=100)

## Visualize metrics over time

Finally, we plot our benchmark metrics over time. This allows us to monitor how
our model quality changed as we tweaked our pipelines.

If you used the `--runs_per_benchmark` flag, you can even display error bars.

In [0]:
#@title { run:"auto" }
#@test {"skip": true}

#@markdown ### Apply filters and select metrics to plot

df = results.overview(store)
benchmark_names = df[
    results.BENCHMARK_KEY].unique() if results.BENCHMARK_KEY in df else []
benchmark_metrics = set(df.columns.values)
benchmark_metrics.difference_update({results.BENCHMARK_KEY, results.STARTED_AT})

benchmark_filter = ""  #@param {'type': 'string'}
benchmark_metric = "accuracy"  #@param []
interactive_forms.UpdateParam("benchmark_metric", None,
                              sorted(benchmark_metrics))
error_bars = "ci"  #@param [ "ci", "stdev", "stderr", "iqr"]

# Remove rows where then benchmarks metric column is NaN.
df = df.dropna(subset=[benchmark_metric])

benchmarks_to_display = [
    b for b in benchmark_names
    if benchmark_filter in b or re.match(benchmark_filter, b)
]

predicate = alt.FieldOneOfPredicate(
    field=results.BENCHMARK_KEY, oneOf=benchmarks_to_display)

line = alt.Chart(df.reset_index()).mark_line(point=True).encode(
    x=results.STARTED_AT + ":T",
    y=f"mean({benchmark_metric}):Q",
    color=alt.Color(
        results.BENCHMARK_KEY + ":N",
        legend=alt.Legend(title="Benchmark", labelLimit=500)),
    tooltip=[results.BENCHMARK_KEY + ":N"] + [f"mean({benchmark_metric}):Q"],
).transform_filter(predicate).properties(
    width=800, height=450).interactive()
band = alt.Chart(df.reset_index()).mark_errorband(extent=error_bars).encode(
    x=results.STARTED_AT + ":T",
    y=alt.Y(benchmark_metric, scale=alt.Scale(zero=False)),
    color="benchmark",
).transform_filter(predicate)
bars = alt.Chart(df.reset_index()).mark_errorbar(extent=error_bars).encode(
    x=results.STARTED_AT + ":T",
    y=benchmark_metric,
    color="benchmark",
).transform_filter(predicate)
line + band + bars

<IPython.core.display.Javascript at 0x7f1c0ed0ee48>

<IPython.core.display.Javascript at 0x7f1c0ee2d828>

<IPython.core.display.Javascript at 0x7f1c0ee2dc50>

<IPython.core.display.Javascript at 0x7f1c0ee2d828>

<IPython.core.display.Javascript at 0x7f1c0ed9ab00>