-
Notifications
You must be signed in to change notification settings - Fork 3k
/
graphconv.py
488 lines (417 loc) · 18.1 KB
/
graphconv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
"""Torch modules for graph convolutions(GCN)."""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
from torch import nn
from torch.nn import init
from .... import function as fn
from ....base import DGLError
from ....convert import block_to_graph
from ....heterograph import DGLBlock
from ....transforms import reverse
from ....utils import expand_as_pair
class EdgeWeightNorm(nn.Module):
r"""This module normalizes positive scalar edge weights on a graph
following the form in `GCN <https://arxiv.org/abs/1609.02907>`__.
Mathematically, setting ``norm='both'`` yields the following normalization term:
.. math::
c_{ji} = (\sqrt{\sum_{k\in\mathcal{N}(j)}e_{jk}}\sqrt{\sum_{k\in\mathcal{N}(i)}e_{ki}})
And, setting ``norm='right'`` yields the following normalization term:
.. math::
c_{ji} = (\sum_{k\in\mathcal{N}(i)}e_{ki})
where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.
The module returns the normalized weight :math:`e_{ji} / c_{ji}`.
Parameters
----------
norm : str, optional
The normalizer as specified above. Default is `'both'`.
eps : float, optional
A small offset value in the denominator. Default is 0.
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import EdgeWeightNorm, GraphConv
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = th.ones(6, 10)
>>> edge_weight = th.tensor([0.5, 0.6, 0.4, 0.7, 0.9, 0.1, 1, 1, 1, 1, 1, 1])
>>> norm = EdgeWeightNorm(norm='both')
>>> norm_edge_weight = norm(g, edge_weight)
>>> conv = GraphConv(10, 2, norm='none', weight=True, bias=True)
>>> res = conv(g, feat, edge_weight=norm_edge_weight)
>>> print(res)
tensor([[-1.1849, -0.7525],
[-1.3514, -0.8582],
[-1.2384, -0.7865],
[-1.9949, -1.2669],
[-1.3658, -0.8674],
[-0.8323, -0.5286]], grad_fn=<AddBackward0>)
"""
def __init__(self, norm="both", eps=0.0):
super(EdgeWeightNorm, self).__init__()
self._norm = norm
self._eps = eps
def forward(self, graph, edge_weight):
r"""
Description
-----------
Compute normalized edge weight for the GCN model.
Parameters
----------
graph : DGLGraph
The graph.
edge_weight : torch.Tensor
Unnormalized scalar weights on the edges.
The shape is expected to be :math:`(|E|)`.
Returns
-------
torch.Tensor
The normalized edge weight.
Raises
------
DGLError
Case 1:
The edge weight is multi-dimensional. Currently this module
only supports a scalar weight on each edge.
Case 2:
The edge weight has non-positive values with ``norm='both'``.
This will trigger square root and division by a non-positive number.
"""
with graph.local_scope():
if isinstance(graph, DGLBlock):
graph = block_to_graph(graph)
if len(edge_weight.shape) > 1:
raise DGLError(
"Currently the normalization is only defined "
"on scalar edge weight. Please customize the "
"normalization for your high-dimensional weights."
)
if self._norm == "both" and th.any(edge_weight <= 0).item():
raise DGLError(
'Non-positive edge weight detected with `norm="both"`. '
"This leads to square root of zero or negative values."
)
dev = graph.device
dtype = edge_weight.dtype
graph.srcdata["_src_out_w"] = th.ones(
graph.number_of_src_nodes(), dtype=dtype, device=dev
)
graph.dstdata["_dst_in_w"] = th.ones(
graph.number_of_dst_nodes(), dtype=dtype, device=dev
)
graph.edata["_edge_w"] = edge_weight
if self._norm == "both":
reversed_g = reverse(graph)
reversed_g.edata["_edge_w"] = edge_weight
reversed_g.update_all(
fn.copy_e("_edge_w", "m"), fn.sum("m", "out_weight")
)
degs = reversed_g.dstdata["out_weight"] + self._eps
norm = th.pow(degs, -0.5)
graph.srcdata["_src_out_w"] = norm
if self._norm != "none":
graph.update_all(
fn.copy_e("_edge_w", "m"), fn.sum("m", "in_weight")
)
degs = graph.dstdata["in_weight"] + self._eps
if self._norm == "both":
norm = th.pow(degs, -0.5)
else:
norm = 1.0 / degs
graph.dstdata["_dst_in_w"] = norm
graph.apply_edges(
lambda e: {
"_norm_edge_weights": e.src["_src_out_w"]
* e.dst["_dst_in_w"]
* e.data["_edge_w"]
}
)
return graph.edata["_norm_edge_weights"]
# pylint: disable=W0235
class GraphConv(nn.Module):
r"""Graph convolutional layer from `Semi-Supervised Classification with Graph Convolutional
Networks <https://arxiv.org/abs/1609.02907>`__
Mathematically it is defined as follows:
.. math::
h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ji}}h_j^{(l)}W^{(l)})
where :math:`\mathcal{N}(i)` is the set of neighbors of node :math:`i`,
:math:`c_{ji}` is the product of the square root of node degrees
(i.e., :math:`c_{ji} = \sqrt{|\mathcal{N}(j)|}\sqrt{|\mathcal{N}(i)|}`),
and :math:`\sigma` is an activation function.
If a weight tensor on each edge is provided, the weighted graph convolution is defined as:
.. math::
h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{e_{ji}}{c_{ji}}h_j^{(l)}W^{(l)})
where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.
This is NOT equivalent to the weighted graph convolutional network formulation in the paper.
To customize the normalization term :math:`c_{ji}`, one can first set ``norm='none'`` for
the model, and send the pre-normalized :math:`e_{ji}` to the forward computation. We provide
:class:`~dgl.nn.pytorch.EdgeWeightNorm` to normalize scalar edge weight following the GCN paper.
Parameters
----------
in_feats : int
Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.
out_feats : int
Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.
norm : str, optional
How to apply the normalizer. Can be one of the following values:
* ``right``, to divide the aggregated messages by each node's in-degrees,
which is equivalent to averaging the received messages.
* ``none``, where no normalization is applied.
* ``both`` (default), where the messages are scaled with :math:`1/c_{ji}` above, equivalent
to symmetric normalization.
* ``left``, to divide the messages sent out from each node by its out-degrees,
equivalent to random walk normalization.
weight : bool, optional
If True, apply a linear layer. Otherwise, aggregating the messages
without a weight matrix.
bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
activation : callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Default: ``False``.
Attributes
----------
weight : torch.Tensor
The learnable weight tensor.
bias : torch.Tensor
The learnable bias tensor.
Note
----
Zero in-degree nodes will lead to invalid output value. This is because no message
will be passed to those nodes, the aggregation function will be appied on empty input.
A common practice to avoid this is to add a self-loop for each node in the graph if
it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph
>>> g = dgl.add_self_loop(g)
Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph
since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``
to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.
A common practise to handle this is to filter out the nodes with zero-in-degree when use
after conv.
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import GraphConv
>>> # Case 1: Homogeneous graph
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = th.ones(6, 10)
>>> conv = GraphConv(10, 2, norm='both', weight=True, bias=True)
>>> res = conv(g, feat)
>>> print(res)
tensor([[ 1.3326, -0.2797],
[ 1.4673, -0.3080],
[ 1.3326, -0.2797],
[ 1.6871, -0.3541],
[ 1.7711, -0.3717],
[ 1.0375, -0.2178]], grad_fn=<AddBackward0>)
>>> # allow_zero_in_degree example
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> conv = GraphConv(10, 2, norm='both', weight=True, bias=True, allow_zero_in_degree=True)
>>> res = conv(g, feat)
>>> print(res)
tensor([[-0.2473, -0.4631],
[-0.3497, -0.6549],
[-0.3497, -0.6549],
[-0.4221, -0.7905],
[-0.3497, -0.6549],
[ 0.0000, 0.0000]], grad_fn=<AddBackward0>)
>>> # Case 2: Unidirectional bipartite graph
>>> u = [0, 1, 0, 0, 1]
>>> v = [0, 1, 2, 3, 2]
>>> g = dgl.heterograph({('_U', '_E', '_V') : (u, v)})
>>> u_fea = th.rand(2, 5)
>>> v_fea = th.rand(4, 5)
>>> conv = GraphConv(5, 2, norm='both', weight=True, bias=True)
>>> res = conv(g, (u_fea, v_fea))
>>> res
tensor([[-0.2994, 0.6106],
[-0.4482, 0.5540],
[-0.5287, 0.8235],
[-0.2994, 0.6106]], grad_fn=<AddBackward0>)
"""
def __init__(
self,
in_feats,
out_feats,
norm="both",
weight=True,
bias=True,
activation=None,
allow_zero_in_degree=False,
):
super(GraphConv, self).__init__()
if norm not in ("none", "both", "right", "left"):
raise DGLError(
'Invalid norm value. Must be either "none", "both", "right" or "left".'
' But got "{}".'.format(norm)
)
self._in_feats = in_feats
self._out_feats = out_feats
self._norm = norm
self._allow_zero_in_degree = allow_zero_in_degree
if weight:
self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))
else:
self.register_parameter("weight", None)
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_parameter("bias", None)
self.reset_parameters()
self._activation = activation
def reset_parameters(self):
r"""
Description
-----------
Reinitialize learnable parameters.
Note
----
The model parameters are initialized as in the
`original implementation <https://github.com/tkipf/gcn/blob/master/gcn/layers.py>`__
where the weight :math:`W^{(l)}` is initialized using Glorot uniform initialization
and the bias is initialized to be zero.
"""
if self.weight is not None:
init.xavier_uniform_(self.weight)
if self.bias is not None:
init.zeros_(self.bias)
def set_allow_zero_in_degree(self, set_value):
r"""
Description
-----------
Set allow_zero_in_degree flag.
Parameters
----------
set_value : bool
The value to be set to the flag.
"""
self._allow_zero_in_degree = set_value
def forward(self, graph, feat, weight=None, edge_weight=None):
r"""
Description
-----------
Compute graph convolution.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, it represents the input feature of shape
:math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, which is the case for bipartite graph, the pair
must contain two tensors of shape :math:`(N_{in}, D_{in_{src}})` and
:math:`(N_{out}, D_{in_{dst}})`.
weight : torch.Tensor, optional
Optional external weight tensor.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.
Returns
-------
torch.Tensor
The output feature
Raises
------
DGLError
Case 1:
If there are 0-in-degree nodes in the input graph, it will raise DGLError
since no message will be passed to those nodes. This will cause invalid output.
The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.
Case 2:
External weight is provided while at the same time the module
has defined its own weight parameter.
Note
----
* Input shape: :math:`(N, *, \text{in_feats})` where * means any number of additional
dimensions, :math:`N` is the number of nodes.
* Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are
the same shape as the input.
* Weight shape: :math:`(\text{in_feats}, \text{out_feats})`.
"""
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
aggregate_fn = fn.copy_u("h", "m")
if edge_weight is not None:
assert edge_weight.shape[0] == graph.num_edges()
graph.edata["_edge_weight"] = edge_weight
aggregate_fn = fn.u_mul_e("h", "_edge_weight", "m")
# (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite.
feat_src, feat_dst = expand_as_pair(feat, graph)
if self._norm in ["left", "both"]:
degs = graph.out_degrees().to(feat_src).clamp(min=1)
if self._norm == "both":
norm = th.pow(degs, -0.5)
else:
norm = 1.0 / degs
shp = norm.shape + (1,) * (feat_src.dim() - 1)
norm = th.reshape(norm, shp)
feat_src = feat_src * norm
if weight is not None:
if self.weight is not None:
raise DGLError(
"External weight is provided while at the same time the"
" module has defined its own weight parameter. Please"
" create the module with flag weight=False."
)
else:
weight = self.weight
if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation.
if weight is not None:
feat_src = th.matmul(feat_src, weight)
graph.srcdata["h"] = feat_src
graph.update_all(aggregate_fn, fn.sum(msg="m", out="h"))
rst = graph.dstdata["h"]
else:
# aggregate first then mult W
graph.srcdata["h"] = feat_src
graph.update_all(aggregate_fn, fn.sum(msg="m", out="h"))
rst = graph.dstdata["h"]
if weight is not None:
rst = th.matmul(rst, weight)
if self._norm in ["right", "both"]:
degs = graph.in_degrees().to(feat_dst).clamp(min=1)
if self._norm == "both":
norm = th.pow(degs, -0.5)
else:
norm = 1.0 / degs
shp = norm.shape + (1,) * (feat_dst.dim() - 1)
norm = th.reshape(norm, shp)
rst = rst * norm
if self.bias is not None:
rst = rst + self.bias
if self._activation is not None:
rst = self._activation(rst)
return rst
def extra_repr(self):
"""Set the extra representation of the module,
which will come into effect when printing the model.
"""
summary = "in={_in_feats}, out={_out_feats}"
summary += ", normalization={_norm}"
if "_activation" in self.__dict__:
summary += ", activation={_activation}"
return summary.format(**self.__dict__)