Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
danielegrattarola committed Jun 1, 2023
2 parents 39fe897 + aa7866a commit 2d0e0cf
Show file tree
Hide file tree
Showing 11 changed files with 634 additions and 5 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ jobs:

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.11
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: 3.11
- name: Install dependencies
run: |
pip install ogb matplotlib
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/style_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ jobs:

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.11
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: 3.11
- name: Lint Python code
run: |
pip install flake8
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: ["3.8", "3.9", "3.10", "3.11"]
os: [ubuntu-latest, macos-latest, windows-latest]

steps:
Expand Down
2 changes: 2 additions & 0 deletions docs/autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
layers.GCSConv,
layers.GINConv,
layers.GraphSageConv,
layers.GTVConv,
layers.TAGConv,
layers.XENetConv,
layers.GINConvBatch,
Expand All @@ -52,6 +53,7 @@
"methods": [],
"classes": [
layers.SRCPool,
layers.AsymCheegerCutPool,
layers.DiffPool,
layers.LaPool,
layers.MinCutPool,
Expand Down
135 changes: 135 additions & 0 deletions examples/other/node_clustering_tvgnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
This example implements the node clustering experiment on citation networks
from the paper:
Total Variation Graph Neural Networks (https://arxiv.org/abs/2211.06218)
Jonas Berg Hansen and Filippo Maria Bianchi
"""

import numpy as np
import tensorflow as tf
from sklearn.metrics.cluster import (
completeness_score,
homogeneity_score,
normalized_mutual_info_score,
)
from tensorflow.keras import Model
from tqdm import tqdm

from spektral.datasets import DBLP
from spektral.datasets.citation import Citation
from spektral.layers import AsymCheegerCutPool, GTVConv
from spektral.utils.sparse import sp_matrix_to_sp_tensor

tf.random.set_seed(1)

################################
# CONFIG/HYPERPARAMETERS
################################
dataset_id = "cora"
mp_channels = 512
mp_layers = 2
mp_activation = "elu"
delta_coeff = 0.311
epsilon = 1e-3
mlp_hidden_channels = 256
mlp_hidden_layers = 1
mlp_activation = "relu"
totvar_coeff = 0.785
balance_coeff = 0.514
learning_rate = 1e-3
epochs = 500

################################
# LOAD DATASET
################################
if dataset_id in ["cora", "citeseer", "pubmed"]:
dataset = Citation(dataset_id, normalize_x=True)
elif dataset_id == "dblp":
dataset = DBLP(normalize_x=True)
X = dataset.graphs[0].x
A = dataset.graphs[0].a
Y = dataset.graphs[0].y
y = np.argmax(Y, axis=-1)
n_clust = Y.shape[-1]


################################
# MODEL
################################
class ClusteringModel(Model):
"""
Defines the general model structure
"""

def __init__(self, aggr, pool):
super().__init__()

self.mp = aggr
self.pool = pool

def call(self, inputs):
x, a = inputs

out = x
for _mp in self.mp:
out = _mp([out, a])

_, _, s_pool = self.pool([out, a])

return s_pool


# Define the message-passing layers
MP_layers = [
GTVConv(
mp_channels, delta_coeff=delta_coeff, epsilon=1e-3, activation=mp_activation
)
for _ in range(mp_layers)
]

# Define the pooling layer
pool_layer = AsymCheegerCutPool(
n_clust,
mlp_hidden=[mlp_hidden_channels for _ in range(mlp_hidden_layers)],
mlp_activation=mlp_activation,
totvar_coeff=totvar_coeff,
balance_coeff=balance_coeff,
return_selection=True,
)

# Instantiate model and optimizer
model = ClusteringModel(aggr=MP_layers, pool=pool_layer)
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)


################################
# TRAINING
################################
@tf.function(input_signature=None)
def train_step(model, inputs):
with tf.GradientTape() as tape:
_ = model(inputs, training=True)
loss = sum(model.losses)
gradients = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))
return model.losses


A = sp_matrix_to_sp_tensor(A)
inputs = [X, A]
loss_history = []

# Training loop
for _ in tqdm(range(epochs)):
outs = train_step(model, inputs)

################################
# INFERENCE/RESULTS
################################
S_ = model(inputs, training=False)
s_out = np.argmax(S_, axis=-1)
nmi = normalized_mutual_info_score(y, s_out)
hom = homogeneity_score(y, s_out)
com = completeness_score(y, s_out)
print("Homogeneity: {:.3f}; Completeness: {:.3f}; NMI: {:.3f}".format(hom, com, nmi))
1 change: 1 addition & 0 deletions spektral/layers/convolutional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .general_conv import GeneralConv
from .gin_conv import GINConv, GINConvBatch
from .graphsage_conv import GraphSageConv
from .gtv_conv import GTVConv
from .message_passing import MessagePassing
from .tag_conv import TAGConv
from .xenet_conv import XENetConv, XENetConvBatch

0 comments on commit 2d0e0cf

Please sign in to comment.