Skip to content

Commit

Permalink
Merge pull request #346 from ioam/spikes_element
Browse files Browse the repository at this point in the history
Added Spikes Element
  • Loading branch information
jlstevens committed Dec 11, 2015
2 parents 54dfec8 + 7716fba commit bafc20a
Show file tree
Hide file tree
Showing 14 changed files with 360 additions and 47 deletions.
90 changes: 90 additions & 0 deletions doc/Tutorials/Elements.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
" <dt>[``Scatter``](#Scatter)</dt><dd>Discontinuous collection of points indexed over a single dimension.</dd>\n",
" <dt>[``Points``](#Points)</dt><dd>Discontinuous collection of points indexed over two dimensions.</dd>\n",
" <dt>[``VectorField``](#VectorField)</dt><dd>Cyclic variable (and optional auxiliary data) distributed over two-dimensional space.</dd>\n",
" <dt>[``Spikes``](#Spikes)</dt><dd>A collection of horizontal or vertical lines at various locations with fixed height (1D) or variable height (2D).</dd>\n",
" <dt>[``SideHistogram``](#SideHistogram)</dt><dd>Histogram binning data contained by some other ``Element``.</dd>\n",
" </dl>\n",
"\n",
Expand Down Expand Up @@ -461,6 +462,95 @@
"points + points[0.3:0.7, 0.3:0.7].hist()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ``Spikes`` <a id='Spikes'></a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Spikes represent any number of horizontal or vertical line segments with fixed or variable heights. There are a number of uses for this type, first of all they may be used as a rugplot to give an overview of a one-dimensional distribution. They may also be useful in more domain specific cases, such as visualizing spike trains for neurophysiology or spectrograms in physics and chemistry applications.\n",
"\n",
"In the simplest case a Spikes object therefore represents a 1D distribution:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%%opts Spikes (alpha=0.4)\n",
"xs = np.random.rand(50)\n",
"ys = np.random.rand(50)\n",
"hv.Points((xs, ys)) * hv.Spikes(xs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When supplying two dimensions to the Spikes object the second dimension will be mapped onto the line height. Optionally you may also supply a cmap and color_index to map color onto one of the dimensions. This way we can for example plot a mass spectrogram:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%%opts Spikes (cmap='Reds')\n",
"hv.Spikes(np.random.rand(20, 2), kdims=['Mass'], vdims=['Intensity'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another possibility is to draw a number of spike trains as you would encounter in neuroscience. Here we generate 10 separate random spike trains and distribute them evenly across the space by setting their ``position``. By also declaring some yticks each spike traing can be labeled individually:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%%opts Spikes NdOverlay [show_legend=False]\n",
"hv.NdOverlay({i: hv.Spikes(np.random.randint(0, 100, 10), kdims=['Time'])(plot=dict(position=0.1*i))\n",
" for i in range(10)})(plot=dict(yticks=[((i+1)*0.1-0.05, i) for i in range(10)]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally we may use ``Spikes`` to visualize marginal distributions as adjoined plots:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%%opts Spikes (alpha=0.05) [spike_length=0.5] AdjointLayout [border_size=0]\n",
"points = hv.Points(np.random.randn(500, 2))\n",
"points << hv.Spikes(points['y']) << hv.Spikes(points['x'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
2 changes: 1 addition & 1 deletion doc/reference_data
Submodule reference_data updated 162 files
21 changes: 21 additions & 0 deletions holoviews/element/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,24 @@ def __init__(self, data, **params):
if isinstance(data, list) and all(isinstance(d, np.ndarray) for d in data):
data = np.column_stack([d.flat if d.ndim > 1 else d for d in data])
super(VectorField, self).__init__(data, **params)



class Spikes(Chart):
"""
Spikes is a 1D or 2D Element, which represents a series of
vertical or horizontal lines distributed along some dimension. If
an additional dimension is supplied it will be used to specify the
height of the lines. The Element may therefore be used to
represent 1D distributions, spectrograms or spike trains in
electrophysiology.
"""

group = param.String(default='Spikes', constant=True)

kdims = param.List(default=[Dimension('x')])

vdims = param.List(default=[])

_1d = True

5 changes: 5 additions & 0 deletions holoviews/element/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def register(cls):
cls.equality_type_funcs[Trisurface] = cls.compare_trisurface
cls.equality_type_funcs[Histogram] = cls.compare_histogram
cls.equality_type_funcs[Bars] = cls.compare_bars
cls.equality_type_funcs[Spikes] = cls.compare_spikes

# Tables
cls.equality_type_funcs[ItemTable] = cls.compare_itemtables
Expand Down Expand Up @@ -500,6 +501,10 @@ def compare_vectorfield(cls, el1, el2, msg='VectorField'):
def compare_bars(cls, el1, el2, msg='Bars'):
cls.compare_columns(el1, el2, msg)

@classmethod
def compare_spikes(cls, el1, el2, msg='Spikes'):
cls.compare_columns(el1, el2, msg)

#=========#
# Rasters #
#=========#
Expand Down
10 changes: 6 additions & 4 deletions holoviews/plotting/bokeh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ...element import (Curve, Points, Scatter, Image, Raster, Path,
RGB, Histogram, Spread, HeatMap, Contours,
Path, Box, Bounds, Ellipse, Polygons,
ErrorBars, Text, HLine, VLine, Spline,
ErrorBars, Text, HLine, VLine, Spline, Spikes,
Table, ItemTable, Surface, Scatter3D, Trisurface)
from ...core.options import Options, Cycle, OptionTree
from ...interface import DFrame
Expand All @@ -14,7 +14,7 @@
from .callbacks import Callbacks
from .element import OverlayPlot, BokehMPLWrapper, BokehMPLRawWrapper
from .chart import (PointPlot, CurvePlot, SpreadPlot, ErrorPlot, HistogramPlot,
AdjointHistogramPlot)
SideHistogramPlot, SpikesPlot, SideSpikesPlot)
from .path import PathPlot, PolygonPlot
from .plot import GridPlot, LayoutPlot, AdjointLayoutPlot
from .raster import RasterPlot, RGBPlot, HeatmapPlot
Expand All @@ -37,6 +37,7 @@
Scatter: PointPlot,
ErrorBars: ErrorPlot,
Spread: SpreadPlot,
Spikes: SpikesPlot,

# Rasters
Image: RasterPlot,
Expand Down Expand Up @@ -80,8 +81,8 @@
'bokeh')


AdjointLayoutPlot.registry[Histogram] = AdjointHistogramPlot

AdjointLayoutPlot.registry[Histogram] = SideHistogramPlot
AdjointLayoutPlot.registry[Spikes] = SideSpikesPlot

try:
from ..mpl.seaborn import TimeSeriesPlot, BivariatePlot, DistributionPlot
Expand Down Expand Up @@ -114,6 +115,7 @@
options.Spread = Options('style', fill_color=Cycle(), fill_alpha=0.6, line_color='black')
options.Histogram = Options('style', fill_color="#036564", line_color="#033649")
options.Points = Options('style', color=Cycle())
options.Spikes = Options('style', color='black')

# Paths
options.Contours = Options('style', color=Cycle())
Expand Down
92 changes: 88 additions & 4 deletions holoviews/plotting/bokeh/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ...core import Dimension
from ...core.util import max_range
from ...element import Chart, Raster, Points, Polygons
from ...element import Chart, Raster, Points, Polygons, Spikes
from ..util import compute_sizes, get_sideplot_ranges
from .element import ElementPlot, line_properties, fill_properties
from .path import PathPlot, PolygonPlot
Expand Down Expand Up @@ -143,25 +143,31 @@ def get_data(self, element, ranges=None):
return (data, mapping)


class AdjointHistogramPlot(HistogramPlot):
class SideHistogramPlot(HistogramPlot):

style_opts = HistogramPlot.style_opts + ['cmap']

width = param.Integer(default=125)
height = param.Integer(default=125, doc="The height of the plot")

width = param.Integer(default=125, doc="The width of the plot")

show_title = param.Boolean(default=False, doc="""
Whether to display the plot title.""")

def get_data(self, element, ranges=None):
if self.invert_axes:
mapping = dict(top='left', bottom='right', left=0, right='top')
else:
mapping = dict(top='top', bottom=0, left='left', right='right')

data = dict(top=element.values, left=element.edges[:-1],
right=element.edges[1:])

dim = element.get_dimension(0).name
main = self.adjoined.main
range_item, main_range, dim = get_sideplot_ranges(self, element, main, ranges)
vals = element.dimension_values(dim)
if isinstance(range_item, (Raster, Points, Polygons)):
if isinstance(range_item, (Raster, Points, Polygons, Spikes)):
style = self.lookup_options(range_item, 'style')[self.cyclic_index]
else:
style = {}
Expand Down Expand Up @@ -203,3 +209,81 @@ def get_data(self, element, ranges=None):
err_xs.append((x, x))
err_ys.append((y - neg, y + pos))
return (dict(xs=err_xs, ys=err_ys), self._mapping)


class SpikesPlot(PathPlot):

color_index = param.Integer(default=1, doc="""
Index of the dimension from which the color will the drawn""")

spike_length = param.Number(default=0.5, doc="""
The length of each spike if Spikes object is one dimensional.""")

position = param.Number(default=0., doc="""
The position of the lower end of each spike.""")

style_opts = (['color', 'cmap', 'palette'] + line_properties)

def get_extents(self, element, ranges):
l, b, r, t = super(SpikesPlot, self).get_extents(element, ranges)
if len(element.dimensions()) == 1:
b, t = self.position, self.position+self.spike_length
return l, b, r, t


def get_data(self, element, ranges=None):
style = self.style[self.cyclic_index]
dims = element.dimensions(label=True)

pos = self.position
if len(dims) > 1:
xs, ys = zip(*(((x, x), (pos, pos+y))
for x, y in element.array()))
mapping = dict(xs=dims[0], ys=dims[1])
keys = (dims[0], dims[1])
else:
height = self.spike_length
xs, ys = zip(*(((x[0], x[0]), (pos, pos+height))
for x in element.array()))
mapping = dict(xs=dims[0], ys='heights')
keys = (dims[0], 'heights')

if self.invert_axes: keys = keys[::-1]
data = dict(zip(keys, (xs, ys)))

cmap = style.get('palette', style.get('cmap', None))
if self.color_index < len(dims) and cmap:
cdim = dims[self.color_index]
map_key = 'color_' + cdim
mapping['color'] = map_key
cmap = get_cmap(cmap)
colors = element.dimension_values(cdim)
crange = ranges.get(cdim, None)
data[map_key] = map_colors(colors, crange, cmap)

return data, mapping



class SideSpikesPlot(SpikesPlot):
"""
SpikesPlot with useful defaults for plotting adjoined rug plot.
"""

xaxis = param.ObjectSelector(default='top-bare',
objects=['top', 'bottom', 'bare', 'top-bare',
'bottom-bare', None], doc="""
Whether and where to display the xaxis, bare options allow suppressing
all axis labels including ticks and xlabel. Valid options are 'top',
'bottom', 'bare', 'top-bare' and 'bottom-bare'.""")

yaxis = param.ObjectSelector(default='right-bare',
objects=['left', 'right', 'bare', 'left-bare',
'right-bare', None], doc="""
Whether and where to display the yaxis, bare options allow suppressing
all axis labels including ticks and ylabel. Valid options are 'left',
'right', 'bare' 'left-bare' and 'right-bare'.""")

height = param.Integer(default=80, doc="Height of plot")

width = param.Integer(default=80, doc="Width of plot")
7 changes: 4 additions & 3 deletions holoviews/plotting/bokeh/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class ElementPlot(BokehPlot, GenericElementPlot):
{'ticks': '20pt', 'title': '15pt', 'ylabel': '5px', 'xlabel': '5px'}""")

invert_axes = param.Boolean(default=False, doc="""
Whether to invert the x- and y-axis""")

invert_xaxis = param.Boolean(default=False, doc="""
Whether to invert the plot x-axis.""")

Expand Down Expand Up @@ -137,9 +140,7 @@ class ElementPlot(BokehPlot, GenericElementPlot):
# instance attribute.
_update_handles = ['source', 'glyph']

def __init__(self, element, plot=None, invert_axes=False,
show_labels=['x', 'y'], **params):
self.invert_axes = invert_axes
def __init__(self, element, plot=None, show_labels=['x', 'y'], **params):
self.show_labels = show_labels
self.current_ranges = None
super(ElementPlot, self).__init__(element, **params)
Expand Down
14 changes: 8 additions & 6 deletions holoviews/plotting/bokeh/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,15 @@ def _create_subplots(self, layout, positions, layout_dimensions, ranges, num=0):
if pos != 'main':
plot_type = AdjointLayoutPlot.registry.get(vtype, plot_type)
if pos == 'right':
side_opts = dict(height=main_plot.height, yaxis='right',
invert_axes=True, width=120, show_labels=['y'],
xticks=2, show_title=False)
yaxis = 'right-bare' if 'bare' in plot_type.yaxis else 'right'
side_opts = dict(height=main_plot.height, yaxis=yaxis,
width=plot_type.width, invert_axes=True,
show_labels=['y'], xticks=1, xaxis=main_plot.xaxis)
else:
side_opts = dict(width=main_plot.width, xaxis='top',
height=120, show_labels=['x'], yticks=2,
show_title=False)
xaxis = 'top-bare' if 'bare' in plot_type.xaxis else 'top'
side_opts = dict(width=main_plot.width, xaxis=xaxis,
height=plot_type.height, show_labels=['x'],
yticks=1, yaxis=main_plot.yaxis)

# Override the plotopts as required
# Customize plotopts depending on position.
Expand Down
5 changes: 4 additions & 1 deletion holoviews/plotting/mpl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def grid_selector(grid):
VectorField: VectorFieldPlot,
ErrorBars: ErrorPlot,
Spread: SpreadPlot,
Spikes: SpikesPlot,

# General plots
GridSpace: GridPlot,
Expand Down Expand Up @@ -157,7 +158,8 @@ def grid_selector(grid):


MPLPlot.sideplots.update({Histogram: SideHistogramPlot,
GridSpace: GridPlot})
GridSpace: GridPlot,
Spikes: SideSpikesPlot})

options = Store.options(backend='matplotlib')

Expand All @@ -175,6 +177,7 @@ def grid_selector(grid):
options.Scatter3D = Options('style', facecolors=Cycle(), marker='o')
options.Scatter3D = Options('plot', fig_size=150)
options.Surface = Options('plot', fig_size=150)
options.Spikes = Options('style', color='black')
# Rasters
options.Image = Options('style', cmap='hot', interpolation='nearest')
options.Raster = Options('style', cmap='hot', interpolation='nearest')
Expand Down
Loading

0 comments on commit bafc20a

Please sign in to comment.