From 910af697a5b88aa0a70edec1c309ca3060cf0546 Mon Sep 17 00:00:00 2001 From: JonasBergHansen Date: Thu, 11 May 2023 10:44:54 +0200 Subject: [PATCH 1/8] added GTVConv layer --- docs/autogen.py | 1 + spektral/layers/convolutional/__init__.py | 1 + spektral/layers/convolutional/gtv_conv.py | 202 ++++++++++++++++++ .../convolutional/test_gtv_conv.py | 16 ++ 4 files changed, 220 insertions(+) create mode 100644 spektral/layers/convolutional/gtv_conv.py create mode 100644 tests/test_layers/convolutional/test_gtv_conv.py diff --git a/docs/autogen.py b/docs/autogen.py index 6c0bf5b8..46fba5a9 100644 --- a/docs/autogen.py +++ b/docs/autogen.py @@ -40,6 +40,7 @@ layers.GCSConv, layers.GINConv, layers.GraphSageConv, + layers.GTVConv, layers.TAGConv, layers.XENetConv, ], diff --git a/spektral/layers/convolutional/__init__.py b/spektral/layers/convolutional/__init__.py index 53a34199..858c6bad 100644 --- a/spektral/layers/convolutional/__init__.py +++ b/spektral/layers/convolutional/__init__.py @@ -14,6 +14,7 @@ from .general_conv import GeneralConv from .gin_conv import GINConv from .graphsage_conv import GraphSageConv +from .gtv_conv import GTVConv from .message_passing import MessagePassing from .tag_conv import TAGConv from .xenet_conv import XENetConv, XENetDenseConv diff --git a/spektral/layers/convolutional/gtv_conv.py b/spektral/layers/convolutional/gtv_conv.py new file mode 100644 index 00000000..cba5f431 --- /dev/null +++ b/spektral/layers/convolutional/gtv_conv.py @@ -0,0 +1,202 @@ +import tensorflow as tf +from tensorflow.keras import backend as K +from spektral.layers import ops +from spektral.layers.convolutional.conv import Conv + +class GTVConv(Conv): + r""" + A graph total variation convolutional layer (GTVConv) from the paper + + > [Clustering with Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218)
+ > Jonas Berg Hansen and Filippo Maria Bianchi + + **Mode**: single, disjoint, batch + + This layer computes + $$ + \X' = \sigma\left[\left(\I - \delta\L_\hat{\mathbf{\Gamma}}\right) \tilde{\X} \right] + $$ + where + $$ + \begin{align} + \tilde{\X} &= \X \W\\[5pt] + \L_\hat{\mathbf{\Gamma}} &= \D_\mathbf{\hat{\Gamma}} - \hat{\mathbf{\Gamma}}\\[5pt] + [\hat{\mathbf{\Gamma}}]_{ij} &= \frac{[\mathbf{A}]_{ij}}{\max\{||\tilde{\x}_i-\tilde{\x}_j||_1, \epsilon\}}\\ + \end{align} + $$ + + **Input** + + - Node features of shape `(batch, n_nodes, n_node_features)`; + - Adjacency matrix of shape `(batch, n_nodes, n_nodes)`; + + **Output** + + - Node features with the same shape as the input, but with the last + dimension changed to `channels`. + + **Arguments** + + - `channels`: number of output channels; + - `delta_coeff`: step size for gradient descent of GTV + - `epsilon`: small number used to numerically stabilize the computation of new adjacency weights + - `activation`: activation function; + - `use_bias`: bool, add a bias vector to the output; + - `kernel_initializer`: initializer for the weights; + - `bias_initializer`: initializer for the bias vector; + - `kernel_regularizer`: regularization applied to the weights; + - `bias_regularizer`: regularization applied to the bias vector; + - `activity_regularizer`: regularization applied to the output; + - `kernel_constraint`: constraint applied to the weights; + - `bias_constraint`: constraint applied to the bias vector. + + """ + + def __init__( + self, + channels, + delta_coeff=1.0, + epsilon=0.001, + activation=None, + use_bias=True, + kernel_initializer="he_normal", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs + ): + super().__init__( + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + **kwargs + ) + + self.channels = channels + self.delta_coeff = delta_coeff + self.epsilon = epsilon + + def build(self, input_shape): + assert len(input_shape) >= 2 + input_dim = input_shape[0][-1] + self.kernel = self.add_weight( + shape=(input_dim, self.channels), + initializer=self.kernel_initializer, + name="kernel", + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + ) + if self.use_bias: + self.bias = self.add_weight( + shape=(self.channels,), + initializer=self.bias_initializer, + name="bias", + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + ) + self.built = True + + def call(self, inputs, mask=None): + x, a = inputs + + mode = ops.autodetect_mode(x, a) + + x = K.dot(x, self.kernel) + + if mode == ops.modes.SINGLE: + output = self._call_single(x, a) + + elif mode == ops.modes.BATCH: + output = self._call_batch(x, a) + + if self.use_bias: + output = K.bias_add(output, self.bias) + + if mask is not None: + output *= mask[0] + + output = self.activation(output) + + return output + + def _call_single(self, x, a): + if K.is_sparse(a): + index_i = a.indices[:, 0] + index_j = a.indices[:, 1] + + n_nodes = tf.shape(a, out_type=index_i.dtype)[0] + + # Compute absolute differences between neighbouring nodes + abs_diff = tf.math.abs(tf.transpose(tf.gather(x, index_i)) - + tf.transpose(tf.gather(x, index_j))) + abs_diff = tf.math.reduce_sum(abs_diff, axis=0) + + # Compute new adjacency matrix + gamma = tf.sparse.map_values(tf.multiply, + a, + 1 / tf.math.maximum(abs_diff, self.epsilon)) + + # Compute degree matrix from gamma matrix + d_gamma = tf.sparse.SparseTensor(tf.stack([tf.range(n_nodes)] * 2, axis=1), + tf.sparse.reduce_sum(gamma, axis=-1), + [n_nodes, n_nodes]) + + # Compute laplcian: L = D_gamma - Gamma + l = tf.sparse.add(d_gamma, tf.sparse.map_values( + tf.multiply, gamma, -1.)) + + # Compute adjusted laplacian: L_adjusted = I - delta*L + l = tf.sparse.add(tf.sparse.eye(n_nodes, dtype=x.dtype), tf.sparse.map_values( + tf.multiply, l, -self.delta_coeff)) + + # Aggregate features with adjusted laplacian + output = ops.modal_dot(l, x) + + else: + n_nodes = tf.shape(a)[-1] + + abs_diff = tf.math.abs(x[:, tf.newaxis, :] - x) + abs_diff = tf.reduce_sum(abs_diff, axis=-1) + + gamma = a / tf.math.maximum(abs_diff, self.epsilon) + + degrees = tf.math.reduce_sum(gamma, axis=-1) + l = -gamma + l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma)) + l = tf.eye(n_nodes, dtype=x.dtype) - self.delta_coeff * l + + output = tf.matmul(l, x) + + return output + + def _call_batch(self, x, a): + n_nodes = tf.shape(a)[-1] + + abs_diff = tf.reduce_sum(tf.math.abs(tf.expand_dims(x, 2) - + tf.expand_dims(x, 1)), axis = -1) + + gamma = a / tf.math.maximum(abs_diff, self.epsilon) + + degrees = tf.math.reduce_sum(gamma, axis=-1) + l = -gamma + l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma)) + l = tf.eye(n_nodes, dtype=x.dtype) - self.delta_coeff * l + + output = tf.matmul(l, x) + + return output + + @property + def config(self): + return {"channels": self.channels, + "delta_coeff": self.delta_coeff, + "epsilon": self.epsilon} \ No newline at end of file diff --git a/tests/test_layers/convolutional/test_gtv_conv.py b/tests/test_layers/convolutional/test_gtv_conv.py new file mode 100644 index 00000000..cbf60398 --- /dev/null +++ b/tests/test_layers/convolutional/test_gtv_conv.py @@ -0,0 +1,16 @@ +from core import MODES, run_layer + +from spektral import layers + +config = { + "layer": layers.GTVConv, + "modes": [MODES["SINGLE"], MODES["BATCH"]], + "kwargs": {"channels": 8, "delta_coeff": 1.0, "epsilon": 0.001, "activation": "relu"}, + "dense": True, + "sparse": True, + "edges": False, +} + + +def test_layer(): + run_layer(config) \ No newline at end of file From 493cc61cacc7536113e38ee5548a5d64b4e89bf7 Mon Sep 17 00:00:00 2001 From: JonasBergHansen Date: Thu, 11 May 2023 13:44:10 +0200 Subject: [PATCH 2/8] Add AsymCheegerCutPool layer + fix docstring --- docs/autogen.py | 1 + spektral/layers/convolutional/gtv_conv.py | 4 +- spektral/layers/pooling/__init__.py | 1 + .../layers/pooling/asym_cheeger_cut_pool.py | 222 ++++++++++++++++++ .../pooling/test_asym_cheeger_cut_pool.py | 14 ++ 5 files changed, 240 insertions(+), 2 deletions(-) create mode 100644 spektral/layers/pooling/asym_cheeger_cut_pool.py create mode 100644 tests/test_layers/pooling/test_asym_cheeger_cut_pool.py diff --git a/docs/autogen.py b/docs/autogen.py index 46fba5a9..42245750 100644 --- a/docs/autogen.py +++ b/docs/autogen.py @@ -51,6 +51,7 @@ "methods": [], "classes": [ layers.SRCPool, + layers.AsymCheegerCutPool, layers.DiffPool, layers.LaPool, layers.MinCutPool, diff --git a/spektral/layers/convolutional/gtv_conv.py b/spektral/layers/convolutional/gtv_conv.py index cba5f431..13bc412b 100644 --- a/spektral/layers/convolutional/gtv_conv.py +++ b/spektral/layers/convolutional/gtv_conv.py @@ -7,10 +7,10 @@ class GTVConv(Conv): r""" A graph total variation convolutional layer (GTVConv) from the paper - > [Clustering with Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218)
+ > [Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218)
> Jonas Berg Hansen and Filippo Maria Bianchi - **Mode**: single, disjoint, batch + **Mode**: single, disjoint, batch. This layer computes $$ diff --git a/spektral/layers/pooling/__init__.py b/spektral/layers/pooling/__init__.py index 2f21c014..ec138e6f 100644 --- a/spektral/layers/pooling/__init__.py +++ b/spektral/layers/pooling/__init__.py @@ -1,3 +1,4 @@ +from .asym_cheeger_cut_pool import AsymCheegerCutPool from .diff_pool import DiffPool from .dmon_pool import DMoNPool from .global_pool import ( diff --git a/spektral/layers/pooling/asym_cheeger_cut_pool.py b/spektral/layers/pooling/asym_cheeger_cut_pool.py new file mode 100644 index 00000000..5f7a8599 --- /dev/null +++ b/spektral/layers/pooling/asym_cheeger_cut_pool.py @@ -0,0 +1,222 @@ +import tensorflow as tf +from tensorflow.keras import Sequential +from tensorflow.keras.layers import Dense +import tensorflow.keras.backend as K +from spektral.layers import ops +from spektral.layers.pooling.src import SRCPool + +class AsymCheegerCutPool(SRCPool): + r""" + An Asymmetric Cheeger Cut Pooling layer from the paper + > [Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218)
+ > Jonas Berg Hansen and Filippo Maria Bianchi + + **Mode**: single, batch. + + This layer learns a soft clustering of the input graph as follows: + $$ + \begin{align} + \S &= \textrm{MLP}(\X); \\ + \X' &= \S^\top \X \\ + \A' &= \S^\top \A \S; \\ + \end{align} + $$ + where \(\textrm{MLP}\) is a multi-layer perceptron with softmax output. + + The layer includes two auxiliary loss terms/components: + A graph total variation component given by + $$ + L_\text{GTV} = \frac{1}{2E} \sum_{k=1}^K \sum_{i=1}^N \sum_{j=i}^N a_{i,j} |s_{i,k} - s_{j,k}|, + $$ + where \(E\) is the number of edges/links, \(K\) is the number of clusters or output nodes, and \(N\) is the number of nodes. + + An asymmetrical norm component given by + $$ + L_\text{AN} = \frac{N(K - 1) - \sum_{k=1}^K ||\s_{:,k} - \textrm{quant}_{K-1} (\s_{:,k})||_{1, K-1}}{N(K-1)}, + $$ + + The layer can be used without a supervised loss to compute node clustering by + minimizing the two auxiliary losses. + + **Input** + + - Node features of shape `(batch, n_nodes_in, n_node_features)`; + - Adjacency matrix of shape `(batch, n_nodes_in, n_nodes_in)`; + + **Output** + + - Reduced node features of shape `(batch, n_nodes_out, n_node_features)`; + - If `return_selection=True`, the selection matrix of shape + `(batch, n_nodes_in, n_nodes_out)`. + + **Arguments** + + - `k`: number of output nodes; + - `mlp_hidden`: list of integers, number of hidden units for each hidden layer in + the MLP used to compute cluster assignments (if `None`, the MLP has only one output + layer); + - `mlp_activation`: activation for the MLP layers; + - `totvar_coeff`: coefficient for graph total variation loss component; + - `balance_coeff`: coefficient for asymmetric norm loss component; + - `return_selection`: boolean, whether to return the selection matrix; + - `use_bias`: use bias in the MLP; + - `kernel_initializer`: initializer for the weights of the MLP; + - `bias_regularizer`: regularization applied to the bias of the MLP; + - `kernel_constraint`: constraint applied to the weights of the MLP; + - `bias_constraint`: constraint applied to the bias of the MLP; + """ + + def __init__( + self, + k, + mlp_hidden=None, + mlp_activation="relu", + totvar_coeff=1.0, + balance_coeff=1.0, + return_selection=False, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs + ): + super().__init__( + k=k, + mlp_hidden=mlp_hidden, + mlp_activation=mlp_activation, + return_selection=return_selection, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + **kwargs + ) + + self.k = k + self.mlp_hidden = mlp_hidden if mlp_hidden else [] + self.mlp_activation = mlp_activation + self.totvar_coeff = totvar_coeff + self.balance_coeff = balance_coeff + + def build(self, input_shape): + layer_kwargs = dict( + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + kernel_constraint=self.kernel_constraint, + bias_constraint=self.bias_constraint, + ) + self.mlp = Sequential( + [ + Dense(channels, self.mlp_activation, **layer_kwargs) + for channels in self.mlp_hidden + ] + + [Dense(self.k, "softmax", **layer_kwargs)] + ) + + super().build(input_shape) + + def call(self, inputs, mask=None): + x, a, i = self.get_inputs(inputs) + return self.pool(x, a, i, mask=mask) + + def select(self, x, a, i, mask=None): + s = self.mlp(x) + if mask is not None: + s *= mask[0] + + # Total variation loss + tv_loss = self.totvar_loss(a, s) + if K.ndim(a) == 3: + tv_loss = K.mean(tv_loss) + self.add_loss(self.totvar_coeff * tv_loss) + + # Asymmetric l1-norm loss + bal_loss = self.balance_loss(s) + if K.ndim(a) == 3: + bal_loss = K.mean(bal_loss) + self.add_loss(self.balance_coeff * bal_loss) + + return s + + def reduce(self, x, s, **kwargs): + return ops.modal_dot(s, x, transpose_a=True) + + def connect(self, a, s, **kwargs): + a_pool = ops.matmul_at_b_a(s, a) + + return a_pool + + def reduce_index(self, i, s, **kwargs): + i_mean = tf.math.segment_mean(i, i) + i_pool = ops.repeat(i_mean, tf.ones_like(i_mean) * self.k) + + return i_pool + + def totvar_loss(self, a, s): + if K.is_sparse(a): + index_i = a.indices[:, 0] + index_j = a.indices[:, 1] + + n_edges = tf.cast(len(a.values), dtype=s.dtype) + + loss = tf.math.reduce_sum(a.values[:, tf.newaxis] * + tf.math.abs(tf.gather(s, index_i) - + tf.gather(s, index_j)), + axis=(-2, -1)) + + else: + n_edges = tf.cast(tf.math.count_nonzero( + a, axis=(-2, -1)), dtype=s.dtype) + n_nodes = tf.shape(a)[-1] + if K.ndim(a) == 3: + loss = tf.math.reduce_sum(a * tf.math.reduce_sum(tf.math.abs(s[:, tf.newaxis, ...] - + tf.repeat(s[..., tf.newaxis, :], + n_nodes, axis=-2)), axis=-1), + axis=(-2, -1)) + else: + loss = tf.math.reduce_sum(a * tf.math.reduce_sum(tf.math.abs(s - + tf.repeat(s[..., tf.newaxis, :], + n_nodes, axis=-2)), axis=-1), + axis=(-2, -1)) + + loss *= 1 / (2 * n_edges) + + return loss + + def balance_loss(self, s): + n_nodes = tf.cast(tf.shape(s, out_type=tf.int32)[-2], s.dtype) + + # k-quantile + idx = tf.cast(tf.math.floor(n_nodes / self.k) + 1, dtype=tf.int32) + med = tf.math.top_k(tf.linalg.matrix_transpose(s), + k=idx).values[..., -1] + # Asymmetric l1-norm + if K.ndim(s) == 2: + loss = s - med + else: + loss = s - med[:, tf.newaxis, ...] + loss = ((tf.cast(loss >= 0, loss.dtype) * (self.k - 1) * loss) + + (tf.cast(loss < 0, loss.dtype) * loss * -1.)) + loss = tf.math.reduce_sum(loss, axis=(-2, -1)) + loss = 1 / (n_nodes * (self.k - 1)) * (n_nodes * (self.k - 1) - loss) + + return loss + + def get_config(self): + config = { + "k": self.k, + "mlp_hidden": self.mlp_hidden, + "mlp_activation": self.mlp_activation, + "totvar_coeff": self.totvar_coeff, + "balance_coeff": self.balance_coeff + } + base_config = super().get_config() + return {**base_config, **config} \ No newline at end of file diff --git a/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py b/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py new file mode 100644 index 00000000..468c5f04 --- /dev/null +++ b/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py @@ -0,0 +1,14 @@ +from spektral import layers +from tests.test_layers.pooling.core import MODES, run_layer + +config = { + "layer": layers.AsymCheegerCutPool, + "modes": [MODES["SINGLE"], MODES["BATCH"]], + "kwargs": {"k": 5, "return_selection": True, "mlp_hidden": [32], "totvar_coeff": 1.0, "balance_coeff": 1.0}, + "dense": True, + "sparse": True, +} + + +def test_layer(): + run_layer(config) From 9172cc90b112001d0970e519c20bdbf3942f4b84 Mon Sep 17 00:00:00 2001 From: JonasBergHansen Date: Tue, 16 May 2023 10:15:24 +0200 Subject: [PATCH 3/8] Add tvgnn example + change of default initializer --- examples/other/node_clustering_tvgnn.py | 129 ++++++++++++++++++++++ spektral/layers/convolutional/gtv_conv.py | 2 +- 2 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 examples/other/node_clustering_tvgnn.py diff --git a/examples/other/node_clustering_tvgnn.py b/examples/other/node_clustering_tvgnn.py new file mode 100644 index 00000000..06a9a81f --- /dev/null +++ b/examples/other/node_clustering_tvgnn.py @@ -0,0 +1,129 @@ +""" +This example implements the node clustering experiment on citation networks +from the paper: + +Total Variation Graph Neural Networks (https://arxiv.org/abs/2211.06218) +Jonas Berg Hansen and Filippo Maria Bianchi +""" + +import numpy as np +from tqdm import tqdm +from sklearn.metrics.cluster import ( + completeness_score, + homogeneity_score, + normalized_mutual_info_score, +) +import tensorflow as tf +from tensorflow.keras import Model +from spektral.utils.sparse import sp_matrix_to_sp_tensor +from spektral.datasets.citation import Citation +from spektral.datasets import DBLP +from spektral.layers import GTVConv, AsymCheegerCutPool + +tf.random.set_seed(1) + +################################ +# CONFIG/HYPERPARAMETERS +################################ +dataset_id = "cora" +mp_channels = 512 +mp_layers = 2 +mp_activation = "elu" +delta_coeff = 0.311 +epsilon=1e-3 +mlp_hidden_channels = 256 +mlp_hidden_layers = 1 +mlp_activation = "relu" +totvar_coeff=0.785 +balance_coeff=0.514 +learning_rate = 1e-3 +epochs = 500 + +################################ +# LOAD DATASET +################################ +if dataset_id in ["cora", "citeseer", "pubmed"]: + dataset = Citation(dataset_id, normalize_x=True) +elif dataset_id == "dblp": + dataset = DBLP(normalize_x=True) +X = dataset.graphs[0].x +A = dataset.graphs[0].a +Y = dataset.graphs[0].y +y = np.argmax(Y, axis=-1) +n_clust = Y.shape[-1] + +################################ +# MODEL +################################ +class ClusteringModel(Model): + """ + Defines the general model structure + """ + + def __init__(self, aggr, pool): + super().__init__() + + self.mp = aggr + self.pool = pool + + def call(self, inputs): + x, a = inputs + + out = x + for _mp in self.mp: + out = _mp([out, a]) + + _, _, s_pool = self.pool([out, a]) + + return s_pool + +# Define the message-passing layers +MP_layers = [GTVConv( + mp_channels, + delta_coeff=delta_coeff, + epsilon=1e-3, + activation=mp_activation) +for _ in range(mp_layers)] + +# Define the pooling layer +pool_layer = AsymCheegerCutPool( + n_clust, + mlp_hidden=[mlp_hidden_channels for _ in range(mlp_hidden_layers)], + mlp_activation=mlp_activation, + totvar_coeff=totvar_coeff, + balance_coeff=balance_coeff, + return_selection=True) + +# Instantiate model and optimizer +model = ClusteringModel(aggr=MP_layers, pool=pool_layer) +opt = tf.keras.optimizers.Adam(learning_rate=learning_rate) + +################################ +# TRAINING +################################ +@tf.function(input_signature=None) +def train_step(model, inputs): + with tf.GradientTape() as tape: + _ = model(inputs, training=True) + loss = sum(model.losses) + gradients = tape.gradient(loss, model.trainable_variables) + opt.apply_gradients(zip(gradients, model.trainable_variables)) + return model.losses + +A = sp_matrix_to_sp_tensor(A) +inputs = [X, A] +loss_history = [] + +# Training loop +for _ in tqdm(range(epochs)): + outs = train_step(model, inputs) + +################################ +# INFERENCE/RESULTS +################################ +S_ = model(inputs, training=False) +s_out = np.argmax(S_, axis=-1) +nmi = normalized_mutual_info_score(y, s_out) +hom = homogeneity_score(y, s_out) +com = completeness_score(y, s_out) +print("Homogeneity: {:.3f}; Completeness: {:.3f}; NMI: {:.3f}".format(hom, com, nmi)) \ No newline at end of file diff --git a/spektral/layers/convolutional/gtv_conv.py b/spektral/layers/convolutional/gtv_conv.py index 13bc412b..ea469975 100644 --- a/spektral/layers/convolutional/gtv_conv.py +++ b/spektral/layers/convolutional/gtv_conv.py @@ -59,7 +59,7 @@ def __init__( epsilon=0.001, activation=None, use_bias=True, - kernel_initializer="he_normal", + kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None, From 356bde9c99fdff30b4dc36b1b1a50cbdb0e3fce4 Mon Sep 17 00:00:00 2001 From: Daniele Grattarola Date: Thu, 1 Jun 2023 22:36:46 +0100 Subject: [PATCH 4/8] Fix formatting and sort imports --- examples/other/node_clustering_tvgnn.py | 36 +++++----- spektral/layers/convolutional/gtv_conv.py | 72 +++++++++++-------- .../layers/pooling/asym_cheeger_cut_pool.py | 61 ++++++++++------ .../convolutional/test_gtv_conv.py | 9 ++- .../pooling/test_asym_cheeger_cut_pool.py | 8 ++- 5 files changed, 113 insertions(+), 73 deletions(-) diff --git a/examples/other/node_clustering_tvgnn.py b/examples/other/node_clustering_tvgnn.py index 06a9a81f..11ed8d1d 100644 --- a/examples/other/node_clustering_tvgnn.py +++ b/examples/other/node_clustering_tvgnn.py @@ -7,18 +7,19 @@ """ import numpy as np -from tqdm import tqdm +import tensorflow as tf from sklearn.metrics.cluster import ( completeness_score, homogeneity_score, normalized_mutual_info_score, ) -import tensorflow as tf from tensorflow.keras import Model -from spektral.utils.sparse import sp_matrix_to_sp_tensor -from spektral.datasets.citation import Citation +from tqdm import tqdm + from spektral.datasets import DBLP -from spektral.layers import GTVConv, AsymCheegerCutPool +from spektral.datasets.citation import Citation +from spektral.layers import AsymCheegerCutPool, GTVConv +from spektral.utils.sparse import sp_matrix_to_sp_tensor tf.random.set_seed(1) @@ -30,12 +31,12 @@ mp_layers = 2 mp_activation = "elu" delta_coeff = 0.311 -epsilon=1e-3 +epsilon = 1e-3 mlp_hidden_channels = 256 mlp_hidden_layers = 1 mlp_activation = "relu" -totvar_coeff=0.785 -balance_coeff=0.514 +totvar_coeff = 0.785 +balance_coeff = 0.514 learning_rate = 1e-3 epochs = 500 @@ -77,13 +78,14 @@ def call(self, inputs): return s_pool + # Define the message-passing layers -MP_layers = [GTVConv( - mp_channels, - delta_coeff=delta_coeff, - epsilon=1e-3, - activation=mp_activation) -for _ in range(mp_layers)] +MP_layers = [ + GTVConv( + mp_channels, delta_coeff=delta_coeff, epsilon=1e-3, activation=mp_activation + ) + for _ in range(mp_layers) +] # Define the pooling layer pool_layer = AsymCheegerCutPool( @@ -92,7 +94,8 @@ def call(self, inputs): mlp_activation=mlp_activation, totvar_coeff=totvar_coeff, balance_coeff=balance_coeff, - return_selection=True) + return_selection=True, +) # Instantiate model and optimizer model = ClusteringModel(aggr=MP_layers, pool=pool_layer) @@ -110,6 +113,7 @@ def train_step(model, inputs): opt.apply_gradients(zip(gradients, model.trainable_variables)) return model.losses + A = sp_matrix_to_sp_tensor(A) inputs = [X, A] loss_history = [] @@ -126,4 +130,4 @@ def train_step(model, inputs): nmi = normalized_mutual_info_score(y, s_out) hom = homogeneity_score(y, s_out) com = completeness_score(y, s_out) -print("Homogeneity: {:.3f}; Completeness: {:.3f}; NMI: {:.3f}".format(hom, com, nmi)) \ No newline at end of file +print("Homogeneity: {:.3f}; Completeness: {:.3f}; NMI: {:.3f}".format(hom, com, nmi)) diff --git a/spektral/layers/convolutional/gtv_conv.py b/spektral/layers/convolutional/gtv_conv.py index ea469975..9b3c1a22 100644 --- a/spektral/layers/convolutional/gtv_conv.py +++ b/spektral/layers/convolutional/gtv_conv.py @@ -1,8 +1,10 @@ import tensorflow as tf from tensorflow.keras import backend as K + from spektral.layers import ops from spektral.layers.convolutional.conv import Conv + class GTVConv(Conv): r""" A graph total variation convolutional layer (GTVConv) from the paper @@ -132,43 +134,48 @@ def _call_single(self, x, a): if K.is_sparse(a): index_i = a.indices[:, 0] index_j = a.indices[:, 1] - + n_nodes = tf.shape(a, out_type=index_i.dtype)[0] - + # Compute absolute differences between neighbouring nodes - abs_diff = tf.math.abs(tf.transpose(tf.gather(x, index_i)) - - tf.transpose(tf.gather(x, index_j))) + abs_diff = tf.math.abs( + tf.transpose(tf.gather(x, index_i)) + - tf.transpose(tf.gather(x, index_j)) + ) abs_diff = tf.math.reduce_sum(abs_diff, axis=0) - + # Compute new adjacency matrix - gamma = tf.sparse.map_values(tf.multiply, - a, - 1 / tf.math.maximum(abs_diff, self.epsilon)) - + gamma = tf.sparse.map_values( + tf.multiply, a, 1 / tf.math.maximum(abs_diff, self.epsilon) + ) + # Compute degree matrix from gamma matrix - d_gamma = tf.sparse.SparseTensor(tf.stack([tf.range(n_nodes)] * 2, axis=1), - tf.sparse.reduce_sum(gamma, axis=-1), - [n_nodes, n_nodes]) - + d_gamma = tf.sparse.SparseTensor( + tf.stack([tf.range(n_nodes)] * 2, axis=1), + tf.sparse.reduce_sum(gamma, axis=-1), + [n_nodes, n_nodes], + ) + # Compute laplcian: L = D_gamma - Gamma - l = tf.sparse.add(d_gamma, tf.sparse.map_values( - tf.multiply, gamma, -1.)) - + l = tf.sparse.add(d_gamma, tf.sparse.map_values(tf.multiply, gamma, -1.0)) + # Compute adjusted laplacian: L_adjusted = I - delta*L - l = tf.sparse.add(tf.sparse.eye(n_nodes, dtype=x.dtype), tf.sparse.map_values( - tf.multiply, l, -self.delta_coeff)) - + l = tf.sparse.add( + tf.sparse.eye(n_nodes, dtype=x.dtype), + tf.sparse.map_values(tf.multiply, l, -self.delta_coeff), + ) + # Aggregate features with adjusted laplacian output = ops.modal_dot(l, x) - + else: n_nodes = tf.shape(a)[-1] - + abs_diff = tf.math.abs(x[:, tf.newaxis, :] - x) abs_diff = tf.reduce_sum(abs_diff, axis=-1) - + gamma = a / tf.math.maximum(abs_diff, self.epsilon) - + degrees = tf.math.reduce_sum(gamma, axis=-1) l = -gamma l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma)) @@ -180,9 +187,10 @@ def _call_single(self, x, a): def _call_batch(self, x, a): n_nodes = tf.shape(a)[-1] - - abs_diff = tf.reduce_sum(tf.math.abs(tf.expand_dims(x, 2) - - tf.expand_dims(x, 1)), axis = -1) + + abs_diff = tf.reduce_sum( + tf.math.abs(tf.expand_dims(x, 2) - tf.expand_dims(x, 1)), axis=-1 + ) gamma = a / tf.math.maximum(abs_diff, self.epsilon) @@ -192,11 +200,13 @@ def _call_batch(self, x, a): l = tf.eye(n_nodes, dtype=x.dtype) - self.delta_coeff * l output = tf.matmul(l, x) - + return output - + @property def config(self): - return {"channels": self.channels, - "delta_coeff": self.delta_coeff, - "epsilon": self.epsilon} \ No newline at end of file + return { + "channels": self.channels, + "delta_coeff": self.delta_coeff, + "epsilon": self.epsilon, + } diff --git a/spektral/layers/pooling/asym_cheeger_cut_pool.py b/spektral/layers/pooling/asym_cheeger_cut_pool.py index 5f7a8599..fe59e9af 100644 --- a/spektral/layers/pooling/asym_cheeger_cut_pool.py +++ b/spektral/layers/pooling/asym_cheeger_cut_pool.py @@ -1,10 +1,12 @@ import tensorflow as tf +import tensorflow.keras.backend as K from tensorflow.keras import Sequential from tensorflow.keras.layers import Dense -import tensorflow.keras.backend as K + from spektral.layers import ops from spektral.layers.pooling.src import SRCPool + class AsymCheegerCutPool(SRCPool): r""" An Asymmetric Cheeger Cut Pooling layer from the paper @@ -151,7 +153,7 @@ def reduce(self, x, s, **kwargs): def connect(self, a, s, **kwargs): a_pool = ops.matmul_at_b_a(s, a) - + return a_pool def reduce_index(self, i, s, **kwargs): @@ -159,7 +161,7 @@ def reduce_index(self, i, s, **kwargs): i_pool = ops.repeat(i_mean, tf.ones_like(i_mean) * self.k) return i_pool - + def totvar_loss(self, a, s): if K.is_sparse(a): index_i = a.indices[:, 0] @@ -167,25 +169,38 @@ def totvar_loss(self, a, s): n_edges = tf.cast(len(a.values), dtype=s.dtype) - loss = tf.math.reduce_sum(a.values[:, tf.newaxis] * - tf.math.abs(tf.gather(s, index_i) - - tf.gather(s, index_j)), - axis=(-2, -1)) + loss = tf.math.reduce_sum( + a.values[:, tf.newaxis] + * tf.math.abs(tf.gather(s, index_i) - tf.gather(s, index_j)), + axis=(-2, -1), + ) else: - n_edges = tf.cast(tf.math.count_nonzero( - a, axis=(-2, -1)), dtype=s.dtype) + n_edges = tf.cast(tf.math.count_nonzero(a, axis=(-2, -1)), dtype=s.dtype) n_nodes = tf.shape(a)[-1] if K.ndim(a) == 3: - loss = tf.math.reduce_sum(a * tf.math.reduce_sum(tf.math.abs(s[:, tf.newaxis, ...] - - tf.repeat(s[..., tf.newaxis, :], - n_nodes, axis=-2)), axis=-1), - axis=(-2, -1)) + loss = tf.math.reduce_sum( + a + * tf.math.reduce_sum( + tf.math.abs( + s[:, tf.newaxis, ...] + - tf.repeat(s[..., tf.newaxis, :], n_nodes, axis=-2) + ), + axis=-1, + ), + axis=(-2, -1), + ) else: - loss = tf.math.reduce_sum(a * tf.math.reduce_sum(tf.math.abs(s - - tf.repeat(s[..., tf.newaxis, :], - n_nodes, axis=-2)), axis=-1), - axis=(-2, -1)) + loss = tf.math.reduce_sum( + a + * tf.math.reduce_sum( + tf.math.abs( + s - tf.repeat(s[..., tf.newaxis, :], n_nodes, axis=-2) + ), + axis=-1, + ), + axis=(-2, -1), + ) loss *= 1 / (2 * n_edges) @@ -196,15 +211,15 @@ def balance_loss(self, s): # k-quantile idx = tf.cast(tf.math.floor(n_nodes / self.k) + 1, dtype=tf.int32) - med = tf.math.top_k(tf.linalg.matrix_transpose(s), - k=idx).values[..., -1] + med = tf.math.top_k(tf.linalg.matrix_transpose(s), k=idx).values[..., -1] # Asymmetric l1-norm if K.ndim(s) == 2: loss = s - med else: loss = s - med[:, tf.newaxis, ...] - loss = ((tf.cast(loss >= 0, loss.dtype) * (self.k - 1) * loss) + - (tf.cast(loss < 0, loss.dtype) * loss * -1.)) + loss = (tf.cast(loss >= 0, loss.dtype) * (self.k - 1) * loss) + ( + tf.cast(loss < 0, loss.dtype) * loss * -1.0 + ) loss = tf.math.reduce_sum(loss, axis=(-2, -1)) loss = 1 / (n_nodes * (self.k - 1)) * (n_nodes * (self.k - 1) - loss) @@ -216,7 +231,7 @@ def get_config(self): "mlp_hidden": self.mlp_hidden, "mlp_activation": self.mlp_activation, "totvar_coeff": self.totvar_coeff, - "balance_coeff": self.balance_coeff + "balance_coeff": self.balance_coeff, } base_config = super().get_config() - return {**base_config, **config} \ No newline at end of file + return {**base_config, **config} diff --git a/tests/test_layers/convolutional/test_gtv_conv.py b/tests/test_layers/convolutional/test_gtv_conv.py index cbf60398..ef498348 100644 --- a/tests/test_layers/convolutional/test_gtv_conv.py +++ b/tests/test_layers/convolutional/test_gtv_conv.py @@ -5,7 +5,12 @@ config = { "layer": layers.GTVConv, "modes": [MODES["SINGLE"], MODES["BATCH"]], - "kwargs": {"channels": 8, "delta_coeff": 1.0, "epsilon": 0.001, "activation": "relu"}, + "kwargs": { + "channels": 8, + "delta_coeff": 1.0, + "epsilon": 0.001, + "activation": "relu", + }, "dense": True, "sparse": True, "edges": False, @@ -13,4 +18,4 @@ def test_layer(): - run_layer(config) \ No newline at end of file + run_layer(config) diff --git a/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py b/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py index 468c5f04..3baf3be5 100644 --- a/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py +++ b/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py @@ -4,7 +4,13 @@ config = { "layer": layers.AsymCheegerCutPool, "modes": [MODES["SINGLE"], MODES["BATCH"]], - "kwargs": {"k": 5, "return_selection": True, "mlp_hidden": [32], "totvar_coeff": 1.0, "balance_coeff": 1.0}, + "kwargs": { + "k": 5, + "return_selection": True, + "mlp_hidden": [32], + "totvar_coeff": 1.0, + "balance_coeff": 1.0, + }, "dense": True, "sparse": True, } From ed62ea19787cde0bdb7b609fb1ee24f067731679 Mon Sep 17 00:00:00 2001 From: Daniele Grattarola Date: Thu, 1 Jun 2023 22:41:32 +0100 Subject: [PATCH 5/8] Fix Python version on tests --- .github/workflows/examples.yml | 4 ++-- .github/workflows/style_check.yml | 4 ++-- .github/workflows/test.yml | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 1fdfe386..4756c2b3 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -9,10 +9,10 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python 3.6 + - name: Set up Python 3.11 uses: actions/setup-python@v2 with: - python-version: 3.6 + python-version: 3.11 - name: Install dependencies run: | pip install ogb matplotlib diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 9249371c..583de2f9 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -9,10 +9,10 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python 3.6 + - name: Set up Python 3.11 uses: actions/setup-python@v2 with: - python-version: 3.6 + python-version: 3.11 - name: Lint Python code run: | pip install flake8 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3af84537..f2e0e804 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,7 +8,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: [3.8, 3.9, 3.10] os: [ubuntu-latest, macos-latest, windows-latest] steps: From 197adfd51ee1b9c970ea73fe9f76b5f3962a33a9 Mon Sep 17 00:00:00 2001 From: Daniele Grattarola Date: Thu, 1 Jun 2023 22:43:38 +0100 Subject: [PATCH 6/8] Format with latest Black version --- examples/graph_prediction/qm9_ecc.py | 1 + examples/other/node_clustering_tvgnn.py | 2 ++ spektral/layers/base.py | 2 -- spektral/layers/pooling/dmon_pool.py | 1 - spektral/layers/pooling/global_pool.py | 1 - 5 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/graph_prediction/qm9_ecc.py b/examples/graph_prediction/qm9_ecc.py index 7c3822b7..864f2adc 100644 --- a/examples/graph_prediction/qm9_ecc.py +++ b/examples/graph_prediction/qm9_ecc.py @@ -40,6 +40,7 @@ loader_tr = DisjointLoader(dataset_tr, batch_size=batch_size, epochs=epochs) loader_te = DisjointLoader(dataset_te, batch_size=batch_size, epochs=1) + ################################################################################ # Build model ################################################################################ diff --git a/examples/other/node_clustering_tvgnn.py b/examples/other/node_clustering_tvgnn.py index 11ed8d1d..59ffdbaa 100644 --- a/examples/other/node_clustering_tvgnn.py +++ b/examples/other/node_clustering_tvgnn.py @@ -53,6 +53,7 @@ y = np.argmax(Y, axis=-1) n_clust = Y.shape[-1] + ################################ # MODEL ################################ @@ -101,6 +102,7 @@ def call(self, inputs): model = ClusteringModel(aggr=MP_layers, pool=pool_layer) opt = tf.keras.optimizers.Adam(learning_rate=learning_rate) + ################################ # TRAINING ################################ diff --git a/spektral/layers/base.py b/spektral/layers/base.py index 29ac48a7..19f6bef5 100644 --- a/spektral/layers/base.py +++ b/spektral/layers/base.py @@ -118,7 +118,6 @@ def __init__( kernel_constraint=None, **kwargs ): - super().__init__(**kwargs) self.trainable_kernel = trainable_kernel self.activation = activations.get(activation) @@ -184,7 +183,6 @@ class MinkowskiProduct(Layer): """ def __init__(self, activation=None, **kwargs): - super().__init__(**kwargs) self.activation = activations.get(activation) diff --git a/spektral/layers/pooling/dmon_pool.py b/spektral/layers/pooling/dmon_pool.py index 70f4baa8..5a7621c3 100644 --- a/spektral/layers/pooling/dmon_pool.py +++ b/spektral/layers/pooling/dmon_pool.py @@ -162,7 +162,6 @@ def reduce_index(self, i, s, **kwargs): return i_pool def modularity_loss(self, a, s, a_pool): - if K.is_sparse(a): n_edges = tf.cast(len(a.values), dtype=s.dtype) diff --git a/spektral/layers/pooling/global_pool.py b/spektral/layers/pooling/global_pool.py index f12ed430..5453df28 100644 --- a/spektral/layers/pooling/global_pool.py +++ b/spektral/layers/pooling/global_pool.py @@ -8,7 +8,6 @@ class GlobalPool(Layer): def __init__(self, **kwargs): - super().__init__(**kwargs) self.supports_masking = True self.pooling_op = None From eebde5bdf7469cd2f91b4905cec056861ff4f416 Mon Sep 17 00:00:00 2001 From: Daniele Grattarola Date: Thu, 1 Jun 2023 22:49:00 +0100 Subject: [PATCH 7/8] Fix test_convolution.py and test workflow --- .github/workflows/test.yml | 2 +- tests/test_utils/test_convolution.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f2e0e804..41ae2581 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,7 +8,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: [3.8, 3.9, 3.10] + python-version: ["3.8", "3.9", "3.10", "3.11"] os: [ubuntu-latest, macos-latest, windows-latest] steps: diff --git a/tests/test_utils/test_convolution.py b/tests/test_utils/test_convolution.py index 8ca921fd..b3503ea6 100644 --- a/tests/test_utils/test_convolution.py +++ b/tests/test_utils/test_convolution.py @@ -6,7 +6,7 @@ from spektral.utils import convolution g = nx.generators.erdos_renyi_graph(10, 0.2) -adj_sp = nx.adj_matrix(g).astype("f") +adj_sp = nx.adjacenct_matrix(g).astype("f") adj = adj_sp.A.astype("f") degree = np.diag([d[1] for d in nx.degree(g)]) tol = 1e-6 From 2a51af90f2e1b60c8aa17063fd7ac60be8846152 Mon Sep 17 00:00:00 2001 From: Daniele Grattarola Date: Thu, 1 Jun 2023 22:52:02 +0100 Subject: [PATCH 8/8] typo --- tests/test_utils/test_convolution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils/test_convolution.py b/tests/test_utils/test_convolution.py index b3503ea6..ca4f7484 100644 --- a/tests/test_utils/test_convolution.py +++ b/tests/test_utils/test_convolution.py @@ -6,7 +6,7 @@ from spektral.utils import convolution g = nx.generators.erdos_renyi_graph(10, 0.2) -adj_sp = nx.adjacenct_matrix(g).astype("f") +adj_sp = nx.adjacency_matrix(g).astype("f") adj = adj_sp.A.astype("f") degree = np.diag([d[1] for d in nx.degree(g)]) tol = 1e-6