-
Notifications
You must be signed in to change notification settings - Fork 334
/
global_pool.py
459 lines (363 loc) · 13.9 KB
/
global_pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras import constraints, initializers, regularizers
from tensorflow.keras.layers import Dense, Layer
from spektral.layers import ops
class GlobalPool(Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True
self.pooling_op = None
self.batch_pooling_op = None
def build(self, input_shape):
if isinstance(input_shape, list) and len(input_shape) == 2:
self.data_mode = "disjoint"
else:
if len(input_shape) == 2:
self.data_mode = "single"
else:
self.data_mode = "batch"
super().build(input_shape)
def call(self, inputs):
if self.data_mode == "disjoint":
X = inputs[0]
I = inputs[1]
if K.ndim(I) == 2:
I = I[:, 0]
else:
X = inputs
if self.data_mode == "disjoint":
return self.pooling_op(X, I)
else:
return self.batch_pooling_op(
X, axis=-2, keepdims=(self.data_mode == "single")
)
def compute_output_shape(self, input_shape):
if self.data_mode == "single":
return (1,) + input_shape[-1:]
elif self.data_mode == "batch":
return input_shape[:-2] + input_shape[-1:]
else:
# Input shape is a list of shapes for X and I
return input_shape[0]
class GlobalSumPool(GlobalPool):
"""
A global sum pooling layer. Pools a graph by computing the sum of its node
features.
**Mode**: single, disjoint, mixed, batch.
**Input**
- Node features of shape `([batch], n_nodes, n_node_features)`;
- Graph IDs of shape `(n_nodes, )` (only in disjoint mode);
**Output**
- Pooled node features of shape `(batch, n_node_features)` (if single mode, shape will
be `(1, n_node_features)`).
**Arguments**
None.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.pooling_op = tf.math.segment_sum
self.batch_pooling_op = tf.reduce_sum
class GlobalAvgPool(GlobalPool):
"""
An average pooling layer. Pools a graph by computing the average of its node
features.
**Mode**: single, disjoint, mixed, batch.
**Input**
- Node features of shape `([batch], n_nodes, n_node_features)`;
- Graph IDs of shape `(n_nodes, )` (only in disjoint mode);
**Output**
- Pooled node features of shape `(batch, n_node_features)` (if single mode, shape will
be `(1, n_node_features)`).
**Arguments**
None.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.pooling_op = tf.math.segment_mean
self.batch_pooling_op = tf.reduce_mean
class GlobalMaxPool(GlobalPool):
"""
A max pooling layer. Pools a graph by computing the maximum of its node
features.
**Mode**: single, disjoint, mixed, batch.
**Input**
- Node features of shape `([batch], n_nodes, n_node_features)`;
- Graph IDs of shape `(n_nodes, )` (only in disjoint mode);
**Output**
- Pooled node features of shape `(batch, n_node_features)` (if single mode, shape will
be `(1, n_node_features)`).
**Arguments**
None.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.pooling_op = tf.math.segment_max
self.batch_pooling_op = tf.reduce_max
class GlobalAttentionPool(GlobalPool):
r"""
A gated attention global pooling layer from the paper
> [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493)<br>
> Yujia Li et al.
This layer computes:
$$
\X' = \sum\limits_{i=1}^{N} (\sigma(\X \W_1 + \b_1) \odot (\X \W_2 + \b_2))_i
$$
where \(\sigma\) is the sigmoid activation function.
**Mode**: single, disjoint, mixed, batch.
**Input**
- Node features of shape `([batch], n_nodes, n_node_features)`;
- Graph IDs of shape `(n_nodes, )` (only in disjoint mode);
**Output**
- Pooled node features of shape `(batch, channels)` (if single mode,
shape will be `(1, channels)`).
**Arguments**
- `channels`: integer, number of output channels;
- `bias_initializer`: initializer for the bias vectors;
- `kernel_regularizer`: regularization applied to the kernel matrices;
- `bias_regularizer`: regularization applied to the bias vectors;
- `kernel_constraint`: constraint applied to the kernel matrices;
- `bias_constraint`: constraint applied to the bias vectors.
"""
def __init__(
self,
channels,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs,
):
super().__init__(**kwargs)
self.channels = channels
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
def build(self, input_shape):
super().build(input_shape)
layer_kwargs = dict(
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
kernel_regularizer=self.kernel_regularizer,
bias_regularizer=self.bias_regularizer,
kernel_constraint=self.kernel_constraint,
bias_constraint=self.bias_constraint,
)
self.features_layer = Dense(
self.channels, name="features_layer", **layer_kwargs
)
self.attention_layer = Dense(
self.channels, activation="sigmoid", name="attn_layer", **layer_kwargs
)
self.built = True
def call(self, inputs):
if self.data_mode == "disjoint":
X, I = inputs
if K.ndim(I) == 2:
I = I[:, 0]
else:
X = inputs
inputs_linear = self.features_layer(X)
attn = self.attention_layer(X)
masked_inputs = inputs_linear * attn
if self.data_mode in {"single", "batch"}:
output = K.sum(masked_inputs, axis=-2, keepdims=self.data_mode == "single")
else:
output = tf.math.segment_sum(masked_inputs, I)
return output
def compute_output_shape(self, input_shape):
if self.data_mode == "single":
return (1,) + (self.channels,)
elif self.data_mode == "batch":
return input_shape[:-2] + (self.channels,)
else:
output_shape = input_shape[0]
output_shape = output_shape[:-1] + (self.channels,)
return output_shape
def get_config(self):
config = {
"channels": self.channels,
"kernel_initializer": self.kernel_initializer,
"bias_initializer": self.bias_initializer,
"kernel_regularizer": self.kernel_regularizer,
"bias_regularizer": self.bias_regularizer,
"kernel_constraint": self.kernel_constraint,
"bias_constraint": self.bias_constraint,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
class GlobalAttnSumPool(GlobalPool):
r"""
A node-attention global pooling layer. Pools a graph by learning attention
coefficients to sum node features.
This layer computes:
$$
\alpha = \textrm{softmax}( \X \a); \\
\X' = \sum\limits_{i=1}^{N} \alpha_i \cdot \X_i
$$
where \(\a \in \mathbb{R}^F\) is a trainable vector. Note that the softmax
is applied across nodes, and not across features.
**Mode**: single, disjoint, mixed, batch.
**Input**
- Node features of shape `([batch], n_nodes, n_node_features)`;
- Graph IDs of shape `(n_nodes, )` (only in disjoint mode);
**Output**
- Pooled node features of shape `(batch, n_node_features)` (if single mode, shape will
be `(1, n_node_features)`).
**Arguments**
- `attn_kernel_initializer`: initializer for the attention weights;
- `attn_kernel_regularizer`: regularization applied to the attention kernel
matrix;
- `attn_kernel_constraint`: constraint applied to the attention kernel
matrix;
"""
def __init__(
self,
attn_kernel_initializer="glorot_uniform",
attn_kernel_regularizer=None,
attn_kernel_constraint=None,
**kwargs,
):
super().__init__(**kwargs)
self.attn_kernel_initializer = initializers.get(attn_kernel_initializer)
self.attn_kernel_regularizer = regularizers.get(attn_kernel_regularizer)
self.attn_kernel_constraint = constraints.get(attn_kernel_constraint)
def build(self, input_shape):
assert len(input_shape) >= 2
if isinstance(input_shape, list) and len(input_shape) == 2:
self.data_mode = "disjoint"
F = input_shape[0][-1]
else:
if len(input_shape) == 2:
self.data_mode = "single"
else:
self.data_mode = "batch"
F = input_shape[-1]
# Attention kernels
self.attn_kernel = self.add_weight(
shape=(F, 1),
initializer=self.attn_kernel_initializer,
regularizer=self.attn_kernel_regularizer,
constraint=self.attn_kernel_constraint,
name="attn_kernel",
)
self.built = True
def call(self, inputs):
if self.data_mode == "disjoint":
X, I = inputs
if K.ndim(I) == 2:
I = I[:, 0]
else:
X = inputs
attn_coeff = K.dot(X, self.attn_kernel)
attn_coeff = K.squeeze(attn_coeff, -1)
if self.data_mode == "single":
attn_coeff = K.softmax(attn_coeff)
output = K.dot(attn_coeff[None, ...], X)
elif self.data_mode == "batch":
attn_coeff = K.softmax(attn_coeff)
output = K.batch_dot(attn_coeff, X)
else:
attn_coeff = ops.unsorted_segment_softmax(attn_coeff, I, K.shape(X)[0])
output = attn_coeff[:, None] * X
output = tf.math.segment_sum(output, I)
return output
def get_config(self):
config = {
"attn_kernel_initializer": self.attn_kernel_initializer,
"attn_kernel_regularizer": self.attn_kernel_regularizer,
"attn_kernel_constraint": self.attn_kernel_constraint,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
class SortPool(Layer):
r"""
A SortPool layer as described by
[Zhang et al](https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf).
This layers takes a graph signal \(\mathbf{X}\) and returns the topmost k
rows according to the last column.
If \(\mathbf{X}\) has less than k rows, the result is zero-padded to k.
**Mode**: single, disjoint, batch.
**Input**
- Node features of shape `([batch], n_nodes, n_node_features)`;
- Graph IDs of shape `(n_nodes, )` (only in disjoint mode);
**Output**
- Pooled node features of shape `(batch, k, n_node_features)` (if single mode, shape will
be `(1, k, n_node_features)`).
**Arguments**
- `k`: integer, number of nodes to keep;
"""
def __init__(self, k, **kwargs):
super().__init__(**kwargs)
k = int(k)
if k <= 0:
raise ValueError("K must be a positive integer")
self.k = k
def build(self, input_shape):
if isinstance(input_shape, list) and len(input_shape) == 2:
self.data_mode = "disjoint"
self.F = input_shape[0][-1]
else:
if len(input_shape) == 2:
self.data_mode = "single"
else:
self.data_mode = "batch"
self.F = input_shape[-1]
def call(self, inputs):
if self.data_mode == "disjoint":
X, I = inputs
X = ops.disjoint_signal_to_batch(X, I)
else:
X = inputs
if self.data_mode == "single":
X = tf.expand_dims(X, 0)
N = tf.shape(X)[-2]
sort_perm = tf.argsort(X[..., -1], direction="DESCENDING")
X_sorted = tf.gather(X, sort_perm, axis=-2, batch_dims=1)
def truncate():
_X_out = X_sorted[..., : self.k, :]
return _X_out
def pad():
padding = [[0, 0], [0, self.k - N], [0, 0]]
_X_out = tf.pad(X_sorted, padding)
return _X_out
X_out = tf.cond(tf.less_equal(self.k, N), truncate, pad)
if self.data_mode == "single":
X_out = tf.squeeze(X_out, [0])
X_out.set_shape((self.k, self.F))
elif self.data_mode == "batch" or self.data_mode == "disjoint":
X_out.set_shape((None, self.k, self.F))
return X_out
def get_config(self):
config = {
"k": self.k,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
if self.data_mode == "single":
return self.k, self.F
elif self.data_mode == "batch" or self.data_mode == "disjoint":
return input_shape[0], self.k, self.F
layers = {
"sum": GlobalSumPool,
"avg": GlobalAvgPool,
"max": GlobalMaxPool,
"attn": GlobalAttentionPool,
"attn_sum": GlobalAttnSumPool,
"sort": SortPool,
}
def get(identifier):
if identifier not in layers:
raise ValueError(
"Unknown identifier {}. Available: {}".format(
identifier, list(layers.keys())
)
)
else:
return layers[identifier]