-
Notifications
You must be signed in to change notification settings - Fork 5
/
gnn_layers.py
525 lines (449 loc) · 21 KB
/
gnn_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
import numpy as np
from scipy import sparse
from scipy.sparse.linalg import eigsh
import tensorflow as tf
from tensorflow.keras import Model
from . import utils
from scipy.special import comb
class Chebyshev(Model):
"""
A graph convolutional layer using the Chebyshev approximation
"""
def __init__(self, L, K, Fout=None, initializer=None, activation=None, use_bias=False,
use_bn=False, n_matmul_splits=1, **kwargs):
"""
Initializes the graph convolutional layer, assuming the input has dimension (B, M, F)
:param L: The graph Laplacian (MxM), as numpy array
:param K: Order of the polynomial to use
:param Fout: Number of features (channels) of the output, default to number of input channels
:param initializer: initializer to use for weight initialisation
:param activation: the activation function to use after the layer, defaults to linear
:param use_bias: Use learnable bias weights
:param use_bn: Apply batch norm before adding the bias
:param n_matmul_splits: Number of splits to apply to axis 1 of the dense tensor in the
tf.sparse.sparse_dense_matmul operations to avoid the operation's size limitation
:param kwargs: additional keyword arguments passed on to add_weight
"""
# This is necessary for every Layer
super(Chebyshev, self).__init__()
# save necessary params
self.L = L
self.K = K
self.Fout = Fout
self.use_bias = use_bias
self.use_bn = use_bn
if self.use_bn:
self.bn = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5, center=False, scale=False)
self.initializer = initializer
if activation is None or callable(activation):
self.activation = activation
elif hasattr(tf.keras.activations, activation):
self.activation = getattr(tf.keras.activations, activation)
else:
raise ValueError(f"Could not find activation <{activation}> in tf.keras.activations...")
self.n_matmul_splits = n_matmul_splits
self.kwargs = kwargs
# Rescale Laplacian and store as a TF sparse tensor. Copy to not modify the shared L.
L = sparse.csr_matrix(L)
lmax = 1.02 * eigsh(L, k=1, which='LM', return_eigenvectors=False)[0]
L = utils.rescale_L(L, lmax=lmax, scale=0.75)
L = L.tocoo()
indices = np.column_stack((L.row, L.col))
L = tf.SparseTensor(indices, L.data, L.shape)
self.sparse_L = tf.sparse.reorder(L)
def build(self, input_shape):
"""
Build the weights of the layer
:param input_shape: shape of the input, batch dim has to be defined
:return: the kernel variable to train
"""
# get the input shape
Fin = int(input_shape[-1])
# get Fout if necessary
if self.Fout is None:
Fout = Fin
else:
Fout = self.Fout
if self.initializer is None:
# Filter: Fin*Fout filters of order K, i.e. one filterbank per output feature.
stddev = 1 / np.sqrt(Fin * (self.K + 0.5) / 2)
initializer = tf.initializers.TruncatedNormal(stddev=stddev)
self.kernel = self.add_weight("kernel", shape=[self.K * Fin, Fout],
initializer=initializer, **self.kwargs)
else:
self.kernel = self.add_weight("kernel", shape=[self.K * Fin, Fout],
initializer=self.initializer, **self.kwargs)
if self.use_bias:
self.bias = self.add_weight("bias", shape=[1, 1, Fout])
# we cast the sparse L to the current backend type
if tf.keras.backend.floatx() == 'float32':
self.sparse_L = tf.cast(self.sparse_L, tf.float32)
if tf.keras.backend.floatx() == 'float64':
self.sparse_L = tf.cast(self.sparse_L, tf.float64)
def call(self, input_tensor, training=False):
"""
Calls the layer on an input tensor
:param input_tensor: input of the layer shape (batch, nodes, channels)
:param training: whether we are training or not
:return: the output of the layer
"""
# shapes, this fun is necessary since sparse_matmul_dense in TF only supports
# the multiplication of 2d matrices, therefore one has to do some weird reshaping
# this is not strictly necessary but leads to a huge performance gain...
# See: https://arxiv.org/pdf/1903.11409.pdf
N, M, Fin = input_tensor.get_shape()
M, Fin = int(M), int(Fin)
# get Fout if necessary
if self.Fout is None:
Fout = Fin
else:
Fout = self.Fout
# Transform to Chebyshev basis
x0 = tf.transpose(input_tensor, perm=[1, 2, 0]) # M x Fin x N
x0 = tf.reshape(x0, [M, -1]) # M x Fin*N
# list for stacking
stack = [x0]
if self.K > 1:
x1 = utils.split_sparse_dense_matmul(self.sparse_L, x0, self.n_matmul_splits)
stack.append(x1)
for k in range(2, self.K):
x2 = 2 * utils.split_sparse_dense_matmul(self.sparse_L, x1, self.n_matmul_splits) - x0 # M x Fin*N
stack.append(x2)
x0, x1 = x1, x2
x = tf.stack(stack, axis=0)
x = tf.reshape(x, [self.K, M, Fin, -1]) # K x M x Fin x N
x = tf.transpose(x, perm=[3, 1, 2, 0]) # N x M x Fin x K
x = tf.reshape(x, [-1, Fin * self.K]) # N*M x Fin*K
# Filter: Fin*Fout filters of order K, i.e. one filterbank per output feature.
x = tf.matmul(x, self.kernel) # N*M x Fout
x = tf.reshape(x, [-1, M, Fout]) # N x M x Fout
if self.use_bn:
x = self.bn(x, training=training)
if self.use_bias:
x = tf.add(x, self.bias)
if self.activation is not None:
x = self.activation(x)
return x
class Monomial(Model):
"""
A graph convolutional layer using Monomials
"""
def __init__(self, L, K, Fout=None, initializer=None, activation=None, use_bias=False,
use_bn=False, n_matmul_splits=1, **kwargs):
"""
Initializes the graph convolutional layer, assuming the input has dimension (B, M, F)
:param L: The graph Laplacian (MxM), as numpy array
:param K: Order of the polynomial to use
:param Fout: Number of features (channels) of the output, default to number of input channels
:param initializer: initializer to use for weight initialisation
:param activation: the activation function to use after the layer, defaults to linear
:param use_bias: Use learnable bias weights
:param use_bn: Apply batch norm before adding the bias
:param n_matmul_splits: Number of splits to apply to axis 1 of the dense tensor in the
tf.sparse.sparse_dense_matmul operations to avoid the operation's size limitation
:param kwargs: additional keyword arguments passed on to add_weight
"""
# This is necessary for every Layer
super(Monomial, self).__init__()
# save necessary params
self.L = L
self.K = K
self.Fout = Fout
self.use_bias = use_bias
self.use_bn = use_bn
if self.use_bn:
self.bn = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5, center=False, scale=False)
self.initializer = initializer
if activation is None or callable(activation):
self.activation = activation
elif hasattr(tf.keras.activations, activation):
self.activation = getattr(tf.keras.activations, activation)
else:
raise ValueError(f"Could not find activation <{activation}> in tf.keras.activations...")
self.n_matmul_splits = n_matmul_splits
self.kwargs = kwargs
# Rescale Laplacian and store as a TF sparse tensor. Copy to not modify the shared L.
L = sparse.csr_matrix(L)
lmax = 1.02 * eigsh(L, k=1, which='LM', return_eigenvectors=False)[0]
L = utils.rescale_L(L, lmax=lmax)
L = L.tocoo()
indices = np.column_stack((L.row, L.col))
L = tf.SparseTensor(indices, L.data, L.shape)
self.sparse_L = tf.sparse.reorder(L)
def build(self, input_shape):
"""
Build the weights of the layer
:param input_shape: shape of the input, batch dim has to be defined
"""
# get the input shape
Fin = int(input_shape[-1])
# get Fout if necessary
if self.Fout is None:
Fout = Fin
else:
Fout = self.Fout
if self.initializer is None:
# Filter: Fin*Fout filters of order K, i.e. one filterbank per output feature.
initializer = tf.initializers.TruncatedNormal(stddev=0.1)
self.kernel = self.add_weight("kernel", shape=[self.K * Fin, Fout],
initializer=initializer, **self.kwargs)
else:
self.kernel = self.add_weight("kernel", shape=[self.K * Fin, Fout],
initializer=self.initializer, **self.kwargs)
if self.use_bias:
self.bias = self.add_weight("bias", shape=[1, 1, Fout])
# we cast the sparse L to the current backend type
if tf.keras.backend.floatx() == 'float32':
self.sparse_L = tf.cast(self.sparse_L, tf.float32)
if tf.keras.backend.floatx() == 'float64':
self.sparse_L = tf.cast(self.sparse_L, tf.float64)
def call(self, input_tensor, training=False):
"""
Calls the layer on an input tensor
:param input_tensor: input of the layer shape (batch, nodes, channels)
:param training: whether we are training or not
:return: the output of the layer
"""
# shapes, this fun is necessary since sparse_matmul_dense in TF only supports
# the multiplication of 2d matrices, therefore one has to do some weird reshaping
# this is not strictly necessary but leads to a huge performance gain...
# See: https://arxiv.org/pdf/1903.11409.pdf
N, M, Fin = input_tensor.get_shape()
M, Fin = int(M), int(Fin)
# get Fout if necessary
if self.Fout is None:
Fout = Fin
else:
Fout = self.Fout
# Transform to monomial basis.
x0 = tf.transpose(input_tensor, perm=[1, 2, 0]) # M x Fin x N
x0 = tf.reshape(x0, [M, -1]) # M x Fin*N
# list for stacking
stack = [x0]
for k in range(1, self.K):
x1 = utils.split_sparse_dense_matmul(self.sparse_L, x0, self.n_matmul_splits) # M x Fin*N
stack.append(x1)
x0 = x1
x = tf.stack(stack, axis=0)
x = tf.reshape(x, [self.K, M, Fin, -1]) # K x M x Fin x N
x = tf.transpose(x, perm=[3, 1, 2, 0]) # N x M x Fin x K
x = tf.reshape(x, [-1, Fin * self.K]) # N*M x Fin*K
# Filter: Fin*Fout filters of order K, i.e. one filterbank per output feature.
x = tf.matmul(x, self.kernel) # N*M x Fout
x = tf.reshape(x, [-1, M, Fout]) # N x M x Fout
if self.use_bn:
x = self.bn(x, training=training)
if self.use_bias:
x = tf.add(x, self.bias)
if self.activation is not None:
x = self.activation(x)
return x
class GCNN_ResidualLayer(Model):
"""
A generic residual layer of the form
in -> layer -> layer -> out + alpha*in
with optional batchnorm in the end
"""
def __init__(self, layer_type, layer_kwargs, activation=None, act_before=False, use_bn=False,
norm_type="batch_norm", bn_kwargs=None, alpha=1.0):
"""
Initializes the residual layer with the given argument
:param layer_type: The layer type, either "CHEBY" or "MONO" for chebychev or monomials
:param layer_kwargs: A dictionary with the inputs for the layer
:param activation: activation function to use for the res layer
:param act_before: use activation before skip connection
:param use_bn: use batchnorm inbetween the layers
:param norm_type: type of batch norm, either batch_norm for normal batch norm or layer_norm for
tf.keras.layers.LayerNormalization
:param bn_kwargs: An optional dictionary containing further keyword arguments for the normalization layer
:param alpha: Coupling strength of the input -> layer(input) + alpha*input
"""
# This is necessary for every Layer
super(GCNN_ResidualLayer, self).__init__()
# save variables
self.layer_type = layer_type
self.layer_kwargs = layer_kwargs
if activation is None or callable(activation):
self.activation = activation
elif hasattr(tf.keras.activations, activation):
self.activation = getattr(tf.keras.activations, activation)
else:
raise ValueError(f"Could not find activation <{activation}> in tf.keras.activations...")
self.act_before = act_before
self.use_bn = use_bn
self.norm_type = norm_type
# set the default axis if necessary
if bn_kwargs is None:
self.bn_kwargs = {"axis": -1}
else:
self.bn_kwargs = bn_kwargs
if "axis" not in self.bn_kwargs and norm_type != "moving_norm":
self.bn_kwargs.update({"axis": -1})
if self.layer_type == "CHEBY":
self.layer1 = Chebyshev(**self.layer_kwargs)
self.layer2 = Chebyshev(**self.layer_kwargs)
elif self.layer_type == "MONO":
self.layer1 = Monomial(**self.layer_kwargs)
self.layer2 = Monomial(**self.layer_kwargs)
else:
raise IOError(f"Layertype not understood: {self.layer_type}")
if use_bn:
if norm_type == "layer_norm":
self.bn1 = tf.keras.layers.LayerNormalization(**self.bn_kwargs)
self.bn2 = tf.keras.layers.LayerNormalization(**self.bn_kwargs)
elif norm_type == "batch_norm":
self.bn1 = tf.keras.layers.BatchNormalization(**self.bn_kwargs)
self.bn2 = tf.keras.layers.BatchNormalization(**self.bn_kwargs)
else:
raise ValueError(f"norm_type <{norm_type}> not understood!")
self.alpha = alpha
def call(self, input_tensor, training=False):
"""
Calls the layer on an input tensorf
:param input_tensor: The input of the layer
:param training: whether we are training or not
:return: the output of the layer
"""
x = self.layer1(input_tensor)
# bn
if self.use_bn:
x = self.bn1(x, training=training)
# 2nd layer
x = self.layer2(x)
# bn
if self.use_bn:
x = self.bn2(x, training=training)
# deal with the activation
if self.activation is None:
return x + input_tensor
if self.act_before:
return self.activation(x) + self.alpha*input_tensor
else:
return self.activation(x + self.alpha*input_tensor)
class Bernstein(Model):
"""
A graph convolutional layer using the Bernstein approximation
see https://arxiv.org/abs/2106.10994
"""
def __init__(self, L, K, Fout=None, initializer=None, activation=None, use_bias=False,
use_bn=False, n_matmul_splits=1, **kwargs):
"""
Initializes the graph convolutional layer, assuming the input has dimension (B, M, F)
:param L: The graph Laplacian (MxM), as numpy array
:param K: Order of the polynomial to use
:param Fout: Number of features (channels) of the output, default to number of input channels
:param initializer: initializer to use for weight initialisation
:param activation: the activation function to use after the layer, defaults to linear
:param use_bias: Use learnable bias weights
:param use_bn: Apply batch norm before adding the bias
:param n_matmul_splits: Number of splits to apply to axis 1 of the dense tensor in the
tf.sparse.sparse_dense_matmul operations to avoid the operation's size limitation
:param kwargs: additional keyword arguments passed on to add_weight
"""
# This is necessary for every Layer
super(Bernstein, self).__init__()
# save necessary params
self.L = L
self.K = K
self.Fout = Fout
self.use_bias = use_bias
self.use_bn = use_bn
if self.use_bn:
self.bn = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5, center=False, scale=False)
self.initializer = initializer
if activation is None or callable(activation):
self.activation = activation
elif hasattr(tf.keras.activations, activation):
self.activation = getattr(tf.keras.activations, activation)
else:
raise ValueError(f"Could not find activation <{activation}> in tf.keras.activations...")
self.n_matmul_splits = n_matmul_splits
self.kwargs = kwargs
# Rescale Laplacian and store as a TF sparse tensor. Copy to not modify the shared L.
L = sparse.csr_matrix(L)
lmax = 1.02 * eigsh(L, k=1, which='LM', return_eigenvectors=False)[0]
L = utils.rescale_L(L, lmax=lmax, scale=0.75)
L = L.tocoo()
indices = np.column_stack((L.row, L.col))
L = tf.SparseTensor(indices, L.data, L.shape)
self.sparse_L = tf.sparse.reorder(L)
def build(self, input_shape):
"""
Build the weights of the layer
:param input_shape: shape of the input, batch dim has to be defined
:return: the kernel variable to train
"""
# get the input shape
Fin = int(input_shape[-1])
# get Fout if necessary
if self.Fout is None:
Fout = Fin
else:
Fout = self.Fout
if self.initializer is None:
# Filter: Fin*Fout filters of order K, i.e. one filterbank per output feature.
stddev = np.sqrt(6 /(Fin+Fout))
initializer = tf.initializers.TruncatedNormal(stddev=stddev)
self.kernel = self.add_weight("kernel", shape=[(self.K+1) * Fin, Fout],
initializer=initializer, **self.kwargs)
else:
self.kernel = self.add_weight("kernel", shape=[(self.K+1) * Fin, Fout],
initializer=self.initializer, **self.kwargs)
if self.use_bias:
self.bias = self.add_weight("bias", shape=[1, 1, Fout])
# we cast the sparse L to the current backend type
if tf.keras.backend.floatx() == 'float32':
self.sparse_L = tf.cast(self.sparse_L, tf.float32)
if tf.keras.backend.floatx() == 'float64':
self.sparse_L = tf.cast(self.sparse_L, tf.float64)
def call(self, input_tensor, training=False, *args, **kwargs):
"""
Calls the layer on a input tensor
:param input_tensor: input of the layer shape (batch, nodes, channels)
:param args: further arguments
:param training: wheter we are training or not
:param kwargs: further keyword arguments
:return: the output of the layer
"""
# shapes, this fun is necessary since sparse_matmul_dense in TF only supports
# the multiplication of 2d matrices, therefore one has to do some weird reshaping
# this is not strictly necessary but leads to a huge performance gain...
# See: https://arxiv.org/pdf/1903.11409.pdf
N, M, Fin = input_tensor.get_shape()
M, Fin = int(M), int(Fin)
# get Fout if necessary
if self.Fout is None:
Fout = Fin
else:
Fout = self.Fout
# Transform to Chebyshev basis
x0 = tf.transpose(input_tensor, perm=[1, 2, 0]) # M x Fin x N
x0 = tf.reshape(x0, [M, -1]) # M x Fin*N
# list for stacking
stack = []
for i in range(0,self.K+1):
x1 = x0
theta = comb(self.K,i)/(2**self.K)
for j in range(i):
x2= utils.split_sparse_dense_matmul(self.sparse_L, x1, self.n_matmul_splits)
x1 =x2
x2=x1
for k in range(self.K-i):
x3 = 2 * x2 - utils.split_sparse_dense_matmul(self.sparse_L, x2, self.n_matmul_splits)
x2 =x3
x3 = theta*x3
stack.append(x3)
x = tf.stack(stack, axis=0)
x = tf.reshape(x, [(self.K+1), M, Fin, -1]) # K+1 x M x Fin x N
x = tf.transpose(x, perm=[3, 1, 2, 0]) # N x M x Fin x K+1
x = tf.reshape(x, [-1, Fin * (self.K+1)]) # N*M x Fin*K+1
# Filter: Fin*Fout filters of order K, i.e. one filterbank per output feature.
x = tf.matmul(x, self.kernel) # N*M x Fout
x = tf.reshape(x, [-1, M, Fout]) # N x M x Fout
if self.use_bn:
x = self.bn(x, training=training)
if self.use_bias:
x = tf.add(x, self.bias)
if self.activation is not None:
x = self.activation(x)
return x