diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml
index a52d65ed..4756c2b3 100644
--- a/.github/workflows/examples.yml
+++ b/.github/workflows/examples.yml
@@ -9,10 +9,10 @@ jobs:
steps:
- uses: actions/checkout@v2
- - name: Set up Python 3.7
+ - name: Set up Python 3.11
uses: actions/setup-python@v2
with:
- python-version: 3.7
+ python-version: 3.11
- name: Install dependencies
run: |
pip install ogb matplotlib
diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index 9ffd0b24..583de2f9 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -9,10 +9,10 @@ jobs:
steps:
- uses: actions/checkout@v2
- - name: Set up Python 3.7
+ - name: Set up Python 3.11
uses: actions/setup-python@v2
with:
- python-version: 3.7
+ python-version: 3.11
- name: Lint Python code
run: |
pip install flake8
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 3af84537..41ae2581 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -8,7 +8,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
- python-version: [3.7, 3.8, 3.9]
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
os: [ubuntu-latest, macos-latest, windows-latest]
steps:
diff --git a/docs/autogen.py b/docs/autogen.py
index 4819a051..5ef3cfa3 100644
--- a/docs/autogen.py
+++ b/docs/autogen.py
@@ -40,6 +40,7 @@
layers.GCSConv,
layers.GINConv,
layers.GraphSageConv,
+ layers.GTVConv,
layers.TAGConv,
layers.XENetConv,
layers.GINConvBatch,
@@ -52,6 +53,7 @@
"methods": [],
"classes": [
layers.SRCPool,
+ layers.AsymCheegerCutPool,
layers.DiffPool,
layers.LaPool,
layers.MinCutPool,
diff --git a/examples/other/node_clustering_tvgnn.py b/examples/other/node_clustering_tvgnn.py
new file mode 100644
index 00000000..59ffdbaa
--- /dev/null
+++ b/examples/other/node_clustering_tvgnn.py
@@ -0,0 +1,135 @@
+"""
+This example implements the node clustering experiment on citation networks
+from the paper:
+
+Total Variation Graph Neural Networks (https://arxiv.org/abs/2211.06218)
+Jonas Berg Hansen and Filippo Maria Bianchi
+"""
+
+import numpy as np
+import tensorflow as tf
+from sklearn.metrics.cluster import (
+ completeness_score,
+ homogeneity_score,
+ normalized_mutual_info_score,
+)
+from tensorflow.keras import Model
+from tqdm import tqdm
+
+from spektral.datasets import DBLP
+from spektral.datasets.citation import Citation
+from spektral.layers import AsymCheegerCutPool, GTVConv
+from spektral.utils.sparse import sp_matrix_to_sp_tensor
+
+tf.random.set_seed(1)
+
+################################
+# CONFIG/HYPERPARAMETERS
+################################
+dataset_id = "cora"
+mp_channels = 512
+mp_layers = 2
+mp_activation = "elu"
+delta_coeff = 0.311
+epsilon = 1e-3
+mlp_hidden_channels = 256
+mlp_hidden_layers = 1
+mlp_activation = "relu"
+totvar_coeff = 0.785
+balance_coeff = 0.514
+learning_rate = 1e-3
+epochs = 500
+
+################################
+# LOAD DATASET
+################################
+if dataset_id in ["cora", "citeseer", "pubmed"]:
+ dataset = Citation(dataset_id, normalize_x=True)
+elif dataset_id == "dblp":
+ dataset = DBLP(normalize_x=True)
+X = dataset.graphs[0].x
+A = dataset.graphs[0].a
+Y = dataset.graphs[0].y
+y = np.argmax(Y, axis=-1)
+n_clust = Y.shape[-1]
+
+
+################################
+# MODEL
+################################
+class ClusteringModel(Model):
+ """
+ Defines the general model structure
+ """
+
+ def __init__(self, aggr, pool):
+ super().__init__()
+
+ self.mp = aggr
+ self.pool = pool
+
+ def call(self, inputs):
+ x, a = inputs
+
+ out = x
+ for _mp in self.mp:
+ out = _mp([out, a])
+
+ _, _, s_pool = self.pool([out, a])
+
+ return s_pool
+
+
+# Define the message-passing layers
+MP_layers = [
+ GTVConv(
+ mp_channels, delta_coeff=delta_coeff, epsilon=1e-3, activation=mp_activation
+ )
+ for _ in range(mp_layers)
+]
+
+# Define the pooling layer
+pool_layer = AsymCheegerCutPool(
+ n_clust,
+ mlp_hidden=[mlp_hidden_channels for _ in range(mlp_hidden_layers)],
+ mlp_activation=mlp_activation,
+ totvar_coeff=totvar_coeff,
+ balance_coeff=balance_coeff,
+ return_selection=True,
+)
+
+# Instantiate model and optimizer
+model = ClusteringModel(aggr=MP_layers, pool=pool_layer)
+opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
+
+
+################################
+# TRAINING
+################################
+@tf.function(input_signature=None)
+def train_step(model, inputs):
+ with tf.GradientTape() as tape:
+ _ = model(inputs, training=True)
+ loss = sum(model.losses)
+ gradients = tape.gradient(loss, model.trainable_variables)
+ opt.apply_gradients(zip(gradients, model.trainable_variables))
+ return model.losses
+
+
+A = sp_matrix_to_sp_tensor(A)
+inputs = [X, A]
+loss_history = []
+
+# Training loop
+for _ in tqdm(range(epochs)):
+ outs = train_step(model, inputs)
+
+################################
+# INFERENCE/RESULTS
+################################
+S_ = model(inputs, training=False)
+s_out = np.argmax(S_, axis=-1)
+nmi = normalized_mutual_info_score(y, s_out)
+hom = homogeneity_score(y, s_out)
+com = completeness_score(y, s_out)
+print("Homogeneity: {:.3f}; Completeness: {:.3f}; NMI: {:.3f}".format(hom, com, nmi))
diff --git a/spektral/layers/convolutional/__init__.py b/spektral/layers/convolutional/__init__.py
index 255088ec..e16f80fb 100644
--- a/spektral/layers/convolutional/__init__.py
+++ b/spektral/layers/convolutional/__init__.py
@@ -14,6 +14,7 @@
from .general_conv import GeneralConv
from .gin_conv import GINConv, GINConvBatch
from .graphsage_conv import GraphSageConv
+from .gtv_conv import GTVConv
from .message_passing import MessagePassing
from .tag_conv import TAGConv
from .xenet_conv import XENetConv, XENetConvBatch
diff --git a/spektral/layers/convolutional/gtv_conv.py b/spektral/layers/convolutional/gtv_conv.py
new file mode 100644
index 00000000..9b3c1a22
--- /dev/null
+++ b/spektral/layers/convolutional/gtv_conv.py
@@ -0,0 +1,212 @@
+import tensorflow as tf
+from tensorflow.keras import backend as K
+
+from spektral.layers import ops
+from spektral.layers.convolutional.conv import Conv
+
+
+class GTVConv(Conv):
+ r"""
+ A graph total variation convolutional layer (GTVConv) from the paper
+
+ > [Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218)
+ > Jonas Berg Hansen and Filippo Maria Bianchi
+
+ **Mode**: single, disjoint, batch.
+
+ This layer computes
+ $$
+ \X' = \sigma\left[\left(\I - \delta\L_\hat{\mathbf{\Gamma}}\right) \tilde{\X} \right]
+ $$
+ where
+ $$
+ \begin{align}
+ \tilde{\X} &= \X \W\\[5pt]
+ \L_\hat{\mathbf{\Gamma}} &= \D_\mathbf{\hat{\Gamma}} - \hat{\mathbf{\Gamma}}\\[5pt]
+ [\hat{\mathbf{\Gamma}}]_{ij} &= \frac{[\mathbf{A}]_{ij}}{\max\{||\tilde{\x}_i-\tilde{\x}_j||_1, \epsilon\}}\\
+ \end{align}
+ $$
+
+ **Input**
+
+ - Node features of shape `(batch, n_nodes, n_node_features)`;
+ - Adjacency matrix of shape `(batch, n_nodes, n_nodes)`;
+
+ **Output**
+
+ - Node features with the same shape as the input, but with the last
+ dimension changed to `channels`.
+
+ **Arguments**
+
+ - `channels`: number of output channels;
+ - `delta_coeff`: step size for gradient descent of GTV
+ - `epsilon`: small number used to numerically stabilize the computation of new adjacency weights
+ - `activation`: activation function;
+ - `use_bias`: bool, add a bias vector to the output;
+ - `kernel_initializer`: initializer for the weights;
+ - `bias_initializer`: initializer for the bias vector;
+ - `kernel_regularizer`: regularization applied to the weights;
+ - `bias_regularizer`: regularization applied to the bias vector;
+ - `activity_regularizer`: regularization applied to the output;
+ - `kernel_constraint`: constraint applied to the weights;
+ - `bias_constraint`: constraint applied to the bias vector.
+
+ """
+
+ def __init__(
+ self,
+ channels,
+ delta_coeff=1.0,
+ epsilon=0.001,
+ activation=None,
+ use_bias=True,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ bias_constraint=None,
+ **kwargs
+ ):
+ super().__init__(
+ activation=activation,
+ use_bias=use_bias,
+ kernel_initializer=kernel_initializer,
+ bias_initializer=bias_initializer,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ activity_regularizer=activity_regularizer,
+ kernel_constraint=kernel_constraint,
+ bias_constraint=bias_constraint,
+ **kwargs
+ )
+
+ self.channels = channels
+ self.delta_coeff = delta_coeff
+ self.epsilon = epsilon
+
+ def build(self, input_shape):
+ assert len(input_shape) >= 2
+ input_dim = input_shape[0][-1]
+ self.kernel = self.add_weight(
+ shape=(input_dim, self.channels),
+ initializer=self.kernel_initializer,
+ name="kernel",
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint,
+ )
+ if self.use_bias:
+ self.bias = self.add_weight(
+ shape=(self.channels,),
+ initializer=self.bias_initializer,
+ name="bias",
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint,
+ )
+ self.built = True
+
+ def call(self, inputs, mask=None):
+ x, a = inputs
+
+ mode = ops.autodetect_mode(x, a)
+
+ x = K.dot(x, self.kernel)
+
+ if mode == ops.modes.SINGLE:
+ output = self._call_single(x, a)
+
+ elif mode == ops.modes.BATCH:
+ output = self._call_batch(x, a)
+
+ if self.use_bias:
+ output = K.bias_add(output, self.bias)
+
+ if mask is not None:
+ output *= mask[0]
+
+ output = self.activation(output)
+
+ return output
+
+ def _call_single(self, x, a):
+ if K.is_sparse(a):
+ index_i = a.indices[:, 0]
+ index_j = a.indices[:, 1]
+
+ n_nodes = tf.shape(a, out_type=index_i.dtype)[0]
+
+ # Compute absolute differences between neighbouring nodes
+ abs_diff = tf.math.abs(
+ tf.transpose(tf.gather(x, index_i))
+ - tf.transpose(tf.gather(x, index_j))
+ )
+ abs_diff = tf.math.reduce_sum(abs_diff, axis=0)
+
+ # Compute new adjacency matrix
+ gamma = tf.sparse.map_values(
+ tf.multiply, a, 1 / tf.math.maximum(abs_diff, self.epsilon)
+ )
+
+ # Compute degree matrix from gamma matrix
+ d_gamma = tf.sparse.SparseTensor(
+ tf.stack([tf.range(n_nodes)] * 2, axis=1),
+ tf.sparse.reduce_sum(gamma, axis=-1),
+ [n_nodes, n_nodes],
+ )
+
+ # Compute laplcian: L = D_gamma - Gamma
+ l = tf.sparse.add(d_gamma, tf.sparse.map_values(tf.multiply, gamma, -1.0))
+
+ # Compute adjusted laplacian: L_adjusted = I - delta*L
+ l = tf.sparse.add(
+ tf.sparse.eye(n_nodes, dtype=x.dtype),
+ tf.sparse.map_values(tf.multiply, l, -self.delta_coeff),
+ )
+
+ # Aggregate features with adjusted laplacian
+ output = ops.modal_dot(l, x)
+
+ else:
+ n_nodes = tf.shape(a)[-1]
+
+ abs_diff = tf.math.abs(x[:, tf.newaxis, :] - x)
+ abs_diff = tf.reduce_sum(abs_diff, axis=-1)
+
+ gamma = a / tf.math.maximum(abs_diff, self.epsilon)
+
+ degrees = tf.math.reduce_sum(gamma, axis=-1)
+ l = -gamma
+ l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma))
+ l = tf.eye(n_nodes, dtype=x.dtype) - self.delta_coeff * l
+
+ output = tf.matmul(l, x)
+
+ return output
+
+ def _call_batch(self, x, a):
+ n_nodes = tf.shape(a)[-1]
+
+ abs_diff = tf.reduce_sum(
+ tf.math.abs(tf.expand_dims(x, 2) - tf.expand_dims(x, 1)), axis=-1
+ )
+
+ gamma = a / tf.math.maximum(abs_diff, self.epsilon)
+
+ degrees = tf.math.reduce_sum(gamma, axis=-1)
+ l = -gamma
+ l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma))
+ l = tf.eye(n_nodes, dtype=x.dtype) - self.delta_coeff * l
+
+ output = tf.matmul(l, x)
+
+ return output
+
+ @property
+ def config(self):
+ return {
+ "channels": self.channels,
+ "delta_coeff": self.delta_coeff,
+ "epsilon": self.epsilon,
+ }
diff --git a/spektral/layers/pooling/__init__.py b/spektral/layers/pooling/__init__.py
index 2f21c014..ec138e6f 100644
--- a/spektral/layers/pooling/__init__.py
+++ b/spektral/layers/pooling/__init__.py
@@ -1,3 +1,4 @@
+from .asym_cheeger_cut_pool import AsymCheegerCutPool
from .diff_pool import DiffPool
from .dmon_pool import DMoNPool
from .global_pool import (
diff --git a/spektral/layers/pooling/asym_cheeger_cut_pool.py b/spektral/layers/pooling/asym_cheeger_cut_pool.py
new file mode 100644
index 00000000..fe59e9af
--- /dev/null
+++ b/spektral/layers/pooling/asym_cheeger_cut_pool.py
@@ -0,0 +1,237 @@
+import tensorflow as tf
+import tensorflow.keras.backend as K
+from tensorflow.keras import Sequential
+from tensorflow.keras.layers import Dense
+
+from spektral.layers import ops
+from spektral.layers.pooling.src import SRCPool
+
+
+class AsymCheegerCutPool(SRCPool):
+ r"""
+ An Asymmetric Cheeger Cut Pooling layer from the paper
+ > [Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218)
+ > Jonas Berg Hansen and Filippo Maria Bianchi
+
+ **Mode**: single, batch.
+
+ This layer learns a soft clustering of the input graph as follows:
+ $$
+ \begin{align}
+ \S &= \textrm{MLP}(\X); \\
+ \X' &= \S^\top \X \\
+ \A' &= \S^\top \A \S; \\
+ \end{align}
+ $$
+ where \(\textrm{MLP}\) is a multi-layer perceptron with softmax output.
+
+ The layer includes two auxiliary loss terms/components:
+ A graph total variation component given by
+ $$
+ L_\text{GTV} = \frac{1}{2E} \sum_{k=1}^K \sum_{i=1}^N \sum_{j=i}^N a_{i,j} |s_{i,k} - s_{j,k}|,
+ $$
+ where \(E\) is the number of edges/links, \(K\) is the number of clusters or output nodes, and \(N\) is the number of nodes.
+
+ An asymmetrical norm component given by
+ $$
+ L_\text{AN} = \frac{N(K - 1) - \sum_{k=1}^K ||\s_{:,k} - \textrm{quant}_{K-1} (\s_{:,k})||_{1, K-1}}{N(K-1)},
+ $$
+
+ The layer can be used without a supervised loss to compute node clustering by
+ minimizing the two auxiliary losses.
+
+ **Input**
+
+ - Node features of shape `(batch, n_nodes_in, n_node_features)`;
+ - Adjacency matrix of shape `(batch, n_nodes_in, n_nodes_in)`;
+
+ **Output**
+
+ - Reduced node features of shape `(batch, n_nodes_out, n_node_features)`;
+ - If `return_selection=True`, the selection matrix of shape
+ `(batch, n_nodes_in, n_nodes_out)`.
+
+ **Arguments**
+
+ - `k`: number of output nodes;
+ - `mlp_hidden`: list of integers, number of hidden units for each hidden layer in
+ the MLP used to compute cluster assignments (if `None`, the MLP has only one output
+ layer);
+ - `mlp_activation`: activation for the MLP layers;
+ - `totvar_coeff`: coefficient for graph total variation loss component;
+ - `balance_coeff`: coefficient for asymmetric norm loss component;
+ - `return_selection`: boolean, whether to return the selection matrix;
+ - `use_bias`: use bias in the MLP;
+ - `kernel_initializer`: initializer for the weights of the MLP;
+ - `bias_regularizer`: regularization applied to the bias of the MLP;
+ - `kernel_constraint`: constraint applied to the weights of the MLP;
+ - `bias_constraint`: constraint applied to the bias of the MLP;
+ """
+
+ def __init__(
+ self,
+ k,
+ mlp_hidden=None,
+ mlp_activation="relu",
+ totvar_coeff=1.0,
+ balance_coeff=1.0,
+ return_selection=False,
+ use_bias=True,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ kernel_constraint=None,
+ bias_constraint=None,
+ **kwargs
+ ):
+ super().__init__(
+ k=k,
+ mlp_hidden=mlp_hidden,
+ mlp_activation=mlp_activation,
+ return_selection=return_selection,
+ use_bias=use_bias,
+ kernel_initializer=kernel_initializer,
+ bias_initializer=bias_initializer,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ kernel_constraint=kernel_constraint,
+ bias_constraint=bias_constraint,
+ **kwargs
+ )
+
+ self.k = k
+ self.mlp_hidden = mlp_hidden if mlp_hidden else []
+ self.mlp_activation = mlp_activation
+ self.totvar_coeff = totvar_coeff
+ self.balance_coeff = balance_coeff
+
+ def build(self, input_shape):
+ layer_kwargs = dict(
+ kernel_initializer=self.kernel_initializer,
+ bias_initializer=self.bias_initializer,
+ kernel_regularizer=self.kernel_regularizer,
+ bias_regularizer=self.bias_regularizer,
+ kernel_constraint=self.kernel_constraint,
+ bias_constraint=self.bias_constraint,
+ )
+ self.mlp = Sequential(
+ [
+ Dense(channels, self.mlp_activation, **layer_kwargs)
+ for channels in self.mlp_hidden
+ ]
+ + [Dense(self.k, "softmax", **layer_kwargs)]
+ )
+
+ super().build(input_shape)
+
+ def call(self, inputs, mask=None):
+ x, a, i = self.get_inputs(inputs)
+ return self.pool(x, a, i, mask=mask)
+
+ def select(self, x, a, i, mask=None):
+ s = self.mlp(x)
+ if mask is not None:
+ s *= mask[0]
+
+ # Total variation loss
+ tv_loss = self.totvar_loss(a, s)
+ if K.ndim(a) == 3:
+ tv_loss = K.mean(tv_loss)
+ self.add_loss(self.totvar_coeff * tv_loss)
+
+ # Asymmetric l1-norm loss
+ bal_loss = self.balance_loss(s)
+ if K.ndim(a) == 3:
+ bal_loss = K.mean(bal_loss)
+ self.add_loss(self.balance_coeff * bal_loss)
+
+ return s
+
+ def reduce(self, x, s, **kwargs):
+ return ops.modal_dot(s, x, transpose_a=True)
+
+ def connect(self, a, s, **kwargs):
+ a_pool = ops.matmul_at_b_a(s, a)
+
+ return a_pool
+
+ def reduce_index(self, i, s, **kwargs):
+ i_mean = tf.math.segment_mean(i, i)
+ i_pool = ops.repeat(i_mean, tf.ones_like(i_mean) * self.k)
+
+ return i_pool
+
+ def totvar_loss(self, a, s):
+ if K.is_sparse(a):
+ index_i = a.indices[:, 0]
+ index_j = a.indices[:, 1]
+
+ n_edges = tf.cast(len(a.values), dtype=s.dtype)
+
+ loss = tf.math.reduce_sum(
+ a.values[:, tf.newaxis]
+ * tf.math.abs(tf.gather(s, index_i) - tf.gather(s, index_j)),
+ axis=(-2, -1),
+ )
+
+ else:
+ n_edges = tf.cast(tf.math.count_nonzero(a, axis=(-2, -1)), dtype=s.dtype)
+ n_nodes = tf.shape(a)[-1]
+ if K.ndim(a) == 3:
+ loss = tf.math.reduce_sum(
+ a
+ * tf.math.reduce_sum(
+ tf.math.abs(
+ s[:, tf.newaxis, ...]
+ - tf.repeat(s[..., tf.newaxis, :], n_nodes, axis=-2)
+ ),
+ axis=-1,
+ ),
+ axis=(-2, -1),
+ )
+ else:
+ loss = tf.math.reduce_sum(
+ a
+ * tf.math.reduce_sum(
+ tf.math.abs(
+ s - tf.repeat(s[..., tf.newaxis, :], n_nodes, axis=-2)
+ ),
+ axis=-1,
+ ),
+ axis=(-2, -1),
+ )
+
+ loss *= 1 / (2 * n_edges)
+
+ return loss
+
+ def balance_loss(self, s):
+ n_nodes = tf.cast(tf.shape(s, out_type=tf.int32)[-2], s.dtype)
+
+ # k-quantile
+ idx = tf.cast(tf.math.floor(n_nodes / self.k) + 1, dtype=tf.int32)
+ med = tf.math.top_k(tf.linalg.matrix_transpose(s), k=idx).values[..., -1]
+ # Asymmetric l1-norm
+ if K.ndim(s) == 2:
+ loss = s - med
+ else:
+ loss = s - med[:, tf.newaxis, ...]
+ loss = (tf.cast(loss >= 0, loss.dtype) * (self.k - 1) * loss) + (
+ tf.cast(loss < 0, loss.dtype) * loss * -1.0
+ )
+ loss = tf.math.reduce_sum(loss, axis=(-2, -1))
+ loss = 1 / (n_nodes * (self.k - 1)) * (n_nodes * (self.k - 1) - loss)
+
+ return loss
+
+ def get_config(self):
+ config = {
+ "k": self.k,
+ "mlp_hidden": self.mlp_hidden,
+ "mlp_activation": self.mlp_activation,
+ "totvar_coeff": self.totvar_coeff,
+ "balance_coeff": self.balance_coeff,
+ }
+ base_config = super().get_config()
+ return {**base_config, **config}
diff --git a/tests/test_layers/convolutional/test_gtv_conv.py b/tests/test_layers/convolutional/test_gtv_conv.py
new file mode 100644
index 00000000..ef498348
--- /dev/null
+++ b/tests/test_layers/convolutional/test_gtv_conv.py
@@ -0,0 +1,21 @@
+from core import MODES, run_layer
+
+from spektral import layers
+
+config = {
+ "layer": layers.GTVConv,
+ "modes": [MODES["SINGLE"], MODES["BATCH"]],
+ "kwargs": {
+ "channels": 8,
+ "delta_coeff": 1.0,
+ "epsilon": 0.001,
+ "activation": "relu",
+ },
+ "dense": True,
+ "sparse": True,
+ "edges": False,
+}
+
+
+def test_layer():
+ run_layer(config)
diff --git a/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py b/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py
new file mode 100644
index 00000000..3baf3be5
--- /dev/null
+++ b/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py
@@ -0,0 +1,20 @@
+from spektral import layers
+from tests.test_layers.pooling.core import MODES, run_layer
+
+config = {
+ "layer": layers.AsymCheegerCutPool,
+ "modes": [MODES["SINGLE"], MODES["BATCH"]],
+ "kwargs": {
+ "k": 5,
+ "return_selection": True,
+ "mlp_hidden": [32],
+ "totvar_coeff": 1.0,
+ "balance_coeff": 1.0,
+ },
+ "dense": True,
+ "sparse": True,
+}
+
+
+def test_layer():
+ run_layer(config)