/
computational_graph.py
289 lines (238 loc) · 10.8 KB
/
computational_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import heapq
from chainer import function_node
from chainer import variable
_var_style = {'shape': 'octagon', 'fillcolor': '#E0E0E0', 'style': 'filled'}
_func_style = {'shape': 'record', 'fillcolor': '#6495ED', 'style': 'filled'}
class DotNode(object):
"""Node of the computational graph, with utilities for dot language.
This class represents a node of computational graph,
with some utilities for dot language.
Args:
node: :class: `VariableNode` object or :class: `FunctionNode` object.
attribute (dict): Attributes for the node.
show_name (bool): If `True`, the `name` attribute of the node is added
to the label. Default is `True`.
"""
def __init__(self, node, attribute=None, show_name=True):
assert isinstance(node, (variable.VariableNode,
function_node.FunctionNode))
self.node = node
self.id_ = id(node)
self.attribute = {'label': node.label}
if isinstance(node, variable.VariableNode):
if show_name and node.name is not None:
self.attribute['label'] = '{}: {}'.format(
node.name, self.attribute['label'])
self.attribute.update({'shape': 'oval'})
else:
self.attribute.update({'shape': 'box'})
if attribute is not None:
self.attribute.update(attribute)
@property
def label(self):
"""The text that represents properties of the node.
Returns:
string: The text that represents the id and attributes of this
node.
"""
attributes = ["%s=\"%s\"" % (k, v) for (k, v)
in self.attribute.items()]
return "%s [%s];" % (self.id_, ",".join(attributes))
class ComputationalGraph(object):
"""Class that represents computational graph.
.. note::
We assume that the computational graph is directed and acyclic.
Args:
nodes (list): List of nodes. Each node is either
:class:`VariableNode` object or :class:`Function` object.
edges (list): List of edges. Each edge consists of pair of nodes.
variable_style (dict): Dot node style for variable.
function_style (dict): Dot node style for function.
rankdir (str): Direction of the graph that must be
TB (top to bottom), BT (bottom to top), LR (left to right)
or RL (right to left).
remove_variable (bool): If ``True``, :class:`~chainer.Variable`\\ s are
removed from the resulting computational graph. Only
:class:`~chainer.Function`\\ s are shown in the output.
show_name (bool): If ``True``, the ``name`` attribute of each node is
added to the label of the node. Default is ``True``.
.. note::
The default behavior of :class:`~chainer.ComputationalGraph` has been
changed from v1.23.0, so that it ouputs the richest representation of
a graph as default, namely, styles are set and names of functions and
variables are shown. To reproduce the same result as previous versions
(<= v1.22.0), please specify `variable_style=None`,
`function_style=None`, and `show_name=False` explicitly.
"""
def __init__(self, nodes, edges, variable_style=_var_style,
function_style=_func_style, rankdir='TB',
remove_variable=False, show_name=True):
self.nodes = nodes
self.edges = edges
self.variable_style = variable_style
self.function_style = function_style
if rankdir not in ('TB', 'BT', 'LR', 'RL'):
raise ValueError('rankdir must be in TB, BT, LR or RL.')
self.rankdir = rankdir
self.remove_variable = remove_variable
self.show_name = show_name
def _to_dot(self):
"""Converts graph in dot format.
`label` property of is used as short description of each node.
Returns:
str: The graph in dot format.
"""
ret = 'digraph graphname{rankdir=%s;' % self.rankdir
if self.remove_variable:
self.nodes, self.edges = _skip_variable(self.nodes, self.edges)
for node in self.nodes:
assert isinstance(node, (variable.VariableNode,
function_node.FunctionNode))
if isinstance(node, variable.VariableNode):
if not self.remove_variable:
ret += DotNode(
node, self.variable_style, self.show_name).label
else:
ret += DotNode(node, self.function_style, self.show_name).label
drawn_edges = []
for edge in self.edges:
head, tail = edge
if (isinstance(head, variable.VariableNode) and
isinstance(tail, function_node.FunctionNode)):
head_attr = self.variable_style
tail_attr = self.function_style
elif (isinstance(head, function_node.FunctionNode) and
isinstance(tail, variable.VariableNode)):
head_attr = self.function_style
tail_attr = self.variable_style
else:
if not self.remove_variable:
raise TypeError('head and tail should be the set of '
'VariableNode and Function')
else:
head_attr = self.function_style
tail_attr = self.function_style
head_node = DotNode(head, head_attr, self.show_name)
tail_node = DotNode(tail, tail_attr, self.show_name)
edge = (head_node.id_, tail_node.id_)
if edge in drawn_edges:
continue
ret += "%s -> %s;" % edge
drawn_edges.append(edge)
ret += "}"
return ret
def dump(self, format='dot'):
"""Dumps graph as a text.
Args:
format(str): The graph language name of the output.
Currently, it must be 'dot'.
Returns:
str: The graph in specified format.
"""
if format == 'dot':
return self._to_dot()
else:
NotImplementedError('Currently, only dot format is supported.')
def _skip_variable(nodes, edges):
func_edges = []
for edge_i, edge in enumerate(edges):
head, tail = edge
if isinstance(head, variable.VariableNode):
if head.creator_node is not None:
head = head.creator_node
else:
continue
if isinstance(tail, variable.VariableNode):
for node in nodes:
if isinstance(node, function_node.FunctionNode):
for input_var in node.inputs:
if input_var is tail:
tail = node
break
if isinstance(tail, function_node.FunctionNode):
break
else:
continue
func_edges.append((head, tail))
return nodes, func_edges
def build_computational_graph(
outputs, remove_split=True, variable_style=_var_style,
function_style=_func_style, rankdir='TB', remove_variable=False,
show_name=True):
"""Builds a graph of functions and variables backward-reachable from outputs.
Args:
outputs(list): nodes from which the graph is constructed.
Each element of outputs must be either :class:`~chainer.Variable`
object, :class:`~chainer.variable.VariableNode` object, or
:class:`~chainer.Function` object.
remove_split(bool): It must be ``True``. This argument is left for
backward compatibility.
variable_style(dict): Dot node style for variable.
Possible keys are 'shape', 'color', 'fillcolor', 'style', and etc.
function_style(dict): Dot node style for function.
rankdir (str): Direction of the graph that must be
TB (top to bottom), BT (bottom to top), LR (left to right)
or RL (right to left).
remove_variable (bool): If ``True``, :class:`~chainer.Variable`\\ s are
removed from the resulting computational graph. Only
:class:`~chainer.Function`\\ s are shown in the output.
show_name (bool): If ``True``, the ``name`` attribute of each node is
added to the label of the node. Default is ``True``.
Returns:
ComputationalGraph: A graph consisting of nodes and edges that
are backward-reachable from at least one of ``outputs``.
If ``unchain_backward`` was called in some variable in the
computational graph before this function, backward step is
stopped at this variable.
For example, suppose that computational graph is as follows::
|--> f ---> y
x --+
|--> g ---> z
Let ``outputs = [y, z]``.
Then the full graph is emitted.
Next, let ``outputs = [y]``. Note that ``z`` and ``g``
are not backward-reachable from ``y``.
The resulting graph would be following::
x ---> f ---> y
See :class:`TestGraphBuilder` for details.
.. note::
The default behavior of :class:`~chainer.ComputationalGraph` has been
changed from v1.23.0, so that it ouputs the richest representation of
a graph as default, namely, styles are set and names of functions and
variables are shown. To reproduce the same result as previous versions
(<= v1.22.0), please specify `variable_style=None`,
`function_style=None`, and `show_name=False` explicitly.
"""
if not remove_split:
raise ValueError('remove_split=False is not supported anymore')
cands = []
seen_edges = set()
nodes = set()
push_count = [0]
def add_cand(cand):
heapq.heappush(cands, (-cand.rank, push_count[0], cand))
push_count[0] += 1
for o in outputs:
if isinstance(o, variable.Variable):
o = o.node
add_cand(o)
nodes.add(o)
while cands:
_, _, cand = heapq.heappop(cands)
if isinstance(cand, variable.VariableNode):
creator = cand.creator_node
if creator is not None and (creator, cand) not in seen_edges:
add_cand(creator)
seen_edges.add((creator, cand))
nodes.add(creator)
nodes.add(cand)
elif isinstance(cand, function_node.FunctionNode):
for input_ in cand.inputs:
if input_ is not cand and (input_, cand) not in seen_edges:
add_cand(input_)
seen_edges.add((input_, cand))
nodes.add(input_)
nodes.add(cand)
return ComputationalGraph(
list(nodes), list(seen_edges), variable_style,
function_style, rankdir, remove_variable, show_name)