Skip to content

Commit

Permalink
Add heatmap and baseline to %dws_history magic; minor pyflakes fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jfischer committed Apr 12, 2020
1 parent cf5032a commit 9b06243
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 7 deletions.
119 changes: 114 additions & 5 deletions dataworkspaces/kits/jupyter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import re
from os.path import join, basename, dirname, abspath, expanduser, curdir, exists
from notebook.notebookapp import list_running_servers
from typing import Optional
from typing import Optional, List, Any, Dict, Tuple, Callable
assert Dict # keep pyflakes happy
import shlex
import argparse

Expand All @@ -31,7 +32,6 @@
from dataworkspaces.api import take_snapshot, get_snapshot_history,\
make_lineage_table, make_lineage_graph,\
get_results
from dataworkspaces.utils.file_utils import get_subpath_from_absolute
from dataworkspaces.errors import ConfigurationError


Expand Down Expand Up @@ -241,6 +241,19 @@ def exit(self, status=0, message=None):
assert status==0, "Expecting a status of 0"
raise DwsMagicArgParseExit()

# Colormaps for heatmaps
# These were generated using seaborn:
# def tobyte(c):
# return int(round(255*c))
# MINIMIZE_COLORMAP = ['rgb(%s,%s,%s)'%(tobyte(c[0]),tobyte(c[1]),tobyte(c[2]))
# for c in seaborn.diverging_palette(150, 10, s=50, l=50, n=7)]
# MAXIMIZE_COLORMAP = ['rgb(%s,%s,%s)'%(tobyte(c[0]),tobyte(c[1]),tobyte(c[2]))
# for c in seaborn.diverging_palette(10, 150, s=50, l=50, n=7)]
# The two maps are just the reverse of each other with maximize having greener colors toward
# the high indexes and redder colors toward the low indexes, and minimize being the opposite.
# By pre-generating the colormaps, we avoid a dependency on seaborn.
MINIMIZE_COLORMAP=['rgb(84,128,107)', 'rgb(138,168,153)', 'rgb(193,210,201)', 'rgb(242,242,242)', 'rgb(232,190,192)', 'rgb(212,136,140)', 'rgb(193,84,89)']
MAXIMIZE_COLORMAP=['rgb(193,84,89)', 'rgb(212,136,140)', 'rgb(232,190,192)', 'rgb(242,242,242)', 'rgb(193,210,201)', 'rgb(138,168,153)', 'rgb(84,128,107)']

@magics_class
class DwsMagics(Magics):
Expand Down Expand Up @@ -323,7 +336,7 @@ def dws_info(self, line):
parser = DwsMagicParseArgs("dws_info",
description="Print some information about this workspace")
try:
args = parser.parse_magic_line(line)
parser.parse_magic_line(line)
except DwsMagicArgParseExit:
return # user asked for help
if self.disabled:
Expand Down Expand Up @@ -370,6 +383,17 @@ def dws_history(self, line):
help="Maximum number of snapshots to show")
parser.add_argument('--tail', default=False, action='store_true',
help="Just show the last 10 entries in reverse order")
parser.add_argument('--baseline', default=None, type=str,
help="Snapshot tag or hash to use as a basis for metrics comparison")
parser.add_argument('--heatmap', default=False, action='store_true',
help="Show a heatmap for metrics columns")
parser.add_argument('--maximize-metrics', default=None, type=str,
help="Metrics where larger values are better (e.g. accuracy)")
parser.add_argument('--minimize-metrics', default=None, type=str,
help="Metrics where smaller values are better (e.g. loss)")
# TODO: future feature
# parser.add_argument('--round-metrics', type=int, default=None,
# help="If specified, round metrics to this many decimal places")
try:
args = parser.parse_magic_line(line)
except DwsMagicArgParseExit:
Expand All @@ -378,6 +402,11 @@ def dws_history(self, line):
display(Markdown("DWS magic commands are disabled. To enable, set `DWS_MAGIC_DISABLE` to `False` and restart kernel."))
return
import pandas as pd # TODO: support case where pandas wasn't installed
import numpy as np
if args.heatmap:
if args.baseline is not None:
print("Cannot specify both --baseline and --heatmap", file=sys.stderr)
return
if args.max_count and args.tail:
max_count = args.max_count
elif args.tail:
Expand All @@ -390,20 +419,100 @@ def dws_history(self, line):
entries = []
index = []
columns = ['timestamp', 'hash', 'tags', 'message']
baseline_snapshot = None # type: Optional[int]
# not every snapshot has the same metrics, so we build an inclusive list
metrics = [] # type: List[str]
for s in history:
d = {'timestamp':s.timestamp[0:19],
'hash':s.hashval[0:8],
'tags':', '.join([tag for tag in s.tags]),
'message':s.message}
'message':s.message if s.message is not None else ''}
if s.metrics is not None:
for (m, v) in s.metrics.items():
d[m] = v
if m not in columns:
columns.append(m)
metrics.append(m)
entries.append(d)
index.append(s.snapshot_number)
if (args.baseline is not None):
if args.baseline in s.tags:
baseline_snapshot = s.snapshot_number
elif s.hashval[0:min(len(args.baseline),8)]==args.baseline[0:8]:
baseline_snapshot = s.snapshot_number
if (args.baseline is not None) and (baseline_snapshot is None):
print("Did not find a tag or hash corresponding to baseline '%s'"
% args.baseline, file=sys.stderr)
return
history_df = pd.DataFrame(entries, index=index, columns=columns)
return history_df
maximize_metrics = set(['accuracy', 'precision', 'recall'])
if args.maximize_metrics:
maximize_metrics = maximize_metrics.union(set(args.maximize_metrics.split(',')))
minimize_metrics = set(['loss'])
if args.minimize_metrics:
minimize_metrics = minimize_metrics.union(set(args.minimize_metrics.split(',')))
def truncate(v, l=30):
s = repr(v)
return s if len(s)<=(l-3) else s[0:l-3]+'...'
def cleanup_dict_or_string_metric(val):
if isinstance(val, dict) or isinstance(val, str):
return truncate(val)
else:
return val
element_styling_fns = [] # type: List[Tuple[str, Callable[[Any], None]]]
if args.heatmap:
heatmap_maximize_cols = [] # type: List[str]
heatmap_minimize_cols = [] # type: List[str]
color_templ="border: 1px solid darkgrey; background-color: %s; color: %s"
def color_max_metric_col(col):
bins = pd.qcut(col, 7, labels=range(7), duplicates='drop').astype(np.float32).fillna(-1.0).astype(np.int32)
return bins.apply(lambda b: color_templ%(MAXIMIZE_COLORMAP[b], 'white' if b<2 or b>4 else 'black') if b!=-1
else color_templ%('lightgrey', 'black'))
def color_min_metric_col(col):
bins = pd.qcut(col, 7, labels=range(7), duplicates='drop').astype(np.float32).fillna(-1.0).astype(np.int32)
return bins.apply(lambda b: color_templ%(MINIMIZE_COLORMAP[b], 'white' if b<2 or b>4 else 'black') if b!=-1
else color_templ%('lightgrey', 'black'))
class BaselineElementStyle:
def __init__(self, metric:str, baseline, maximize:bool):
self.metric=metric
self.baseline=baseline
self.baseline_round = abs(self.baseline*0.005)
self.maximize=maximize
def __call__(self, val):
# if a value is within 0.5% of the baseline, we consider it baseline
if pd.isna(val):
return 'color: grey'
elif val>(self.baseline+self.baseline_round):
return 'color: green' if self.maximize else 'color: red'
elif val<(self.baseline-self.baseline_round):
return 'color: red' if self.maximize else 'color: green'
else: # within baseline rounding
return 'color: black; font-weight: bold'
for metric in metrics:
if history_df[metric].dtype.kind in ('f', 'i'):
# float or int
if baseline_snapshot is not None:
baseline_val = history_df.loc[baseline_snapshot][metric]
if metric in maximize_metrics:
element_styling_fns.append((metric, BaselineElementStyle(metric, baseline_val, maximize=True)),)
elif metric in minimize_metrics:
element_styling_fns.append((metric, BaselineElementStyle(metric, baseline_val, maximize=False)),)
elif args.heatmap:
if metric in maximize_metrics:
heatmap_maximize_cols.append(metric)
elif metric in minimize_metrics:
heatmap_minimize_cols.append(metric)
elif history_df[metric].dtype==np.dtype('object'):
history_df[metric] = history_df[metric].apply(cleanup_dict_or_string_metric)
result = history_df
def get_style(df_or_style):
return df_or_style.style if isinstance(df_or_style, pd.DataFrame) else df_or_style
for (metric, styling_fn) in element_styling_fns:
result = get_style(result).applymap(styling_fn, subset=[metric])
if args.heatmap:
result = get_style(result).apply(color_max_metric_col, subset=heatmap_maximize_cols)
result = get_style(result).apply(color_min_metric_col, subset=heatmap_minimize_cols)
return result

@line_magic
def dws_lineage_table(self, line):
Expand Down
2 changes: 0 additions & 2 deletions dataworkspaces/kits/scikit_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
from sklearn import metrics
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.utils import Bunch
from sklearn.base import _pprint
import sys
import numpy as np
import os
from os.path import join, abspath, expanduser, exists, isabs
import pickle
from tempfile import NamedTemporaryFile


Expand Down
1 change: 1 addition & 0 deletions tests/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pyflakes:
cd $(DATAWORKSPACES)/backends; pyflakes *.py
cd $(DATAWORKSPACES)/resources; pyflakes *.py
cd $(DATAWORKSPACES)/commands; pyflakes *.py
cd $(DATAWORKSPACES)/kits; pyflakes *.py

# shortcut for static checks
check: mypy pyflakes
Expand Down
2 changes: 2 additions & 0 deletions tests/dws-test-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ dependencies:
- pip:
- tensorflow>=2.0
- pytest
- mypy
- pyflakes

0 comments on commit 9b06243

Please sign in to comment.