/
plotter_bokeh.py
700 lines (565 loc) · 22.9 KB
/
plotter_bokeh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
"""
"""
# TODO: marker alpha
import functools
import itertools
import numpy as np
from ..manage import auto_xyz_ds
from .core import (
Plotter,
AbstractLinePlot,
AbstractScatter,
AbstractHeatMap,
calc_row_col_datasets,
PLOTTER_DEFAULTS,
intercept_call_arg,
prettify,
)
@functools.lru_cache(1)
def _init_bokeh_nb():
"""Cache this so it doesn't happen over and over again.
"""
from bokeh.plotting import output_notebook
from bokeh.resources import INLINE
output_notebook(resources=INLINE)
def bshow(figs, nb=True, interactive=False, **kwargs):
"""
"""
from bokeh.plotting import show
if nb:
_init_bokeh_nb()
show(figs, notebook_handle=interactive)
else:
show(figs)
# --------------------------------------------------------------------------- #
# Main lineplot interface for bokeh #
# --------------------------------------------------------------------------- #
class PlotterBokeh(Plotter):
def __init__(self, ds, x, y, z=None, **kwargs):
"""
"""
# bokeh custom options / defaults
kwargs['return_fig'] = kwargs.pop('return_fig', False)
self._interactive = kwargs.pop('interactive', False)
super().__init__(ds, x, y, z, **kwargs, backend='BOKEH')
def prepare_axes(self):
"""Make the bokeh plot figure and set options.
"""
from bokeh.plotting import figure
if self.add_to_axes is not None:
self._plot = self.add_to_axes
else:
# Currently axes scale type must be set at figure creation?
self._plot = figure(
# convert figsize to roughly matplotlib dimensions
width=int(self.figsize[0] * 80 +
(100 if self._use_legend else 0) +
(20 if self._ytitle else 0) +
(20 if not self.yticklabels_hide else 0)),
height=int(self.figsize[1] * 80 +
(20 if self.title else 0) +
(20 if self._xtitle else 0) +
(20 if not self.xticklabels_hide else 0)),
x_axis_type=('log' if self.xlog else 'linear'),
y_axis_type=('log' if self.ylog else 'linear'),
y_axis_location=('right' if self.ytitle_right else 'left'),
title=self.title,
toolbar_location="above",
toolbar_sticky=False,
active_scroll="wheel_zoom",
)
def set_axes_labels(self):
"""Set the labels on the axes.
"""
if self._xtitle:
self._plot.xaxis.axis_label = self._xtitle
if self._ytitle:
self._plot.yaxis.axis_label = self._ytitle
def set_axes_range(self):
"""Set the plot ranges of the axes, and the panning limits.
"""
from bokeh.models import DataRange1d
self.calc_data_range()
# plt_x_centre = (self._data_xmax + self._data_xmin) / 2
# plt_x_range = self._data_xmax - self._data_xmin
# xbounds = (plt_x_centre - plt_x_range, plt_x_centre + plt_x_range)
xbounds = None
self._plot.x_range = (DataRange1d(start=self._xlims[0],
end=self._xlims[1],
bounds=xbounds) if self._xlims else
DataRange1d(bounds=xbounds))
# plt_y_centre = (self._data_ymax + self._data_ymin) / 2
# plt_y_range = abs(self._data_ymax - self._data_ymin)
# ybounds = (plt_y_centre - plt_y_range, plt_y_centre + plt_y_range)
ybounds = None
self._plot.y_range = (DataRange1d(start=self._ylims[0],
end=self._ylims[1],
bounds=ybounds) if self._ylims else
DataRange1d(bounds=ybounds))
def set_spans(self):
"""Set custom horizontal and verical line spans.
"""
from bokeh.models import Span
span_opts = {
'level': 'glyph',
'line_dash': 'dashed',
'line_color': (127, 127, 127),
'line_width': self.span_width,
}
if self.hlines:
for hl in self.hlines:
self._plot.add_layout(Span(
location=hl, dimension='width', **span_opts))
if self.vlines:
for vl in self.vlines:
self._plot.add_layout(Span(
location=vl, dimension='height', **span_opts))
def set_gridlines(self):
"""Set whether to use gridlines or not.
"""
if not self.gridlines:
self._plot.xgrid.visible = False
self._plot.ygrid.visible = False
else:
self._plot.xgrid.grid_line_dash = self.gridline_style
self._plot.ygrid.grid_line_dash = self.gridline_style
def set_tick_marks(self):
"""Set custom locations for the tick marks.
"""
from bokeh.models import FixedTicker
if self.xticks:
self._plot.xaxis[0].ticker = FixedTicker(ticks=self.xticks)
if self.yticks:
self._plot.yaxis[0].ticker = FixedTicker(ticks=self.yticks)
if self.xticklabels_hide:
self._plot.xaxis.major_label_text_font_size = '0pt'
if self.yticklabels_hide:
self._plot.yaxis.major_label_text_font_size = '0pt'
def set_sources_heatmap(self):
from bokeh.plotting import ColumnDataSource
# initialize empty source
if not hasattr(self, '_source'):
self._source = ColumnDataSource(data=dict())
# remove mask from data -> not necessary soon? / convert to nan?
var = np.ma.getdata(self._heatmap_var)
self._source.add([var], 'image')
self._source.add([self._data_xmin], 'x')
self._source.add([self._data_ymin], 'y')
self._source.add([self._data_xmax - self._data_xmin], 'dw')
self._source.add([self._data_ymax - self._data_ymin], 'dh')
def set_sources(self):
"""Set the source dictionaries to be used by the plotter functions.
This is seperate to allow interactive updates of the data only.
"""
from bokeh.plotting import ColumnDataSource
# check if heatmap
if hasattr(self, '_heatmap_var'):
return self.set_sources_heatmap()
# 'copy' the zlabels iterator into src_zlbs
self._zlbls, src_zlbs = itertools.tee(self._zlbls)
# Initialise with empty dicts
if not hasattr(self, "_sources"):
self._sources = [ColumnDataSource(dict())
for _ in range(len(self._z_vals))]
# range through all data and update the sources
for i, (zlabel, data) in enumerate(zip(src_zlbs, self._gen_xy())):
self._sources[i].add(data['x'], 'x')
self._sources[i].add(data['y'], 'y')
self._sources[i].add([zlabel] * len(data['x']), 'z_coo')
# check for color for scatter plot
if 'c' in data:
self._sources[i].add(data['c'], 'c')
# check if should set y_err as well
if 'ye' in data:
y_err_p = data['y'] + data['ye']
y_err_m = data['y'] - data['ye']
self._sources[i].add(
list(zip(data['x'], data['x'])), 'y_err_xs')
self._sources[i].add(list(zip(y_err_p, y_err_m)), 'y_err_ys')
# check if should set x_err as well
if 'xe' in data:
x_err_p = data['x'] + data['xe']
x_err_m = data['x'] - data['xe']
self._sources[i].add(
list(zip(data['y'], data['y'])), 'x_err_ys')
self._sources[i].add(list(zip(x_err_p, x_err_m)), 'x_err_xs')
def plot_legend(self, legend_items=None):
"""Add a legend to the plot.
"""
if self._use_legend:
from bokeh.models import Legend
loc = {'best': 'top_left'}.get(self.legend_loc, self.legend_loc)
where = {None: 'right'}.get(self.legend_where, self.legend_where)
# might be manually specified, e.g. from multiplot
if legend_items is None:
legend_items = self._lgnd_items
lg = Legend(items=legend_items)
lg.location = loc
lg.click_policy = 'hide'
self._plot.add_layout(lg, where)
# Don't repeatedly redraw legend
self._use_legend = False
def set_mappable(self):
from bokeh.models import LogColorMapper, LinearColorMapper
import matplotlib as plt
mappr_fn = (LogColorMapper if self.colormap_log else LinearColorMapper)
bokehpalette = [plt.colors.rgb2hex(m) for m in self.cmap(range(256))]
self.mappable = mappr_fn(palette=bokehpalette,
low=self._zmin, high=self._zmax)
def plot_colorbar(self):
if self._use_colorbar:
where = {None: 'right'}.get(self.legend_where, self.legend_where)
from bokeh.models import ColorBar, LogTicker, BasicTicker
ticker = LogTicker if self.colormap_log else BasicTicker
color_bar = ColorBar(color_mapper=self.mappable, location=(0, 0),
ticker=ticker(desired_num_ticks=6),
title=self._ctitle)
self._plot.add_layout(color_bar, where)
def set_tools(self):
"""Set which tools appear for the plot.
"""
from bokeh.models import HoverTool
tooltips = [
("({}, {})".format(
self.x_coo,
self.y_coo if isinstance(self.y_coo, str) else None),
"(@x, @y)"),
]
if self.z_coo:
tooltips.append((self.z_coo, "@z_coo"))
self._plot.add_tools(HoverTool(tooltips=tooltips))
def update(self):
from bokeh.io import push_notebook
self.set_sources()
push_notebook()
def show(self, **kwargs):
"""Show the produced figure.
"""
if self.return_fig:
return self._plot
bshow(self._plot, **kwargs)
return self
def prepare_plot(self):
self.prepare_axes()
self.set_axes_labels()
self.set_axes_range()
self.set_spans()
self.set_gridlines()
self.set_tick_marks()
self.set_sources()
def bokeh_multi_plot(fn):
"""Decorate a plotting function to plot a grid of values.
"""
@functools.wraps(fn)
def multi_plotter(ds, *args, row=None, col=None, link=False, **kwargs):
if (row is None) and (col is None):
return fn(ds, *args, **kwargs)
# Set some global parameters
p = fn(ds, *args, **kwargs, call=False)
p.prepare_data_multi_grid()
kwargs['xlims'] = kwargs.get('xlims', [p._data_xmin, p._data_xmax])
kwargs['ylims'] = kwargs.get('ylims', [p._data_ymin, p._data_ymax])
kwargs['vmin'] = kwargs.pop('vmin', p.vmin)
kwargs['vmax'] = kwargs.pop('vmax', p.vmax)
# split the dataset into its respective rows and columns
ds_r_c, nrows, ncols = calc_row_col_datasets(ds, row=row, col=col)
# intercept figsize as meaning *total* size for whole grid
figsize = kwargs.pop('figsize', None)
if figsize is None:
av_n = (ncols + nrows) / 2
figsize = (2 * (4 / av_n)**0.5, 2 * (4 / av_n)**0.5)
else:
figsize = (figsize[0] / ncols, figsize[1] / nrows)
kwargs['figsize'] = figsize
# intercept return_fig for the full grid and other options
return_fig = kwargs.pop('return_fig', False)
subplots = {}
# range through rows and do subplots
for i, ds_r in enumerate(ds_r_c):
# range through columns
for j, sub_ds in enumerate(ds_r):
skws = {'legend': False, 'colorbar': False}
# if not last row
if i != nrows - 1:
skws['xticklabels_hide'] = True
skws['xtitle'] = ''
# if not first column
if j != 0:
skws['yticklabels_hide'] = True
skws['ytitle'] = ''
# label each column
if (i == 0) and (col is not None):
col_val = prettify(ds[col].values[j])
skws['title'] = "{} = {}".format(col, col_val)
fx = 'fontsize_xtitle'
skws['fontsize_title'] = kwargs.get(
fx, PLOTTER_DEFAULTS[fx])
# label each row
if (j == ncols - 1) and (row is not None):
skws['ytitle_right'] = True
row_val = prettify(ds[row].values[i])
skws['ytitle'] = "{} = {}".format(row, row_val)
subplots[i, j] = fn(sub_ds, *args, return_fig=True, call=False,
**{**kwargs, **skws})
from bokeh.layouts import gridplot
plts = [[subplots[i, j]() for j in range(ncols)] for i in range(nrows)]
# link zooming and panning between all plots
if link:
x_range, y_range = plts[0][0].x_range, plts[0][0].y_range
for i in range(nrows):
for j in range(ncols):
plts[i][j].x_range = x_range
plts[i][j].y_range = y_range
# the main grid
p._plot = gridplot(plts)
if p._use_legend or p._use_colorbar:
from bokeh.models import Legend, GlyphRenderer, Range1d, ColorBar
from bokeh.layouts import row
# plot dummy using last sub_ds
skws = {'title': "", 'legend_loc': 'center_left',
'legend_where': 'left'}
lgren = fn(sub_ds, *args, return_fig=True, **{**kwargs, **skws})
# remove all but legend, colorbar and glyph renderers
lgren.renderers = [
r for r in lgren.renderers
if isinstance(r, (Legend, GlyphRenderer, ColorBar))
]
lgren.toolbar_location = None
lgren.outline_line_color = None
# size it - this is pretty hacky at the moment
lgren.width = 120
lgren.height = int(80 * figsize[1] * nrows + 100)
lgren.x_range = Range1d(0, 0)
lgren.y_range = Range1d(0, 0)
# append to the right of the gridplot
p._plot = row([p._plot, lgren])
if return_fig:
return p._plot
bshow(p._plot)
return multi_plotter
class ILinePlot(PlotterBokeh, AbstractLinePlot):
def __init__(self, ds, x, y, z=None, y_err=None, x_err=None, **kwargs):
super().__init__(ds, x, y, z=z, y_err=y_err, x_err=x_err, **kwargs)
def plot_lines(self):
"""Plot the data and a corresponding legend.
"""
self._lgnd_items = []
for src in self._sources:
col = next(self._cols)
zlabel = next(self._zlbls)
legend_pics = []
if self.lines:
line = self._plot.line(
'x', 'y',
source=src,
color=col,
line_dash=next(self._lines),
line_width=next(self._lws) * 1.5,
)
legend_pics.append(line)
if self.markers:
marker = next(self._mrkrs)
m = getattr(self._plot, marker)(
'x', 'y',
source=src,
name=zlabel,
color=col,
fill_alpha=0.5,
line_width=0.5,
size=self._marker_size,
)
legend_pics.append(m)
# Check if errors specified as well
if self.y_err:
err = self._plot.multi_line(
xs='y_err_xs', ys='y_err_ys', source=src, color=col,
line_width=self.errorbar_linewidth)
legend_pics.append(err)
if self.x_err:
err = self._plot.multi_line(
xs='x_err_xs', ys='x_err_ys', source=src, color=col,
line_width=self.errorbar_linewidth)
legend_pics.append(err)
# Add the names and styles of drawn lines for the legend
self._lgnd_items.append((zlabel, legend_pics))
def __call__(self):
self.prepare_data_single()
# Bokeh preparation
self.prepare_plot()
self.plot_lines()
self.plot_legend()
self.plot_colorbar()
self.set_tools()
return self.show(interactive=self._interactive)
@bokeh_multi_plot
@intercept_call_arg
def ilineplot(ds, x, y, z=None, y_err=None, x_err=None, **kwargs):
"""From ``ds`` plot lines of ``y`` as a function of ``x``, optionally for
varying ``z``. Interactive,
Parameters
----------
ds : xarray.Dataset
Dataset to plot from.
x : str
Dimension to plot along the x-axis.
y : str or tuple[str]
Variable(s) to plot along the y-axis. If tuple, plot each of the
variables - instead of ``z``.
z : str, optional
Dimension to plot into the page.
y_err : str, optional
Variable to plot as y-error.
x_err : str, optional
Variable to plot as x-error.
row : str, optional
Dimension to vary over as a function of rows.
col : str, optional
Dimension to vary over as a function of columns.
plot_opts
See ``xyzpy.plot.core.PLOTTER_DEFAULTS``.
"""
return ILinePlot(ds, x, y, z, y_err=y_err, x_err=x_err, **kwargs)
class AutoILinePlot(ILinePlot):
"""Interactive raw data multi-line plot.
"""
def __init__(self, x, y_z, **lineplot_opts):
ds = auto_xyz_ds(x, y_z)
super().__init__(ds, 'x', 'y', z='z', **lineplot_opts)
def auto_ilineplot(x, y_z, **lineplot_opts):
"""Auto version of :func:`~xyzpy.ilineplot` that accepts array arguments
by converting them to a ``Dataset`` first.
"""
return AutoILinePlot(x, y_z, **lineplot_opts)()
# --------------------------------------------------------------------------- #
class IScatter(PlotterBokeh, AbstractScatter):
def __init__(self, ds, x, y, z=None, **kwargs):
super().__init__(ds, x, y, z, **kwargs, markers=True)
def plot_scatter(self):
self._lgnd_items = []
for src in self._sources:
if 'c' in src.column_names:
col = {'field': 'c', 'transform': self.mappable}
else:
col = next(self._cols)
marker = next(self._mrkrs)
zlabel = next(self._zlbls)
legend_pics = []
m = getattr(self._plot, marker)(
'x', 'y',
source=src,
name=zlabel,
color=col,
fill_alpha=0.5,
line_width=0.5,
size=self._marker_size,
)
legend_pics.append(m)
# Add the names and styles of drawn markers for the legend
self._lgnd_items.append((zlabel, legend_pics))
def __call__(self):
self.prepare_data_single()
# Bokeh preparation
self.prepare_plot()
self.plot_scatter()
self.plot_legend()
self.plot_colorbar()
self.set_tools()
return self.show(interactive=self._interactive)
@bokeh_multi_plot
@intercept_call_arg
def iscatter(ds, x, y, z=None, y_err=None, x_err=None, **kwargs):
"""From ``ds`` plot a scatter of ``y`` against ``x``, optionally for
varying ``z``. Interactive.
Parameters
----------
ds : xarray.Dataset
Dataset to plot from.
x : str
Quantity to plot along the x-axis.
y : str or tuple[str]
Quantity(s) to plot along the y-axis. If tuple, plot each of the
variables - instead of ``z``.
z : str, optional
Dimension to plot into the page.
y_err : str, optional
Variable to plot as y-error.
x_err : str, optional
Variable to plot as x-error.
row : str, optional
Dimension to vary over as a function of rows.
col : str, optional
Dimension to vary over as a function of columns.
plot_opts
See ``xyzpy.plot.core.PLOTTER_DEFAULTS``.
"""
return IScatter(ds, x, y, z, y_err=y_err, x_err=x_err, **kwargs)
class AutoIScatter(IScatter):
def __init__(self, x, y_z, **iscatter_opts):
ds = auto_xyz_ds(x, y_z)
super().__init__(ds, 'x', 'y', z='z', **iscatter_opts)
def auto_iscatter(x, y_z, **iscatter_opts):
"""Auto version of :func:`~xyzpy.iscatter` that accepts array arguments
by converting them to a ``Dataset`` first.
"""
return AutoIScatter(x, y_z, **iscatter_opts)()
# --------------------------------------------------------------------------- #
_HEATMAP_ALT_DEFAULTS = (
('legend', False),
('colorbar', True),
('colormap', 'inferno'),
('gridlines', False),
('padding', 0),
('figsize', (5, 5)), # try to be square, maybe use aspect_ratio??
)
class IHeatMap(PlotterBokeh, AbstractHeatMap):
def __init__(self, ds, x, y, z, **kwargs):
# set some heatmap specific options
for k, default in _HEATMAP_ALT_DEFAULTS:
if k not in kwargs:
kwargs[k] = default
super().__init__(ds, x, y, z, **kwargs)
def plot_heatmap(self):
self.calc_color_norm()
self._plot.image(image='image', x='x', y='y', dw='dw', dh='dh',
source=self._source, color_mapper=self.mappable)
def __call__(self):
# Core preparation
self.prepare_data_single()
# matplotlib preparation
self.prepare_plot()
self.plot_heatmap()
self.plot_colorbar()
self.set_tools()
return self.show(interactive=self._interactive)
@bokeh_multi_plot
@intercept_call_arg
def iheatmap(ds, x, y, z, **kwargs):
"""From ``ds`` plot variable ``z`` as a function of ``x`` and ``y`` using
a 2D heatmap. Interactive,
Parameters
----------
ds : xarray.Dataset
Dataset to plot from.
x : str
Dimension to plot along the x-axis.
y : str
Dimension to plot along the y-axis.
z : str, optional
Variable to plot as colormap.
row : str, optional
Dimension to vary over as a function of rows.
col : str, optional
Dimension to vary over as a function of columns.
plot_opts
See ``xyzpy.plot.core.PLOTTER_DEFAULTS``.
"""
return IHeatMap(ds, x, y, z, **kwargs)
class AutoIHeatMap(IHeatMap):
def __init__(self, x, **iheatmap_opts):
ds = auto_xyz_ds(x)
super().__init__(ds, 'y', 'z', 'x', **iheatmap_opts)
def auto_iheatmap(x, **iheatmap_opts):
"""Auto version of :func:`~xyzpy.iheatmap` that accepts array arguments
by converting them to a ``Dataset`` first.
"""
return AutoIHeatMap(x, **iheatmap_opts)()