Skip to content

Commit aa7866a

Browse files
authoredJun 1, 2023
Merge pull request #433 from JonasBergHansen/develop
Implementation of Total Variation Graph Neural Networks
2 parents 40d7541 + 2a51af9 commit aa7866a

File tree

16 files changed

+636
-10
lines changed

16 files changed

+636
-10
lines changed
 

‎.github/workflows/examples.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ jobs:
99

1010
steps:
1111
- uses: actions/checkout@v2
12-
- name: Set up Python 3.6
12+
- name: Set up Python 3.11
1313
uses: actions/setup-python@v2
1414
with:
15-
python-version: 3.6
15+
python-version: 3.11
1616
- name: Install dependencies
1717
run: |
1818
pip install ogb matplotlib

‎.github/workflows/style_check.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ jobs:
99

1010
steps:
1111
- uses: actions/checkout@v2
12-
- name: Set up Python 3.6
12+
- name: Set up Python 3.11
1313
uses: actions/setup-python@v2
1414
with:
15-
python-version: 3.6
15+
python-version: 3.11
1616
- name: Lint Python code
1717
run: |
1818
pip install flake8

‎.github/workflows/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
runs-on: ${{ matrix.os }}
99
strategy:
1010
matrix:
11-
python-version: [3.7, 3.8, 3.9]
11+
python-version: ["3.8", "3.9", "3.10", "3.11"]
1212
os: [ubuntu-latest, macos-latest, windows-latest]
1313

1414
steps:

‎docs/autogen.py

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
layers.GCSConv,
4141
layers.GINConv,
4242
layers.GraphSageConv,
43+
layers.GTVConv,
4344
layers.TAGConv,
4445
layers.XENetConv,
4546
],
@@ -50,6 +51,7 @@
5051
"methods": [],
5152
"classes": [
5253
layers.SRCPool,
54+
layers.AsymCheegerCutPool,
5355
layers.DiffPool,
5456
layers.LaPool,
5557
layers.MinCutPool,

‎examples/graph_prediction/qm9_ecc.py

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
loader_tr = DisjointLoader(dataset_tr, batch_size=batch_size, epochs=epochs)
4141
loader_te = DisjointLoader(dataset_te, batch_size=batch_size, epochs=1)
4242

43+
4344
################################################################################
4445
# Build model
4546
################################################################################
+135
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
This example implements the node clustering experiment on citation networks
3+
from the paper:
4+
5+
Total Variation Graph Neural Networks (https://arxiv.org/abs/2211.06218)
6+
Jonas Berg Hansen and Filippo Maria Bianchi
7+
"""
8+
9+
import numpy as np
10+
import tensorflow as tf
11+
from sklearn.metrics.cluster import (
12+
completeness_score,
13+
homogeneity_score,
14+
normalized_mutual_info_score,
15+
)
16+
from tensorflow.keras import Model
17+
from tqdm import tqdm
18+
19+
from spektral.datasets import DBLP
20+
from spektral.datasets.citation import Citation
21+
from spektral.layers import AsymCheegerCutPool, GTVConv
22+
from spektral.utils.sparse import sp_matrix_to_sp_tensor
23+
24+
tf.random.set_seed(1)
25+
26+
################################
27+
# CONFIG/HYPERPARAMETERS
28+
################################
29+
dataset_id = "cora"
30+
mp_channels = 512
31+
mp_layers = 2
32+
mp_activation = "elu"
33+
delta_coeff = 0.311
34+
epsilon = 1e-3
35+
mlp_hidden_channels = 256
36+
mlp_hidden_layers = 1
37+
mlp_activation = "relu"
38+
totvar_coeff = 0.785
39+
balance_coeff = 0.514
40+
learning_rate = 1e-3
41+
epochs = 500
42+
43+
################################
44+
# LOAD DATASET
45+
################################
46+
if dataset_id in ["cora", "citeseer", "pubmed"]:
47+
dataset = Citation(dataset_id, normalize_x=True)
48+
elif dataset_id == "dblp":
49+
dataset = DBLP(normalize_x=True)
50+
X = dataset.graphs[0].x
51+
A = dataset.graphs[0].a
52+
Y = dataset.graphs[0].y
53+
y = np.argmax(Y, axis=-1)
54+
n_clust = Y.shape[-1]
55+
56+
57+
################################
58+
# MODEL
59+
################################
60+
class ClusteringModel(Model):
61+
"""
62+
Defines the general model structure
63+
"""
64+
65+
def __init__(self, aggr, pool):
66+
super().__init__()
67+
68+
self.mp = aggr
69+
self.pool = pool
70+
71+
def call(self, inputs):
72+
x, a = inputs
73+
74+
out = x
75+
for _mp in self.mp:
76+
out = _mp([out, a])
77+
78+
_, _, s_pool = self.pool([out, a])
79+
80+
return s_pool
81+
82+
83+
# Define the message-passing layers
84+
MP_layers = [
85+
GTVConv(
86+
mp_channels, delta_coeff=delta_coeff, epsilon=1e-3, activation=mp_activation
87+
)
88+
for _ in range(mp_layers)
89+
]
90+
91+
# Define the pooling layer
92+
pool_layer = AsymCheegerCutPool(
93+
n_clust,
94+
mlp_hidden=[mlp_hidden_channels for _ in range(mlp_hidden_layers)],
95+
mlp_activation=mlp_activation,
96+
totvar_coeff=totvar_coeff,
97+
balance_coeff=balance_coeff,
98+
return_selection=True,
99+
)
100+
101+
# Instantiate model and optimizer
102+
model = ClusteringModel(aggr=MP_layers, pool=pool_layer)
103+
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
104+
105+
106+
################################
107+
# TRAINING
108+
################################
109+
@tf.function(input_signature=None)
110+
def train_step(model, inputs):
111+
with tf.GradientTape() as tape:
112+
_ = model(inputs, training=True)
113+
loss = sum(model.losses)
114+
gradients = tape.gradient(loss, model.trainable_variables)
115+
opt.apply_gradients(zip(gradients, model.trainable_variables))
116+
return model.losses
117+
118+
119+
A = sp_matrix_to_sp_tensor(A)
120+
inputs = [X, A]
121+
loss_history = []
122+
123+
# Training loop
124+
for _ in tqdm(range(epochs)):
125+
outs = train_step(model, inputs)
126+
127+
################################
128+
# INFERENCE/RESULTS
129+
################################
130+
S_ = model(inputs, training=False)
131+
s_out = np.argmax(S_, axis=-1)
132+
nmi = normalized_mutual_info_score(y, s_out)
133+
hom = homogeneity_score(y, s_out)
134+
com = completeness_score(y, s_out)
135+
print("Homogeneity: {:.3f}; Completeness: {:.3f}; NMI: {:.3f}".format(hom, com, nmi))

‎spektral/layers/base.py

-2
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def __init__(
118118
kernel_constraint=None,
119119
**kwargs
120120
):
121-
122121
super().__init__(**kwargs)
123122
self.trainable_kernel = trainable_kernel
124123
self.activation = activations.get(activation)
@@ -184,7 +183,6 @@ class MinkowskiProduct(Layer):
184183
"""
185184

186185
def __init__(self, activation=None, **kwargs):
187-
188186
super().__init__(**kwargs)
189187
self.activation = activations.get(activation)
190188

‎spektral/layers/convolutional/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .general_conv import GeneralConv
1515
from .gin_conv import GINConv
1616
from .graphsage_conv import GraphSageConv
17+
from .gtv_conv import GTVConv
1718
from .message_passing import MessagePassing
1819
from .tag_conv import TAGConv
1920
from .xenet_conv import XENetConv, XENetDenseConv
+212
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import tensorflow as tf
2+
from tensorflow.keras import backend as K
3+
4+
from spektral.layers import ops
5+
from spektral.layers.convolutional.conv import Conv
6+
7+
8+
class GTVConv(Conv):
9+
r"""
10+
A graph total variation convolutional layer (GTVConv) from the paper
11+
12+
> [Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218)<br>
13+
> Jonas Berg Hansen and Filippo Maria Bianchi
14+
15+
**Mode**: single, disjoint, batch.
16+
17+
This layer computes
18+
$$
19+
\X' = \sigma\left[\left(\I - \delta\L_\hat{\mathbf{\Gamma}}\right) \tilde{\X} \right]
20+
$$
21+
where
22+
$$
23+
\begin{align}
24+
\tilde{\X} &= \X \W\\[5pt]
25+
\L_\hat{\mathbf{\Gamma}} &= \D_\mathbf{\hat{\Gamma}} - \hat{\mathbf{\Gamma}}\\[5pt]
26+
[\hat{\mathbf{\Gamma}}]_{ij} &= \frac{[\mathbf{A}]_{ij}}{\max\{||\tilde{\x}_i-\tilde{\x}_j||_1, \epsilon\}}\\
27+
\end{align}
28+
$$
29+
30+
**Input**
31+
32+
- Node features of shape `(batch, n_nodes, n_node_features)`;
33+
- Adjacency matrix of shape `(batch, n_nodes, n_nodes)`;
34+
35+
**Output**
36+
37+
- Node features with the same shape as the input, but with the last
38+
dimension changed to `channels`.
39+
40+
**Arguments**
41+
42+
- `channels`: number of output channels;
43+
- `delta_coeff`: step size for gradient descent of GTV
44+
- `epsilon`: small number used to numerically stabilize the computation of new adjacency weights
45+
- `activation`: activation function;
46+
- `use_bias`: bool, add a bias vector to the output;
47+
- `kernel_initializer`: initializer for the weights;
48+
- `bias_initializer`: initializer for the bias vector;
49+
- `kernel_regularizer`: regularization applied to the weights;
50+
- `bias_regularizer`: regularization applied to the bias vector;
51+
- `activity_regularizer`: regularization applied to the output;
52+
- `kernel_constraint`: constraint applied to the weights;
53+
- `bias_constraint`: constraint applied to the bias vector.
54+
55+
"""
56+
57+
def __init__(
58+
self,
59+
channels,
60+
delta_coeff=1.0,
61+
epsilon=0.001,
62+
activation=None,
63+
use_bias=True,
64+
kernel_initializer="glorot_uniform",
65+
bias_initializer="zeros",
66+
kernel_regularizer=None,
67+
bias_regularizer=None,
68+
activity_regularizer=None,
69+
kernel_constraint=None,
70+
bias_constraint=None,
71+
**kwargs
72+
):
73+
super().__init__(
74+
activation=activation,
75+
use_bias=use_bias,
76+
kernel_initializer=kernel_initializer,
77+
bias_initializer=bias_initializer,
78+
kernel_regularizer=kernel_regularizer,
79+
bias_regularizer=bias_regularizer,
80+
activity_regularizer=activity_regularizer,
81+
kernel_constraint=kernel_constraint,
82+
bias_constraint=bias_constraint,
83+
**kwargs
84+
)
85+
86+
self.channels = channels
87+
self.delta_coeff = delta_coeff
88+
self.epsilon = epsilon
89+
90+
def build(self, input_shape):
91+
assert len(input_shape) >= 2
92+
input_dim = input_shape[0][-1]
93+
self.kernel = self.add_weight(
94+
shape=(input_dim, self.channels),
95+
initializer=self.kernel_initializer,
96+
name="kernel",
97+
regularizer=self.kernel_regularizer,
98+
constraint=self.kernel_constraint,
99+
)
100+
if self.use_bias:
101+
self.bias = self.add_weight(
102+
shape=(self.channels,),
103+
initializer=self.bias_initializer,
104+
name="bias",
105+
regularizer=self.bias_regularizer,
106+
constraint=self.bias_constraint,
107+
)
108+
self.built = True
109+
110+
def call(self, inputs, mask=None):
111+
x, a = inputs
112+
113+
mode = ops.autodetect_mode(x, a)
114+
115+
x = K.dot(x, self.kernel)
116+
117+
if mode == ops.modes.SINGLE:
118+
output = self._call_single(x, a)
119+
120+
elif mode == ops.modes.BATCH:
121+
output = self._call_batch(x, a)
122+
123+
if self.use_bias:
124+
output = K.bias_add(output, self.bias)
125+
126+
if mask is not None:
127+
output *= mask[0]
128+
129+
output = self.activation(output)
130+
131+
return output
132+
133+
def _call_single(self, x, a):
134+
if K.is_sparse(a):
135+
index_i = a.indices[:, 0]
136+
index_j = a.indices[:, 1]
137+
138+
n_nodes = tf.shape(a, out_type=index_i.dtype)[0]
139+
140+
# Compute absolute differences between neighbouring nodes
141+
abs_diff = tf.math.abs(
142+
tf.transpose(tf.gather(x, index_i))
143+
- tf.transpose(tf.gather(x, index_j))
144+
)
145+
abs_diff = tf.math.reduce_sum(abs_diff, axis=0)
146+
147+
# Compute new adjacency matrix
148+
gamma = tf.sparse.map_values(
149+
tf.multiply, a, 1 / tf.math.maximum(abs_diff, self.epsilon)
150+
)
151+
152+
# Compute degree matrix from gamma matrix
153+
d_gamma = tf.sparse.SparseTensor(
154+
tf.stack([tf.range(n_nodes)] * 2, axis=1),
155+
tf.sparse.reduce_sum(gamma, axis=-1),
156+
[n_nodes, n_nodes],
157+
)
158+
159+
# Compute laplcian: L = D_gamma - Gamma
160+
l = tf.sparse.add(d_gamma, tf.sparse.map_values(tf.multiply, gamma, -1.0))
161+
162+
# Compute adjusted laplacian: L_adjusted = I - delta*L
163+
l = tf.sparse.add(
164+
tf.sparse.eye(n_nodes, dtype=x.dtype),
165+
tf.sparse.map_values(tf.multiply, l, -self.delta_coeff),
166+
)
167+
168+
# Aggregate features with adjusted laplacian
169+
output = ops.modal_dot(l, x)
170+
171+
else:
172+
n_nodes = tf.shape(a)[-1]
173+
174+
abs_diff = tf.math.abs(x[:, tf.newaxis, :] - x)
175+
abs_diff = tf.reduce_sum(abs_diff, axis=-1)
176+
177+
gamma = a / tf.math.maximum(abs_diff, self.epsilon)
178+
179+
degrees = tf.math.reduce_sum(gamma, axis=-1)
180+
l = -gamma
181+
l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma))
182+
l = tf.eye(n_nodes, dtype=x.dtype) - self.delta_coeff * l
183+
184+
output = tf.matmul(l, x)
185+
186+
return output
187+
188+
def _call_batch(self, x, a):
189+
n_nodes = tf.shape(a)[-1]
190+
191+
abs_diff = tf.reduce_sum(
192+
tf.math.abs(tf.expand_dims(x, 2) - tf.expand_dims(x, 1)), axis=-1
193+
)
194+
195+
gamma = a / tf.math.maximum(abs_diff, self.epsilon)
196+
197+
degrees = tf.math.reduce_sum(gamma, axis=-1)
198+
l = -gamma
199+
l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma))
200+
l = tf.eye(n_nodes, dtype=x.dtype) - self.delta_coeff * l
201+
202+
output = tf.matmul(l, x)
203+
204+
return output
205+
206+
@property
207+
def config(self):
208+
return {
209+
"channels": self.channels,
210+
"delta_coeff": self.delta_coeff,
211+
"epsilon": self.epsilon,
212+
}

‎spektral/layers/pooling/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .asym_cheeger_cut_pool import AsymCheegerCutPool
12
from .diff_pool import DiffPool
23
from .dmon_pool import DMoNPool
34
from .global_pool import (
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import tensorflow as tf
2+
import tensorflow.keras.backend as K
3+
from tensorflow.keras import Sequential
4+
from tensorflow.keras.layers import Dense
5+
6+
from spektral.layers import ops
7+
from spektral.layers.pooling.src import SRCPool
8+
9+
10+
class AsymCheegerCutPool(SRCPool):
11+
r"""
12+
An Asymmetric Cheeger Cut Pooling layer from the paper
13+
> [Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218)<br>
14+
> Jonas Berg Hansen and Filippo Maria Bianchi
15+
16+
**Mode**: single, batch.
17+
18+
This layer learns a soft clustering of the input graph as follows:
19+
$$
20+
\begin{align}
21+
\S &= \textrm{MLP}(\X); \\
22+
\X' &= \S^\top \X \\
23+
\A' &= \S^\top \A \S; \\
24+
\end{align}
25+
$$
26+
where \(\textrm{MLP}\) is a multi-layer perceptron with softmax output.
27+
28+
The layer includes two auxiliary loss terms/components:
29+
A graph total variation component given by
30+
$$
31+
L_\text{GTV} = \frac{1}{2E} \sum_{k=1}^K \sum_{i=1}^N \sum_{j=i}^N a_{i,j} |s_{i,k} - s_{j,k}|,
32+
$$
33+
where \(E\) is the number of edges/links, \(K\) is the number of clusters or output nodes, and \(N\) is the number of nodes.
34+
35+
An asymmetrical norm component given by
36+
$$
37+
L_\text{AN} = \frac{N(K - 1) - \sum_{k=1}^K ||\s_{:,k} - \textrm{quant}_{K-1} (\s_{:,k})||_{1, K-1}}{N(K-1)},
38+
$$
39+
40+
The layer can be used without a supervised loss to compute node clustering by
41+
minimizing the two auxiliary losses.
42+
43+
**Input**
44+
45+
- Node features of shape `(batch, n_nodes_in, n_node_features)`;
46+
- Adjacency matrix of shape `(batch, n_nodes_in, n_nodes_in)`;
47+
48+
**Output**
49+
50+
- Reduced node features of shape `(batch, n_nodes_out, n_node_features)`;
51+
- If `return_selection=True`, the selection matrix of shape
52+
`(batch, n_nodes_in, n_nodes_out)`.
53+
54+
**Arguments**
55+
56+
- `k`: number of output nodes;
57+
- `mlp_hidden`: list of integers, number of hidden units for each hidden layer in
58+
the MLP used to compute cluster assignments (if `None`, the MLP has only one output
59+
layer);
60+
- `mlp_activation`: activation for the MLP layers;
61+
- `totvar_coeff`: coefficient for graph total variation loss component;
62+
- `balance_coeff`: coefficient for asymmetric norm loss component;
63+
- `return_selection`: boolean, whether to return the selection matrix;
64+
- `use_bias`: use bias in the MLP;
65+
- `kernel_initializer`: initializer for the weights of the MLP;
66+
- `bias_regularizer`: regularization applied to the bias of the MLP;
67+
- `kernel_constraint`: constraint applied to the weights of the MLP;
68+
- `bias_constraint`: constraint applied to the bias of the MLP;
69+
"""
70+
71+
def __init__(
72+
self,
73+
k,
74+
mlp_hidden=None,
75+
mlp_activation="relu",
76+
totvar_coeff=1.0,
77+
balance_coeff=1.0,
78+
return_selection=False,
79+
use_bias=True,
80+
kernel_initializer="glorot_uniform",
81+
bias_initializer="zeros",
82+
kernel_regularizer=None,
83+
bias_regularizer=None,
84+
kernel_constraint=None,
85+
bias_constraint=None,
86+
**kwargs
87+
):
88+
super().__init__(
89+
k=k,
90+
mlp_hidden=mlp_hidden,
91+
mlp_activation=mlp_activation,
92+
return_selection=return_selection,
93+
use_bias=use_bias,
94+
kernel_initializer=kernel_initializer,
95+
bias_initializer=bias_initializer,
96+
kernel_regularizer=kernel_regularizer,
97+
bias_regularizer=bias_regularizer,
98+
kernel_constraint=kernel_constraint,
99+
bias_constraint=bias_constraint,
100+
**kwargs
101+
)
102+
103+
self.k = k
104+
self.mlp_hidden = mlp_hidden if mlp_hidden else []
105+
self.mlp_activation = mlp_activation
106+
self.totvar_coeff = totvar_coeff
107+
self.balance_coeff = balance_coeff
108+
109+
def build(self, input_shape):
110+
layer_kwargs = dict(
111+
kernel_initializer=self.kernel_initializer,
112+
bias_initializer=self.bias_initializer,
113+
kernel_regularizer=self.kernel_regularizer,
114+
bias_regularizer=self.bias_regularizer,
115+
kernel_constraint=self.kernel_constraint,
116+
bias_constraint=self.bias_constraint,
117+
)
118+
self.mlp = Sequential(
119+
[
120+
Dense(channels, self.mlp_activation, **layer_kwargs)
121+
for channels in self.mlp_hidden
122+
]
123+
+ [Dense(self.k, "softmax", **layer_kwargs)]
124+
)
125+
126+
super().build(input_shape)
127+
128+
def call(self, inputs, mask=None):
129+
x, a, i = self.get_inputs(inputs)
130+
return self.pool(x, a, i, mask=mask)
131+
132+
def select(self, x, a, i, mask=None):
133+
s = self.mlp(x)
134+
if mask is not None:
135+
s *= mask[0]
136+
137+
# Total variation loss
138+
tv_loss = self.totvar_loss(a, s)
139+
if K.ndim(a) == 3:
140+
tv_loss = K.mean(tv_loss)
141+
self.add_loss(self.totvar_coeff * tv_loss)
142+
143+
# Asymmetric l1-norm loss
144+
bal_loss = self.balance_loss(s)
145+
if K.ndim(a) == 3:
146+
bal_loss = K.mean(bal_loss)
147+
self.add_loss(self.balance_coeff * bal_loss)
148+
149+
return s
150+
151+
def reduce(self, x, s, **kwargs):
152+
return ops.modal_dot(s, x, transpose_a=True)
153+
154+
def connect(self, a, s, **kwargs):
155+
a_pool = ops.matmul_at_b_a(s, a)
156+
157+
return a_pool
158+
159+
def reduce_index(self, i, s, **kwargs):
160+
i_mean = tf.math.segment_mean(i, i)
161+
i_pool = ops.repeat(i_mean, tf.ones_like(i_mean) * self.k)
162+
163+
return i_pool
164+
165+
def totvar_loss(self, a, s):
166+
if K.is_sparse(a):
167+
index_i = a.indices[:, 0]
168+
index_j = a.indices[:, 1]
169+
170+
n_edges = tf.cast(len(a.values), dtype=s.dtype)
171+
172+
loss = tf.math.reduce_sum(
173+
a.values[:, tf.newaxis]
174+
* tf.math.abs(tf.gather(s, index_i) - tf.gather(s, index_j)),
175+
axis=(-2, -1),
176+
)
177+
178+
else:
179+
n_edges = tf.cast(tf.math.count_nonzero(a, axis=(-2, -1)), dtype=s.dtype)
180+
n_nodes = tf.shape(a)[-1]
181+
if K.ndim(a) == 3:
182+
loss = tf.math.reduce_sum(
183+
a
184+
* tf.math.reduce_sum(
185+
tf.math.abs(
186+
s[:, tf.newaxis, ...]
187+
- tf.repeat(s[..., tf.newaxis, :], n_nodes, axis=-2)
188+
),
189+
axis=-1,
190+
),
191+
axis=(-2, -1),
192+
)
193+
else:
194+
loss = tf.math.reduce_sum(
195+
a
196+
* tf.math.reduce_sum(
197+
tf.math.abs(
198+
s - tf.repeat(s[..., tf.newaxis, :], n_nodes, axis=-2)
199+
),
200+
axis=-1,
201+
),
202+
axis=(-2, -1),
203+
)
204+
205+
loss *= 1 / (2 * n_edges)
206+
207+
return loss
208+
209+
def balance_loss(self, s):
210+
n_nodes = tf.cast(tf.shape(s, out_type=tf.int32)[-2], s.dtype)
211+
212+
# k-quantile
213+
idx = tf.cast(tf.math.floor(n_nodes / self.k) + 1, dtype=tf.int32)
214+
med = tf.math.top_k(tf.linalg.matrix_transpose(s), k=idx).values[..., -1]
215+
# Asymmetric l1-norm
216+
if K.ndim(s) == 2:
217+
loss = s - med
218+
else:
219+
loss = s - med[:, tf.newaxis, ...]
220+
loss = (tf.cast(loss >= 0, loss.dtype) * (self.k - 1) * loss) + (
221+
tf.cast(loss < 0, loss.dtype) * loss * -1.0
222+
)
223+
loss = tf.math.reduce_sum(loss, axis=(-2, -1))
224+
loss = 1 / (n_nodes * (self.k - 1)) * (n_nodes * (self.k - 1) - loss)
225+
226+
return loss
227+
228+
def get_config(self):
229+
config = {
230+
"k": self.k,
231+
"mlp_hidden": self.mlp_hidden,
232+
"mlp_activation": self.mlp_activation,
233+
"totvar_coeff": self.totvar_coeff,
234+
"balance_coeff": self.balance_coeff,
235+
}
236+
base_config = super().get_config()
237+
return {**base_config, **config}

‎spektral/layers/pooling/dmon_pool.py

-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def reduce_index(self, i, s, **kwargs):
162162
return i_pool
163163

164164
def modularity_loss(self, a, s, a_pool):
165-
166165
if K.is_sparse(a):
167166
n_edges = tf.cast(len(a.values), dtype=s.dtype)
168167

‎spektral/layers/pooling/global_pool.py

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
class GlobalPool(Layer):
1010
def __init__(self, **kwargs):
11-
1211
super().__init__(**kwargs)
1312
self.supports_masking = True
1413
self.pooling_op = None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from core import MODES, run_layer
2+
3+
from spektral import layers
4+
5+
config = {
6+
"layer": layers.GTVConv,
7+
"modes": [MODES["SINGLE"], MODES["BATCH"]],
8+
"kwargs": {
9+
"channels": 8,
10+
"delta_coeff": 1.0,
11+
"epsilon": 0.001,
12+
"activation": "relu",
13+
},
14+
"dense": True,
15+
"sparse": True,
16+
"edges": False,
17+
}
18+
19+
20+
def test_layer():
21+
run_layer(config)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from spektral import layers
2+
from tests.test_layers.pooling.core import MODES, run_layer
3+
4+
config = {
5+
"layer": layers.AsymCheegerCutPool,
6+
"modes": [MODES["SINGLE"], MODES["BATCH"]],
7+
"kwargs": {
8+
"k": 5,
9+
"return_selection": True,
10+
"mlp_hidden": [32],
11+
"totvar_coeff": 1.0,
12+
"balance_coeff": 1.0,
13+
},
14+
"dense": True,
15+
"sparse": True,
16+
}
17+
18+
19+
def test_layer():
20+
run_layer(config)

‎tests/test_utils/test_convolution.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from spektral.utils import convolution
77

88
g = nx.generators.erdos_renyi_graph(10, 0.2)
9-
adj_sp = nx.adj_matrix(g).astype("f")
9+
adj_sp = nx.adjacency_matrix(g).astype("f")
1010
adj = adj_sp.A.astype("f")
1111
degree = np.diag([d[1] for d in nx.degree(g)])
1212
tol = 1e-6

0 commit comments

Comments
 (0)
Please sign in to comment.