# Libraries

In [18]:
import numpy as np
import warnings
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Dropout, Input
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.random import set_seed

from spektral.data.loaders import SingleLoader
from spektral.datasets.citation import Citation
from spektral.layers import GATConv
from spektral.transforms import LayerPreprocess

# Suppress SciPy sparse efficiency warnings
warnings.filterwarnings('ignore', category=RuntimeWarning, module='scipy.sparse')

# Seed

In [4]:
set_seed(0)

# Dataset

In [5]:
dataset = Citation("cora", normalize_x=True, transforms=[LayerPreprocess(GATConv)])

Pre-processing node features


  self._set_arrayXarray(i, j, x)


In [6]:
print(type(dataset))

<class 'spektral.datasets.citation.Citation'>


# Prepare Sample Weights

In [7]:
def mask_to_weights(mask):
    return mask.astype(np.float32) / np.count_nonzero(mask)

In [8]:
# training, validation, testing
weights_tr, weights_va, weights_te = (
    mask_to_weights(mask)
    for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)
)

In [9]:
print(weights_tr.shape, weights_va.shape, weights_te.shape)

(2708,) (2708,) (2708,)


In [10]:
print(weights_tr[:10])
print(weights_va[:10])
print(weights_te[:10])


[0.00714286 0.00714286 0.00714286 0.00714286 0.00714286 0.00714286
 0.00714286 0.00714286 0.00714286 0.00714286]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


# Parameters

In [11]:
channels = 8  # Number of channels in each head of the first GAT layer
n_attn_heads = 8  # Number of attention heads in first GAT layer
dropout = 0.6  # Dropout rate for the features and adjacency matrix
l2_reg = 2.5e-4  # L2 regularization rate
learning_rate = 5e-3  # Learning rate
epochs = 100 # Number of training epochs
patience = 5  # Patience for early stopping

# Graph Dimension

In [12]:
N = dataset.n_nodes  # Number of nodes in the graph
F = dataset.n_node_features  # Original size of node features
n_out = dataset.n_labels  # Number of classes

In [13]:
print(f"Nodes: {N}")
print(f"Features: {F}")
print(f"Num of classes: {n_out}")

Nodes: 2708
Features: 1433
Num of classes: 7


# Model Definition

In [14]:
# Model definition
# -- Input
x_in = Input(shape=(F,))
a_in = Input((N,), sparse=True) # (N,) for a square matrix of size N x N

# -- Dropoout + Attention (Part 1)
do_1 = Dropout(dropout)(x_in)

# Disable masking for GATConv to avoid None mask issues
class GATConvNoMask(GATConv):
    def call(self, inputs, **kwargs):
        # Remove mask from kwargs to prevent None mask issues
        kwargs.pop('mask', None)
        return super().call(inputs, **kwargs)

gc_1 = GATConvNoMask(
    channels,
    attn_heads=n_attn_heads,
    concat_heads=True,
    dropout_rate=dropout,
    activation="elu",
    kernel_regularizer=l2(l2_reg),
    attn_kernel_regularizer=l2(l2_reg),
    bias_regularizer=l2(l2_reg),
)([do_1, a_in])

# -- Dropoout + Attention (Part 2)
do_2 = Dropout(dropout)(gc_1)
gc_2 = GATConvNoMask(
    n_out,
    attn_heads=1,
    concat_heads=False,
    dropout_rate=dropout,
    activation="softmax",
    kernel_regularizer=l2(l2_reg),
    attn_kernel_regularizer=l2(l2_reg),
    bias_regularizer=l2(l2_reg),
)([do_2, a_in])

# Build model
model = Model(inputs=[x_in, a_in], outputs=gc_2)

# Model Compilation

In [15]:
optimizer = Adam(learning_rate=learning_rate)
model.compile(
    optimizer=optimizer,
    loss=CategoricalCrossentropy(reduction="sum"),
    weighted_metrics=["acc"],
)
model.summary()

#  Train Model

In [16]:
# Train model
loader_tr = SingleLoader(dataset, sample_weights=weights_tr)
loader_va = SingleLoader(dataset, sample_weights=weights_va)
model.fit(
    loader_tr.load(),
    steps_per_epoch=loader_tr.steps_per_epoch,
    validation_data=loader_va.load(),
    validation_steps=loader_va.steps_per_epoch,
    epochs=epochs,
    callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],
)

Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - acc: 0.1714 - loss: 1.9496 - val_acc: 0.0720 - val_loss: 1.9478
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 51ms/step - acc: 0.1786 - loss: 1.9468 - val_acc: 0.0740 - val_loss: 1.9474
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 48ms/step - acc: 0.2357 - loss: 1.9456 - val_acc: 0.0740 - val_loss: 1.9467
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step - acc: 0.1857 - loss: 1.9422 - val_acc: 0.0760 - val_loss: 1.9454
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step - acc: 0.2000 - loss: 1.9408 - val_acc: 0.2000 - val_loss: 1.9433
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 59ms/step - acc: 0.3571 - loss: 1.9382 - val_acc: 0.6120 - val_loss: 1.9406
Epoch 7/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 47ms/step - acc: 0.4500 - lo

<keras.src.callbacks.history.History at 0x30ad70190>

# Evaluate Model

In [17]:
print("Evaluating model.")

loader_te = SingleLoader(dataset, sample_weights=weights_te)
eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)

print("Done.\n" "Test loss: {}\n" "Test accuracy: {}".format(*eval_results))

Evaluating model.
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 166ms/step - acc: 0.8220 - loss: 1.3019
Done.
Test loss: 1.301896095275879
Test accuracy: 0.8220002055168152
