-
Notifications
You must be signed in to change notification settings - Fork 335
/
gcn.py
87 lines (71 loc) · 2.64 KB
/
gcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import tensorflow as tf
from spektral.layers.convolutional import gcn_conv
class GCN(tf.keras.Model):
"""
This model, with its default hyperparameters, implements the architecture
from the paper:
> [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907)<br>
> Thomas N. Kipf and Max Welling
**Mode**: single, disjoint, mixed, batch.
**Input**
- Node features of shape `([batch], n_nodes, n_node_features)`
- Weighted adjacency matrix of shape `([batch], n_nodes, n_nodes)`
**Output**
- Softmax predictions with shape `([batch], n_nodes, n_labels)`.
**Arguments**
- `n_labels`: number of channels in output;
- `channels`: number of channels in first GCNConv layer;
- `activation`: activation of the first GCNConv layer;
- `output_activation`: activation of the second GCNConv layer;
- `use_bias`: whether to add a learnable bias to the two GCNConv layers;
- `dropout_rate`: `rate` used in `Dropout` layers;
- `l2_reg`: l2 regularization strength;
- `**kwargs`: passed to `Model.__init__`.
"""
def __init__(
self,
n_labels,
channels=16,
activation="relu",
output_activation="softmax",
use_bias=False,
dropout_rate=0.5,
l2_reg=2.5e-4,
**kwargs,
):
super().__init__(**kwargs)
self.n_labels = n_labels
self.channels = channels
self.activation = activation
self.output_activation = output_activation
self.use_bias = use_bias
self.dropout_rate = dropout_rate
self.l2_reg = l2_reg
reg = tf.keras.regularizers.l2(l2_reg)
self._d0 = tf.keras.layers.Dropout(dropout_rate)
self._gcn0 = gcn_conv.GCNConv(
channels, activation=activation, kernel_regularizer=reg, use_bias=use_bias
)
self._d1 = tf.keras.layers.Dropout(dropout_rate)
self._gcn1 = gcn_conv.GCNConv(
n_labels, activation=output_activation, use_bias=use_bias
)
def get_config(self):
return dict(
n_labels=self.n_labels,
channels=self.channels,
activation=self.activation,
output_activation=self.output_activation,
use_bias=self.use_bias,
dropout_rate=self.dropout_rate,
l2_reg=self.l2_reg,
)
def call(self, inputs):
if len(inputs) == 2:
x, a = inputs
else:
x, a, _ = inputs # So that the model can be used with DisjointLoader
x = self._d0(x)
x = self._gcn0([x, a])
x = self._d1(x)
return self._gcn1([x, a])