# Spektral playground

In [1]:
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2

from spektral.datasets import delaunay
from spektral.layers import GraphAttention, GlobalAttentionPool

In [2]:
"""
This example shows how to perform graph classification with a synthetic dataset
of Delaunay triangulations, using a graph attention network in batch mode.
"""

# Load data
A, X, y = delaunay.generate_data(return_type='numpy', classes=[0, 5])

In [6]:
print(A.shape)
A[0,:,:]

(2000, 7, 7)


array([[0., 1., 0., 1., 1., 1., 0.],
       [1., 0., 1., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0., 1., 1.],
       [1., 0., 0., 0., 1., 1., 0.],
       [1., 0., 0., 1., 0., 1., 1.],
       [1., 1., 1., 1., 1., 0., 1.],
       [0., 0., 1., 0., 1., 1., 0.]])

In [8]:
print(X.shape)
X[0,:,:]

(2000, 7, 2)


array([[ 1.02077013,  0.97343257],
       [10.79575071,  4.55267798],
       [ 9.44250279,  6.01343722],
       [ 1.28170365,  3.20603835],
       [-0.45947236,  6.55401358],
       [ 5.86487842,  4.93557475],
       [ 7.10837403,  8.38520455]])

In [10]:
print(y.shape)
y[0,:]

(2000, 2)


array([1., 0.])

In [11]:
# Parameters
N = X.shape[-2]          # Number of nodes in the graphs            |   7
F = X.shape[-1]          # Original feature dimensionality          |   7
n_classes = y.shape[-1]  # Number of classes                        |   2

l2_reg = 5e-4            # Regularization rate for l2
learning_rate = 1e-3     # Learning rate for Adam
epochs = 20              # Number of training epochs
batch_size = 32          # Batch size
es_patience = 200        # Patience fot early stopping

In [12]:
# Train/test split
A_train, A_test, \
x_train, x_test, \
y_train, y_test = train_test_split(A, X, y, test_size=0.1)


In [5]:
# Model definition
X_in = Input(shape=(N, F))
A_in = Input((N, N))

gc1 = GraphAttention(32, activation='relu', kernel_regularizer=l2(l2_reg))([X_in, A_in])
gc2 = GraphAttention(32, activation='relu', kernel_regularizer=l2(l2_reg))([gc1, A_in])
pool = GlobalAttentionPool(128)(gc2)

output = Dense(n_classes, activation='softmax')(pool)

In [6]:
# Build model
model = Model(inputs=[X_in, A_in], outputs=output)
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['acc'])
model.summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 7, 2)]       0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 7, 7)]       0                                            
__________________________________________________________________________________________________
graph_attention (GraphAttention (None, 7, 32)        160         input_1[0][0]                    
                                                                 input_2[0][0]                    
__________________________________________________________________________________________________
graph_attention_1 (GraphAttenti (None, 7, 32)        1120        graph_attention[0][0] 

In [15]:
# Train model
model.fit([x_train, A_train],
          y_train,
          batch_size=batch_size,
          validation_split=0.1,
          epochs=epochs,
          callbacks=[
              EarlyStopping(patience=es_patience, restore_best_weights=True)
          ])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x24d145a6cc8>

In [16]:
# Evaluate model
print('Evaluating model.')
eval_results = model.evaluate([x_test, A_test],
                              y_test,
                              batch_size=batch_size)
print('Done. Test loss: {:.4f}. Test acc: {:.2f}'.format(*eval_results))

Evaluating model.
Done. Test loss: 0.7106. Test acc: 0.67
