Skip to content

Commit

Permalink
Add support for non-categorical HeatMap (#4180)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Jan 13, 2020
1 parent 7d9f5cc commit cfb9704
Show file tree
Hide file tree
Showing 15 changed files with 349 additions and 303 deletions.
16 changes: 10 additions & 6 deletions examples/reference/elements/bokeh/HeatMap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@
"metadata": {},
"outputs": [],
"source": [
"data = [(chr(65+i), chr(97+j), i*j) for i in range(5) for j in range(5) if i!=j]\n",
"hv.HeatMap(data).sort()"
"data = [(i, chr(97+j), i*j) for i in range(5) for j in range(5) if i!=j]\n",
"hm = hv.HeatMap(data).sort()\n",
"hm.opts(xticks=None)"
]
},
{
Expand All @@ -70,7 +71,9 @@
"source": [
"As the above example shows before aggregating the second value for the (0, 0) is ignored unless we aggregate the data first.\n",
"\n",
"To reveal the values of a ``HeatMap`` we can either enable a ``colorbar`` or add a hover tool. The hover tools even allows displaying any number of additional value dimensions, providing additional information a static plot could not capture:"
"To reveal the values of a ``HeatMap`` we can either enable a ``colorbar`` or add a hover tool. The hover tools even allows displaying any number of additional value dimensions, providing additional information a static plot could not capture. \n",
"\n",
"Note that a HeatMap allows mixtures of categorical, numeric and datetime values along the x- and y-axes:"
]
},
{
Expand All @@ -79,9 +82,10 @@
"metadata": {},
"outputs": [],
"source": [
"heatmap = hv.HeatMap((np.random.randint(0, 10, 100), np.random.randint(0, 10, 100),\n",
" np.random.randn(100), np.random.randn(100)), vdims=['z', 'z2']).redim.range(z=(-2, 2))\n",
"heatmap.opts(opts.HeatMap(tools=['hover'], colorbar=True, width=325, toolbar='above'))"
"heatmap = hv.HeatMap((np.random.randint(0, 10, 100), np.random.choice(['A', 'B', 'C', 'D', 'E'], 100), \n",
" np.random.randn(100), np.random.randn(100)), vdims=['z', 'z2']).sort().aggregate(function=np.mean)\n",
"\n",
"heatmap.opts(opts.HeatMap(tools=['hover'], colorbar=True, width=325, toolbar='above', clim=(-2, 2)))"
]
},
{
Expand Down
11 changes: 7 additions & 4 deletions examples/reference/elements/matplotlib/HeatMap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"source": [
"import numpy as np\n",
"import holoviews as hv\n",
"\n",
"hv.extension('matplotlib')"
]
},
Expand Down Expand Up @@ -69,7 +70,9 @@
"source": [
"As the above example shows before aggregating the second value for the (0, 0) is ignored unless we aggregate the data first.\n",
"\n",
"To reveal the values of a ``HeatMap`` we can enable a ``colorbar`` and if you wish to have interactive hover information, you can use the hover tool in the [Bokeh backend](../bokeh/HeatMap.ipynb):"
"To reveal the values of a ``HeatMap`` we can enable a ``colorbar`` and if you wish to have interactive hover information, you can use the hover tool in the [Bokeh backend](../bokeh/HeatMap.ipynb).\n",
"\n",
"Note that a HeatMap allows mixtures of categorical, numeric and datetime values along the x- and y-axes:"
]
},
{
Expand All @@ -78,10 +81,10 @@
"metadata": {},
"outputs": [],
"source": [
"heatmap = hv.HeatMap((np.random.randint(0, 10, 100), np.random.randint(0, 10, 100),\n",
" np.random.randn(100), np.random.randn(100)), vdims=['z', 'z2']).redim.range(z=(-2, 2))\n",
"heatmap = hv.HeatMap((np.random.randint(0, 10, 100), np.random.choice(['A', 'B', 'C', 'D', 'E'], 100), \n",
" np.random.randn(100), np.random.randn(100)), vdims=['z', 'z2']).sort().aggregate(function=np.mean)\n",
"\n",
"heatmap.opts(colorbar=True, fig_size=250)"
"heatmap.opts(colorbar=True, fig_size=250, clim=(-2, 2))"
]
},
{
Expand Down
22 changes: 20 additions & 2 deletions examples/reference/elements/plotly/HeatMap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"source": [
"import numpy as np\n",
"import holoviews as hv\n",
"from holoviews import opts\n",
"\n",
"hv.extension('plotly')"
]
},
Expand All @@ -39,7 +41,7 @@
"metadata": {},
"outputs": [],
"source": [
"data = [(chr(65+i), chr(97+j), i*j) for i in range(5) for j in range(5) if i!=j]\n",
"data = [(i, chr(97+j), i*j) for i in range(5) for j in range(5) if i!=j]\n",
"hv.HeatMap(data).opts(cmap='RdBu_r')"
]
},
Expand All @@ -65,7 +67,23 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"As the above example shows before aggregating the second value for the (0, 0) is ignored unless we aggregate the data first."
"As the above example shows before aggregating the second value for the (0, 0) is ignored unless we aggregate the data first.\n",
"\n",
"To reveal the values of a ``HeatMap`` we can either enable a ``colorbar`` or use the hover tool.\n",
"\n",
"Note that a HeatMap allows mixtures of categorical, numeric and datetime values along the x- and y-axes:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"heatmap = hv.HeatMap((np.random.randint(0, 10, 100), np.random.choice(['A', 'B', 'C', 'D', 'E'], 100), \n",
" np.random.randn(100), np.random.randn(100)), vdims=['z', 'z2']).sort().aggregate(function=np.mean)\n",
"\n",
"heatmap.opts(opts.HeatMap(colorbar=True, clim=(-2, 2)))"
]
},
{
Expand Down
5 changes: 4 additions & 1 deletion holoviews/core/data/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def _infer_interval_breaks(cls, coord, axis=0):
coord = coord.astype('datetime64')
if len(coord) == 0:
return np.array([], dtype=coord.dtype)
deltas = 0.5 * np.diff(coord, axis=axis)
if len(coord) > 1:
deltas = 0.5 * np.diff(coord, axis=axis)
else:
deltas = np.array([0.5])
first = np.take(coord, [0], axis=axis) - np.take(deltas, [0], axis=axis)
last = np.take(coord, [-1], axis=axis) + np.take(deltas, [-1], axis=axis)
trim_last = tuple(slice(None, -1) if n == axis else slice(None)
Expand Down
36 changes: 36 additions & 0 deletions holoviews/element/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,3 +937,39 @@ class HeatMap(Dataset, Element2D):
def __init__(self, data, kdims=None, vdims=None, **params):
super(HeatMap, self).__init__(data, kdims=kdims, vdims=vdims, **params)
self.gridded = categorical_aggregate2d(self)

@property
def _unique(self):
"""
Reports if the Dataset is unique.
"""
return self.gridded.label != 'non-unique'

def range(self, dim, data_range=True, dimension_range=True):
"""Return the lower and upper bounds of values along dimension.
Args:
dimension: The dimension to compute the range on.
data_range (bool): Compute range from data values
dimension_range (bool): Include Dimension ranges
Whether to include Dimension range and soft_range
in range calculation
Returns:
Tuple containing the lower and upper bound
"""
dim = self.get_dimension(dim)
if dim in self.kdims:
try:
self.gridded._binned = True
if self.gridded is self:
return super(HeatMap, self).range(dim, data_range, dimension_range)
else:
drange = self.gridded.range(dim, data_range, dimension_range)
except:
drange = None
finally:
self.gridded._binned = False
if drange is not None:
return drange
return super(HeatMap, self).range(dim, data_range, dimension_range)
32 changes: 25 additions & 7 deletions holoviews/element/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def _get_coords(self, obj):
xdim, ydim = obj.dimensions(label=True)[:2]
xcoords = obj.dimension_values(xdim, False)
ycoords = obj.dimension_values(ydim, False)
if xcoords.dtype.kind not in 'SUO':
xcoords = np.sort(xcoords)
if ycoords.dtype.kind not in 'SUO':
return xcoords, np.sort(ycoords)

# Determine global orderings of y-values using topological sort
grouped = obj.groupby(xdim, container_type=OrderedDict,
Expand All @@ -155,14 +159,14 @@ def _get_coords(self, obj):
elif not is_cyclic(orderings):
coords = list(itertools.chain(*sort_topologically(orderings)))
ycoords = coords if len(coords) == len(ycoords) else np.sort(ycoords)
return xcoords, ycoords
return np.asarray(xcoords), np.asarray(ycoords)


def _aggregate_dataset(self, obj, xcoords, ycoords):
def _aggregate_dataset(self, obj):
"""
Generates a gridded Dataset from a column-based dataset and
lists of xcoords and ycoords
"""
xcoords, ycoords = self._get_coords(obj)
dim_labels = obj.dimensions(label=True)
vdims = obj.dimensions()[2:]
xdim, ydim = dim_labels[:2]
Expand Down Expand Up @@ -195,6 +199,18 @@ def _aggregate_dataset(self, obj, xcoords, ycoords):
return agg.clone(grid_data, kdims=[xdim, ydim], vdims=vdims,
datatype=self.p.datatype)

def _aggregate_dataset_pandas(self, obj):
index_cols = [d.name for d in obj.kdims]
df = obj.data.set_index(index_cols).groupby(index_cols, sort=False).first()
label = 'unique' if len(df) == len(obj) else 'non-unique'
levels = self._get_coords(obj)
index = pd.MultiIndex.from_product(levels, names=df.index.names)
reindexed = df.reindex(index)
data = tuple(levels)
shape = tuple(d.shape[0] for d in data)
for vdim in obj.vdims:
data += (reindexed[vdim.name].values.reshape(shape).T,)
return obj.clone(data, datatype=self.p.datatype, label=label)

def _process(self, obj, key=None):
"""
Expand All @@ -210,10 +226,12 @@ def _process(self, obj, key=None):
raise ValueError("Must have at two dimensions to aggregate over"
"and one value dimension to aggregate on.")

dtype = 'dataframe' if pd else 'dictionary'
obj = Dataset(obj, datatype=[dtype])
xcoords, ycoords = self._get_coords(obj)
return self._aggregate_dataset(obj, xcoords, ycoords)
if pd:
obj = Dataset(obj, datatype=['dataframe'])
return self._aggregate_dataset_pandas(obj)
else:
obj = Dataset(obj, datatype=['dictionary'])
return self._aggregate_dataset(obj)


def circular_layout(nodes):
Expand Down
104 changes: 50 additions & 54 deletions holoviews/plotting/bokeh/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import param
import numpy as np

from bokeh.models import Span
from bokeh.models.glyphs import AnnularWedge

from ...core.data import GridInterface
from ...core.util import is_nan, dimension_sanitizer
from ...core.spaces import HoloMap
from .element import ColorbarPlot, CompositeElementPlot
from .styles import line_properties, fill_properties, mpl_to_bokeh, text_properties
from .styles import line_properties, fill_properties, text_properties


class HeatMapPlot(ColorbarPlot):
Expand Down Expand Up @@ -54,8 +54,6 @@ class HeatMapPlot(ColorbarPlot):
['ymarks_' + p for p in line_properties] +
['cmap', 'color', 'dilate', 'visible'] + line_properties + fill_properties)

_categorical = True

@classmethod
def is_radial(cls, heatmap):
heatmap = heatmap.last if isinstance(heatmap, HoloMap) else heatmap
Expand All @@ -70,25 +68,58 @@ def get_data(self, element, ranges, style):
x, y, z = [dimension_sanitizer(d) for d in element.dimensions(label=True)[:3]]
if self.invert_axes: x, y = y, x
cmapper = self._get_colormapper(element.vdims[0], element, ranges, style)
if 'line_alpha' not in style: style['line_alpha'] = 0
if 'line_alpha' not in style and 'line_width' not in style:
style['line_alpha'] = 0
elif 'line_color' not in style:
style['line_color'] = 'white'

if not element._unique:
self.warning('HeatMap element index is not unique, ensure you '
'aggregate the data before displaying it, e.g. '
'using heatmap.aggregate(function=np.mean). '
'Duplicate index values have been dropped.')

if self.static_source:
return {}, {'x': x, 'y': y, 'fill_color': {'field': 'zvalues', 'transform': cmapper}}, style

aggregate = element.gridded
xdim, ydim = aggregate.dimensions()[:2]
xvals, yvals = (aggregate.dimension_values(x),
aggregate.dimension_values(y))

xtype = aggregate.interface.dtype(aggregate, xdim)
widths = None
if xtype.kind in 'SUO':
xvals = aggregate.dimension_values(xdim)
width = 1
else:
xvals = aggregate.dimension_values(xdim, flat=False)
edges = GridInterface._infer_interval_breaks(xvals, axis=1)
widths = np.diff(edges, axis=1).T.flatten()
xvals = xvals.T.flatten()
width = 'width'

ytype = aggregate.interface.dtype(aggregate, ydim)
heights = None
if ytype.kind in 'SUO':
yvals = aggregate.dimension_values(ydim)
height = 1
else:
yvals = aggregate.dimension_values(ydim, flat=False)
edges = GridInterface._infer_interval_breaks(yvals, axis=0)
heights = np.diff(edges, axis=0).T.flatten()
yvals = yvals.T.flatten()
height = 'height'

zvals = aggregate.dimension_values(2, flat=False)
zvals = zvals.T.flatten()

if self.invert_axes:
xdim, ydim = ydim, xdim
zvals = zvals.T.flatten()
else:
zvals = zvals.T.flatten()
if xvals.dtype.kind not in 'SU':
xvals = [xdim.pprint_value(xv) for xv in xvals]
if yvals.dtype.kind not in 'SU':
yvals = [ydim.pprint_value(yv) for yv in yvals]
width, height = height, width

data = {x: xvals, y: yvals, 'zvalues': zvals}
if widths is not None:
data['width'] = widths
if heights is not None:
data['height'] = heights

if 'hover' in self.handles and not self.static_source:
for vdim in element.vdims:
Expand All @@ -100,48 +131,13 @@ def get_data(self, element, ranges, style):
style = {k: v for k, v in style.items() if not
any(g in k for g in RadialHeatMapPlot._style_groups.values())}
return (data, {'x': x, 'y': y, 'fill_color': {'field': 'zvalues', 'transform': cmapper},
'height': 1, 'width': 1}, style)
'height': height, 'width': width}, style)

def _draw_markers(self, plot, element, marks, axis='x'):
if marks is None:
if marks is None or self.radial:
return
style = self.style[self.cyclic_index]
mark_opts = {k[7:]: v for k, v in style.items() if axis+'mark' in k}
mark_opts = {'line_'+k if k in ('color', 'alpha') else k: v
for k, v in mpl_to_bokeh(mark_opts).items()}
categories = list(element.dimension_values(0 if axis == 'x' else 1,
expanded=False))

if callable(marks):
positions = [i for i, x in enumerate(categories) if marks(x)]
elif isinstance(marks, int):
nth_mark = np.ceil(len(categories) / marks).astype(int)
positions = np.arange(len(categories)+1)[::nth_mark]
elif isinstance(marks, tuple):
positions = [categories.index(m) for m in marks if m in categories]
else:
positions = [m for m in marks if isinstance(m, int) and m < len(categories)]
if axis == 'y':
positions = [len(categories)-p for p in positions]

prev_markers = self.handles.get(axis+'marks', [])
new_markers = []
for i, p in enumerate(positions):
if i < len(prev_markers):
span = prev_markers[i]
span.update(**dict(mark_opts, location=p))
else:
dimension = 'height' if axis == 'x' else 'width'
span = Span(level='annotation', dimension=dimension,
location=p, **mark_opts)
plot.renderers.append(span)
span.visible = True
new_markers.append(span)
for pm in prev_markers:
if pm not in new_markers:
pm.visible = False
new_markers.append(pm)
self.handles[axis+'marks'] = new_markers
self.warning('Only radial HeatMaps supports marks, to make the'
'HeatMap quads for distinguishable set a line_width')

def _init_glyphs(self, plot, element, ranges, source):
super(HeatMapPlot, self)._init_glyphs(plot, element, ranges, source)
Expand Down
Loading

0 comments on commit cfb9704

Please sign in to comment.