|
| 1 | +import tensorflow as tf |
| 2 | +import tensorflow.keras.backend as K |
| 3 | +from tensorflow.keras import Sequential |
| 4 | +from tensorflow.keras.layers import Dense |
| 5 | + |
| 6 | +from spektral.layers import ops |
| 7 | +from spektral.layers.pooling.src import SRCPool |
| 8 | + |
| 9 | + |
| 10 | +class AsymCheegerCutPool(SRCPool): |
| 11 | + r""" |
| 12 | + An Asymmetric Cheeger Cut Pooling layer from the paper |
| 13 | + > [Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218)<br> |
| 14 | + > Jonas Berg Hansen and Filippo Maria Bianchi |
| 15 | +
|
| 16 | + **Mode**: single, batch. |
| 17 | +
|
| 18 | + This layer learns a soft clustering of the input graph as follows: |
| 19 | + $$ |
| 20 | + \begin{align} |
| 21 | + \S &= \textrm{MLP}(\X); \\ |
| 22 | + \X' &= \S^\top \X \\ |
| 23 | + \A' &= \S^\top \A \S; \\ |
| 24 | + \end{align} |
| 25 | + $$ |
| 26 | + where \(\textrm{MLP}\) is a multi-layer perceptron with softmax output. |
| 27 | +
|
| 28 | + The layer includes two auxiliary loss terms/components: |
| 29 | + A graph total variation component given by |
| 30 | + $$ |
| 31 | + L_\text{GTV} = \frac{1}{2E} \sum_{k=1}^K \sum_{i=1}^N \sum_{j=i}^N a_{i,j} |s_{i,k} - s_{j,k}|, |
| 32 | + $$ |
| 33 | + where \(E\) is the number of edges/links, \(K\) is the number of clusters or output nodes, and \(N\) is the number of nodes. |
| 34 | + |
| 35 | + An asymmetrical norm component given by |
| 36 | + $$ |
| 37 | + L_\text{AN} = \frac{N(K - 1) - \sum_{k=1}^K ||\s_{:,k} - \textrm{quant}_{K-1} (\s_{:,k})||_{1, K-1}}{N(K-1)}, |
| 38 | + $$ |
| 39 | +
|
| 40 | + The layer can be used without a supervised loss to compute node clustering by |
| 41 | + minimizing the two auxiliary losses. |
| 42 | +
|
| 43 | + **Input** |
| 44 | +
|
| 45 | + - Node features of shape `(batch, n_nodes_in, n_node_features)`; |
| 46 | + - Adjacency matrix of shape `(batch, n_nodes_in, n_nodes_in)`; |
| 47 | +
|
| 48 | + **Output** |
| 49 | +
|
| 50 | + - Reduced node features of shape `(batch, n_nodes_out, n_node_features)`; |
| 51 | + - If `return_selection=True`, the selection matrix of shape |
| 52 | + `(batch, n_nodes_in, n_nodes_out)`. |
| 53 | +
|
| 54 | + **Arguments** |
| 55 | +
|
| 56 | + - `k`: number of output nodes; |
| 57 | + - `mlp_hidden`: list of integers, number of hidden units for each hidden layer in |
| 58 | + the MLP used to compute cluster assignments (if `None`, the MLP has only one output |
| 59 | + layer); |
| 60 | + - `mlp_activation`: activation for the MLP layers; |
| 61 | + - `totvar_coeff`: coefficient for graph total variation loss component; |
| 62 | + - `balance_coeff`: coefficient for asymmetric norm loss component; |
| 63 | + - `return_selection`: boolean, whether to return the selection matrix; |
| 64 | + - `use_bias`: use bias in the MLP; |
| 65 | + - `kernel_initializer`: initializer for the weights of the MLP; |
| 66 | + - `bias_regularizer`: regularization applied to the bias of the MLP; |
| 67 | + - `kernel_constraint`: constraint applied to the weights of the MLP; |
| 68 | + - `bias_constraint`: constraint applied to the bias of the MLP; |
| 69 | + """ |
| 70 | + |
| 71 | + def __init__( |
| 72 | + self, |
| 73 | + k, |
| 74 | + mlp_hidden=None, |
| 75 | + mlp_activation="relu", |
| 76 | + totvar_coeff=1.0, |
| 77 | + balance_coeff=1.0, |
| 78 | + return_selection=False, |
| 79 | + use_bias=True, |
| 80 | + kernel_initializer="glorot_uniform", |
| 81 | + bias_initializer="zeros", |
| 82 | + kernel_regularizer=None, |
| 83 | + bias_regularizer=None, |
| 84 | + kernel_constraint=None, |
| 85 | + bias_constraint=None, |
| 86 | + **kwargs |
| 87 | + ): |
| 88 | + super().__init__( |
| 89 | + k=k, |
| 90 | + mlp_hidden=mlp_hidden, |
| 91 | + mlp_activation=mlp_activation, |
| 92 | + return_selection=return_selection, |
| 93 | + use_bias=use_bias, |
| 94 | + kernel_initializer=kernel_initializer, |
| 95 | + bias_initializer=bias_initializer, |
| 96 | + kernel_regularizer=kernel_regularizer, |
| 97 | + bias_regularizer=bias_regularizer, |
| 98 | + kernel_constraint=kernel_constraint, |
| 99 | + bias_constraint=bias_constraint, |
| 100 | + **kwargs |
| 101 | + ) |
| 102 | + |
| 103 | + self.k = k |
| 104 | + self.mlp_hidden = mlp_hidden if mlp_hidden else [] |
| 105 | + self.mlp_activation = mlp_activation |
| 106 | + self.totvar_coeff = totvar_coeff |
| 107 | + self.balance_coeff = balance_coeff |
| 108 | + |
| 109 | + def build(self, input_shape): |
| 110 | + layer_kwargs = dict( |
| 111 | + kernel_initializer=self.kernel_initializer, |
| 112 | + bias_initializer=self.bias_initializer, |
| 113 | + kernel_regularizer=self.kernel_regularizer, |
| 114 | + bias_regularizer=self.bias_regularizer, |
| 115 | + kernel_constraint=self.kernel_constraint, |
| 116 | + bias_constraint=self.bias_constraint, |
| 117 | + ) |
| 118 | + self.mlp = Sequential( |
| 119 | + [ |
| 120 | + Dense(channels, self.mlp_activation, **layer_kwargs) |
| 121 | + for channels in self.mlp_hidden |
| 122 | + ] |
| 123 | + + [Dense(self.k, "softmax", **layer_kwargs)] |
| 124 | + ) |
| 125 | + |
| 126 | + super().build(input_shape) |
| 127 | + |
| 128 | + def call(self, inputs, mask=None): |
| 129 | + x, a, i = self.get_inputs(inputs) |
| 130 | + return self.pool(x, a, i, mask=mask) |
| 131 | + |
| 132 | + def select(self, x, a, i, mask=None): |
| 133 | + s = self.mlp(x) |
| 134 | + if mask is not None: |
| 135 | + s *= mask[0] |
| 136 | + |
| 137 | + # Total variation loss |
| 138 | + tv_loss = self.totvar_loss(a, s) |
| 139 | + if K.ndim(a) == 3: |
| 140 | + tv_loss = K.mean(tv_loss) |
| 141 | + self.add_loss(self.totvar_coeff * tv_loss) |
| 142 | + |
| 143 | + # Asymmetric l1-norm loss |
| 144 | + bal_loss = self.balance_loss(s) |
| 145 | + if K.ndim(a) == 3: |
| 146 | + bal_loss = K.mean(bal_loss) |
| 147 | + self.add_loss(self.balance_coeff * bal_loss) |
| 148 | + |
| 149 | + return s |
| 150 | + |
| 151 | + def reduce(self, x, s, **kwargs): |
| 152 | + return ops.modal_dot(s, x, transpose_a=True) |
| 153 | + |
| 154 | + def connect(self, a, s, **kwargs): |
| 155 | + a_pool = ops.matmul_at_b_a(s, a) |
| 156 | + |
| 157 | + return a_pool |
| 158 | + |
| 159 | + def reduce_index(self, i, s, **kwargs): |
| 160 | + i_mean = tf.math.segment_mean(i, i) |
| 161 | + i_pool = ops.repeat(i_mean, tf.ones_like(i_mean) * self.k) |
| 162 | + |
| 163 | + return i_pool |
| 164 | + |
| 165 | + def totvar_loss(self, a, s): |
| 166 | + if K.is_sparse(a): |
| 167 | + index_i = a.indices[:, 0] |
| 168 | + index_j = a.indices[:, 1] |
| 169 | + |
| 170 | + n_edges = tf.cast(len(a.values), dtype=s.dtype) |
| 171 | + |
| 172 | + loss = tf.math.reduce_sum( |
| 173 | + a.values[:, tf.newaxis] |
| 174 | + * tf.math.abs(tf.gather(s, index_i) - tf.gather(s, index_j)), |
| 175 | + axis=(-2, -1), |
| 176 | + ) |
| 177 | + |
| 178 | + else: |
| 179 | + n_edges = tf.cast(tf.math.count_nonzero(a, axis=(-2, -1)), dtype=s.dtype) |
| 180 | + n_nodes = tf.shape(a)[-1] |
| 181 | + if K.ndim(a) == 3: |
| 182 | + loss = tf.math.reduce_sum( |
| 183 | + a |
| 184 | + * tf.math.reduce_sum( |
| 185 | + tf.math.abs( |
| 186 | + s[:, tf.newaxis, ...] |
| 187 | + - tf.repeat(s[..., tf.newaxis, :], n_nodes, axis=-2) |
| 188 | + ), |
| 189 | + axis=-1, |
| 190 | + ), |
| 191 | + axis=(-2, -1), |
| 192 | + ) |
| 193 | + else: |
| 194 | + loss = tf.math.reduce_sum( |
| 195 | + a |
| 196 | + * tf.math.reduce_sum( |
| 197 | + tf.math.abs( |
| 198 | + s - tf.repeat(s[..., tf.newaxis, :], n_nodes, axis=-2) |
| 199 | + ), |
| 200 | + axis=-1, |
| 201 | + ), |
| 202 | + axis=(-2, -1), |
| 203 | + ) |
| 204 | + |
| 205 | + loss *= 1 / (2 * n_edges) |
| 206 | + |
| 207 | + return loss |
| 208 | + |
| 209 | + def balance_loss(self, s): |
| 210 | + n_nodes = tf.cast(tf.shape(s, out_type=tf.int32)[-2], s.dtype) |
| 211 | + |
| 212 | + # k-quantile |
| 213 | + idx = tf.cast(tf.math.floor(n_nodes / self.k) + 1, dtype=tf.int32) |
| 214 | + med = tf.math.top_k(tf.linalg.matrix_transpose(s), k=idx).values[..., -1] |
| 215 | + # Asymmetric l1-norm |
| 216 | + if K.ndim(s) == 2: |
| 217 | + loss = s - med |
| 218 | + else: |
| 219 | + loss = s - med[:, tf.newaxis, ...] |
| 220 | + loss = (tf.cast(loss >= 0, loss.dtype) * (self.k - 1) * loss) + ( |
| 221 | + tf.cast(loss < 0, loss.dtype) * loss * -1.0 |
| 222 | + ) |
| 223 | + loss = tf.math.reduce_sum(loss, axis=(-2, -1)) |
| 224 | + loss = 1 / (n_nodes * (self.k - 1)) * (n_nodes * (self.k - 1) - loss) |
| 225 | + |
| 226 | + return loss |
| 227 | + |
| 228 | + def get_config(self): |
| 229 | + config = { |
| 230 | + "k": self.k, |
| 231 | + "mlp_hidden": self.mlp_hidden, |
| 232 | + "mlp_activation": self.mlp_activation, |
| 233 | + "totvar_coeff": self.totvar_coeff, |
| 234 | + "balance_coeff": self.balance_coeff, |
| 235 | + } |
| 236 | + base_config = super().get_config() |
| 237 | + return {**base_config, **config} |
0 commit comments