/
pipeline.py
632 lines (540 loc) · 22.3 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
from __future__ import annotations
import sys
import traceback as tb
from collections import OrderedDict, defaultdict
from typing import ClassVar, Tuple
import param
from .layout import Column, Row
from .pane import HoloViews, Markdown
from .param import Param
from .util import param_reprs
from .viewable import Viewer
from .widgets import Button, Select
class PipelineError(RuntimeError):
"""
Custom error type which can be raised to display custom error
message in a Pipeline.
"""
def traverse(graph, v, visited):
"""
Traverse the graph from a node and mark visited vertices.
"""
visited[v] = True
# Recur for all the vertices adjacent to this vertex
for i in graph.get(v, []):
if visited[i] == False:
traverse(graph, i, visited)
def find_route(graph, current, target):
"""
Find a route to the target node from the current node.
"""
next_nodes = graph.get(current)
if next_nodes is None:
return None
elif target in next_nodes:
return [target]
else:
for n in next_nodes:
route = find_route(graph, n, target)
if route is None:
continue
return [n]+route
return None
def get_root(graph):
"""
Search for the root node by finding nodes without inputs.
"""
# Find root node
roots = []
targets = [t for ts in graph.values() for t in ts]
for src in graph:
if src not in targets:
roots.append(src)
if len(roots) > 1:
raise ValueError("Graph has more than one node with no "
"incoming edges. Ensure that the graph "
"only has a single source node.")
elif len(roots) == 0:
raise ValueError("Graph has no source node. Ensure that the "
"graph is not cyclic and has a single "
"starting point.")
return roots[0]
def is_traversable(root, graph, stages):
"""
Check if the graph is fully traversable from the root node.
"""
# Ensure graph is traverable from root
int_graph = {stages.index(s): tuple(stages.index(t) for t in tgts)
for s, tgts in graph.items()}
visited = [False]*len(stages)
traverse(int_graph, stages.index(root), visited)
return all(visited)
def get_depth(node, graph, depth=0):
depths = []
for sub in graph.get(node, []):
depths.append(get_depth(sub, graph, depth+1))
return max(depths) if depths else depth+1
def get_breadths(node, graph, depth=0, breadths=None):
if breadths is None:
breadths = defaultdict(list)
breadths[depth].append(node)
for sub in graph.get(node, []):
if sub not in breadths[depth+1]:
breadths[depth+1].append(sub)
get_breadths(sub, graph, depth+1, breadths)
return breadths
class Pipeline(Viewer):
"""
A Pipeline represents a directed graph of stages, which each
return a panel object to render. A pipeline therefore represents
a UI workflow of multiple linear or branching stages.
The Pipeline layout consists of a number of sub-components:
* header:
* title: The name of the current stage.
* error: A field to display the error state.
* network: A network diagram representing the pipeline.
* buttons: All navigation buttons and selectors.
* prev_button: The button to go to the previous stage.
* prev_selector: The selector widget to select between
previous branching stages.
* next_button: The button to go to the previous stage
* next_selector: The selector widget to select the next
branching stages.
* stage: The contents of the current pipeline stage.
By default any outputs of one stage annotated with the
param.output decorator are fed into the next stage. Additionally,
if the inherit_params parameter is set any parameters which are
declared on both the previous and next stage are also inherited.
The stages are declared using the add_stage method and must each
be given a unique name. By default any stages will simply be
connected linearly, but an explicit graph can be declared using
the define_graph method.
"""
auto_advance = param.Boolean(default=False, doc="""
Whether to automatically advance if the ready parameter is True.""")
debug = param.Boolean(default=False, doc="""
Whether to raise errors, useful for debugging while building
an application.""")
inherit_params = param.Boolean(default=True, doc="""
Whether parameters should be inherited between pipeline
stages.""")
next_parameter = param.String(default=None, allow_refs=False, doc="""
Parameter name to watch to switch between different branching
stages""")
ready_parameter = param.String(default=None, allow_refs=False, doc="""
Parameter name to watch to check whether a stage is ready.""")
show_header = param.Boolean(default=True, doc="""
Whether to show the header with the title, network diagram,
and buttons.""")
next = param.Action(default=lambda x: x.param.trigger('next'))
previous = param.Action(default=lambda x: x.param.trigger('previous'))
_ignored_refs: ClassVar[Tuple[str, ...]] = ('next_parameter', 'ready_parameter')
def __init__(self, stages=[], graph={}, **params):
try:
import holoviews as hv
except Exception:
raise ImportError('Pipeline requires holoviews to be installed') from None
super().__init__(**params)
# Initialize internal state
self._stage = None
self._stages = OrderedDict()
self._states = {}
self._state = None
self._linear = True
self._block = False
self._error = None
self._graph = {}
self._route = []
# Declare UI components
self._progress_sel = hv.streams.Selection1D()
self._progress_sel.add_subscriber(self._set_stage)
self.prev_button = Param(self.param.previous).layout[0]
self.prev_button.width = 125
self.prev_selector = Select(width=125)
self.next_button = Param(self.param.next).layout[0]
self.next_button.width = 125
self.next_selector = Select(width=125)
self.prev_button.disabled = True
self.next_selector.param.watch(self._update_progress, 'value')
self.network = HoloViews(backend='bokeh')
self.title = Markdown('# Header', margin=(0, 0, 0, 5))
self.error = Row(width=100)
self.buttons = Row(self.prev_button, self.next_button)
self.header = Row(
Column(self.title, self.error),
self.network,
self.buttons,
sizing_mode='stretch_width'
)
self.network.object = self._make_progress()
self.stage = Row()
self.layout = Column(self.header, self.stage, sizing_mode='stretch_width')
# Initialize stages and the graph
for stage in stages:
kwargs = {}
if len(stage) == 2:
name, stage = stage
elif len(stage) == 3:
name, stage, kwargs = stage
self.add_stage(name, stage, **kwargs)
self.define_graph(graph)
def __panel__(self):
return self.layout
def _validate(self, stage):
if any(stage is s for n, (s, kw) in self._stages.items()):
raise ValueError('Stage %s is already in pipeline' % stage)
elif not ((isinstance(stage, type) and issubclass(stage, param.Parameterized))
or isinstance(stage, param.Parameterized)):
raise ValueError('Pipeline stages must be Parameterized classes or instances.')
def __repr__(self):
repr_str = 'Pipeline:'
for i, (name, (stage, _)) in enumerate(self._stages.items()):
if isinstance(stage, param.Parameterized):
cls_name = type(stage).__name__
else:
cls_name = stage.__name__
params = ', '.join(param_reprs(stage))
repr_str += '\n [%d] %s: %s(%s)' % (i, name, cls_name, params)
return repr_str
def __str__(self):
return self.__repr__()
def __getitem__(self, index):
return self._stages[index][0]
def _unblock(self, event):
if self._state is not event.obj or self._block:
self._block = False
return
button = self.next_button
if button.disabled and event.new:
button.disabled = False
elif not button.disabled and not event.new:
button.disabled = True
stage_kwargs = self._stages[self._stage][-1]
if event.new and stage_kwargs.get('auto_advance', self.auto_advance):
self._next()
def _select_next(self, event):
if self._state is not event.obj:
return
self.next_selector.value = event.new
self._update_progress()
def _init_stage(self):
stage, stage_kwargs = self._stages[self._stage]
previous = []
for src, tgts in self._graph.items():
if self._stage in tgts:
previous.append(src)
prev_states = [self._states[prev] for prev in previous if prev in self._states]
outputs = []
kwargs, results = {}, {}
for state in prev_states:
for name, (_, method, index) in state.param.outputs().items():
if name not in stage.param:
continue
if method not in results:
results[method] = method()
result = results[method]
if index is not None:
result = result[index]
kwargs[name] = result
outputs.append(name)
if stage_kwargs.get('inherit_params', self.inherit_params):
ignored = [stage_kwargs.get(p) or getattr(self, p, None)
for p in ('ready_parameter', 'next_parameter')]
params = [k for k, v in state.param.objects('existing').items()
if k not in ignored]
kwargs.update({k: v for k, v in state.param.values().items()
if k in stage.param and k != 'name' and k in params})
if isinstance(stage, param.Parameterized):
stage.param.update(**kwargs)
self._state = stage
else:
self._state = stage(**kwargs)
# Hide widgets for parameters that are supplied by the previous stage
for output in outputs:
self._state.param[output].precedence = -1
ready_param = stage_kwargs.get('ready_parameter', self.ready_parameter)
if ready_param and ready_param in stage.param:
self._state.param.watch(self._unblock, ready_param, onlychanged=False)
next_param = stage_kwargs.get('next_parameter', self.next_parameter)
if next_param and next_param in stage.param:
self._state.param.watch(self._select_next, next_param, onlychanged=False)
self._states[self._stage] = self._state
return self._state.panel()
def _set_stage(self, index):
if not index:
return
stage = self._progress_sel.source.iloc[index[0], 2]
if stage in self.next_selector.options:
self.next_selector.value = stage
self.param.trigger('next')
elif stage in self.prev_selector.options:
self.prev_selector.value = stage
self.param.trigger('previous')
elif stage in self._route:
while len(self._route) > 1:
self.param.trigger('previous')
else:
# Try currently selected route
route = find_route(self._graph, self._next_stage, stage)
if route is None:
# Try alternate route
route = find_route(self._graph, self._stage, stage)
if route is None:
raise ValueError('Could not find route to target node.')
else:
route = [self._next_stage] + route
for r in route:
if r not in self.next_selector.options:
break
self.next_selector.value = r
self.param.trigger('next')
@property
def _next_stage(self):
return self.next_selector.value
@property
def _prev_stage(self):
return self.prev_selector.value
def _update_button(self):
stage, kwargs = self._stages[self._stage]
options = list(self._graph.get(self._stage, []))
next_param = kwargs.get('next_parameter', self.next_parameter)
option = getattr(self._state, next_param) if next_param and next_param in stage.param else None
if option is None:
option = options[0] if options else None
self.next_selector.options = options
self.next_selector.value = option
self.next_selector.disabled = not bool(options)
previous = []
for src, tgts in self._graph.items():
if self._stage in tgts:
previous.append(src)
self.prev_selector.options = previous
self.prev_selector.value = self._route[-1] if previous else None
self.prev_selector.disabled = not bool(previous)
# Disable previous button
if self._prev_stage is None:
self.prev_button.disabled = True
else:
self.prev_button.disabled = False
# Disable next button
if self._next_stage is None:
self.next_button.disabled = True
else:
ready = kwargs.get('ready_parameter', self.ready_parameter)
disabled = (not getattr(stage, ready)) if ready in stage.param else False
self.next_button.disabled = disabled
def _get_error_button(self, e):
msg = str(e) if isinstance(e, PipelineError) else ""
if self.debug:
type, value, trb = sys.exc_info()
tb_list = tb.format_tb(trb, None) + tb.format_exception_only(type, value)
traceback = (("%s\n\nTraceback (innermost last):\n" + "%-20s %s") %
(msg, ''.join(tb_list[-5:-1]), tb_list[-1]))
else:
traceback = msg or "Undefined error, enable debug mode."
button = Button(name='Error', button_type='danger', width=100,
align='center', margin=(0, 0, 0, 5))
button.js_on_click(code="alert(`{tb}`)".format(tb=traceback))
return button
@param.depends('next', watch=True)
def _next(self):
prev_state, prev_stage = self._state, self._stage
self._stage = self._next_stage
self.stage.loading = True
try:
self.stage[0] = self._init_stage()
except Exception as e:
self._error = self._stage
self._stage = prev_stage
self._state = prev_state
self.stage[0] = prev_state.panel()
self.error[:] = [self._get_error_button(e)]
if self.debug:
raise e
return e
else:
self.error[:] = []
self._error = None
self._update_button()
self._route.append(self._stage)
stage_kwargs = self._stages[self._stage][-1]
ready_param = stage_kwargs.get('ready_parameter', self.ready_parameter)
if (ready_param and getattr(self._state, ready_param, False) and
stage_kwargs.get('auto_advance', self.auto_advance)):
self._next()
finally:
self._update_progress()
self.stage.loading = False
@param.depends('previous', watch=True)
def _previous(self):
prev_state, prev_stage = self._state, self._stage
self._stage = self._prev_stage
try:
if self._stage in self._states:
self._state = self._states[self._stage]
self.stage[0] = self._state.panel()
else:
self.stage[0] = self._init_stage()
self._block = True
except Exception as e:
self.error[:] = [self._get_error_button(e)]
self._error = self._stage
self._stage = prev_stage
self._state = prev_state
if self.debug:
raise e
else:
self.error[:] = []
self._error = None
self._update_button()
self._route.pop()
finally:
self._update_progress()
def _update_progress(self, *args):
self.title.object = '## Stage: ' + self._stage
self.network.object = self._make_progress()
def _make_progress(self):
import holoviews as hv
import holoviews.plotting.bokeh # noqa
if self._graph:
root = get_root(self._graph)
depth = get_depth(root, self._graph)
breadths = get_breadths(root, self._graph)
max_breadth = max(len(v) for v in breadths.values())
else:
root = None
max_breadth, depth = 0, 0
breadths = {}
height = 80 + (max_breadth-1) * 20
edges = []
for src, tgts in self._graph.items():
for t in tgts:
edges.append((src, t))
nodes = []
for depth, subnodes in breadths.items():
breadth = len(subnodes)
step = 1./breadth
for i, n in enumerate(subnodes[::-1]):
if n == self._stage:
state = 'active'
elif n == self._error:
state = 'error'
elif n == self._next_stage:
state = 'next'
else:
state = 'inactive'
nodes.append((depth, step/2.+i*step, n, state))
cmap = {'inactive': 'white', 'active': '#5cb85c', 'error': 'red',
'next': 'yellow'}
def tap_renderer(plot, element):
from bokeh.models import TapTool
gr = plot.handles['glyph_renderer']
tap = plot.state.select_one(TapTool)
tap.renderers = [gr]
nodes = hv.Nodes(nodes, ['x', 'y', 'Stage'], 'State').opts(
alpha=0, default_tools=['tap'], hooks=[tap_renderer],
hover_alpha=0, selection_alpha=0, nonselection_alpha=0,
axiswise=True, size=10, backend='bokeh'
)
self._progress_sel.source = nodes
graph = hv.Graph((edges, nodes)).opts(
edge_hover_line_color='black', node_color='State', cmap=cmap,
tools=[], default_tools=['hover'], selection_policy=None,
node_hover_fill_color='gray', axiswise=True, backend='bokeh')
labels = hv.Labels(nodes, ['x', 'y'], 'Stage').opts(
yoffset=-.30, default_tools=[], axiswise=True, backend='bokeh'
)
plot = (graph * labels * nodes) if self._linear else (graph * nodes)
plot.opts(
xaxis=None, yaxis=None, min_width=400, responsive=True,
show_frame=False, height=height, xlim=(-0.25, depth+0.25),
ylim=(0, 1), default_tools=['hover'], toolbar=None, backend='bokeh'
)
return plot
#----------------------------------------------------------------
# Public API
#----------------------------------------------------------------
def add_stage(self, name, stage, **kwargs):
"""
Adds a new, named stage to the Pipeline.
Arguments
---------
name: str
A string name for the Pipeline stage
stage: param.Parameterized
A Parameterized object which represents the Pipeline stage.
**kwargs: dict
Additional arguments declaring the behavior of the stage.
"""
self._validate(stage)
for k in kwargs:
if k not in self.param:
raise ValueError("Keyword argument %s is not a valid parameter. " % k)
if not self._linear and self._graph:
raise RuntimeError("Cannot add stage after graph has been defined.")
self._stages[name] = (stage, kwargs)
if len(self._stages) == 1:
self._stage = name
self._route = [name]
self._graph = {}
self.stage[:] = [self._init_stage()]
else:
previous = [s for s in self._stages if s not in self._graph][0]
self._graph[previous] = (name,)
self._update_progress()
self._update_button()
def define_graph(self, graph, force=False):
"""
Declares a custom graph structure for the Pipeline overriding
the default linear flow. The graph should be defined as an
adjacency mapping.
Arguments
---------
graph: dict
Dictionary declaring the relationship between different
pipeline stages. Should map from a single stage name to
one or more stage names.
"""
stages = list(self._stages)
if not stages:
self._graph = {}
return
graph = {k: v if isinstance(v, tuple) else (v,) for k, v in graph.items()}
not_found = []
for source, targets in graph.items():
if source not in stages:
not_found.append(source)
not_found += [t for t in targets if t not in stages]
if not_found:
raise ValueError(
'Pipeline stage(s) %s not found, ensure all stages '
'referenced in the graph have been added.' %
(not_found[0] if len(not_found) == 1 else not_found)
)
if graph:
if not (self._linear or force):
raise ValueError("Graph has already been defined, "
"cannot override existing graph.")
self._linear = False
else:
graph = {s: (t,) for s, t in zip(stages[:-1], stages[1:])}
root = get_root(graph)
if not is_traversable(root, graph, stages):
raise ValueError('Graph is not fully traversable from stage: %s.'
% root)
reinit = root is not self._stage
self._stage = root
self._graph = graph
self._route = [root]
if not self._linear:
self.buttons[:] = [
Column(self.prev_selector, self.prev_button),
Column(self.next_selector, self.next_button)
]
if reinit:
self.stage[:] = [self._init_stage()]
self._update_progress()
self._update_button()
__all__ = (
"Pipeline",
)