In [30]:
from spektral.datasets import citation
from spektral.layers import GraphConv # for Graph Convolutional Neural Network
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dropout
import pandas as pd

In [2]:
A, X, y, train_mask, val_mask, test_mask = citation.load_data('cora')
# A: adjacency matrix == shape(N,N)
# X: node features == shape(N,F)
# y: labels == shape(N,n_classes)
# train_mask, val_mask, test_mask: boolean, shows which nodes belong to training, validation, testing sets

Loading cora dataset
Pre-processing node features


In [3]:
N = A.shape[0]
F = X.shape[-1]
n_classes = y.shape[-1]

In [4]:
# Create GNN

X_in = Input(shape=(F,))
A_in = Input(shape=(N,), sparse=True)

X_1 = GraphConv(16, 'relu')([X_in, A_in])
X_1 = Dropout(.5)(X_1)
X_2 = GraphConv(n_classes, 'softmax')([X_1, A_in])

model = Model(inputs=[X_in, A_in], outputs=X_2)

An important thing to notice at this point is how we defined the Input layers of our model. Because the "elements" of our dataset are the nodes themselves, we are telling Keras to consider each node as a separate sample so that the batch axis is implicitly defined as None.
In other words, a sample of the node attributes will be a vector of shape (F, ) and a sample of the adjacency matrix will be one row of shape (N, ).

# Training GCN

When training GCN, we have to pre-process the adjacency matrix to 1) add self-loops and 2) scale the weights of a node's connections according to its degree.

Some layers in Spektral require a different type of pre-processing in order to work correctly, and some work out-of-the-box on the binary A. The pre-processing required by each layer is available as a static class method preprocess().

In [5]:
A = GraphConv.preprocess(A).astype('f4')

In [6]:
model.compile(optimizer='adam',
             loss='categorical_crossentropy',
             weighted_metrics=['acc']) # weighted metrics instead of metrics due to particular semi-supervised problem
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 1433)]       0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 2708)]       0                                            
__________________________________________________________________________________________________
graph_conv (GraphConv)          (None, 16)           22944       input_1[0][0]                    
                                                                 input_2[0][0]                    
__________________________________________________________________________________________________
dropout (Dropout)               (None, 16)           0           graph_conv[0][0]             

In [7]:
# prep data
X = X.toarray()
A = A.astype('f4')
validation_data = ([X, A], y, val_mask)

X, A

(array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 <2708x2708 sparse matrix of type '<class 'numpy.float32'>'
 	with 13264 stored elements in Compressed Sparse Row format>)

In [8]:
X.shape, A.shape, validation_data

((2708, 1433),
 (2708, 2708),
 ([array([[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
   <2708x2708 sparse matrix of type '<class 'numpy.float32'>'
   	with 13264 stored elements in Compressed Sparse Row format>],
  array([[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 1, 0, 0],
         [0, 0, 0, ..., 1, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]], dtype=int32),
  array([False, False, False, ..., False, False, False])))

We have set batch_size=N and shuffle=False. This is because the default behaviour of Keras is to split the data into batches of 32 and shuffle the samples at each epoch. However, shuffling the adjacency matrix along one axis and not the other means that row i will represent a different node than column i.
At the same time, if we split the graph into batches we may end up in a situation where we need to use a node attribute that is not part of the batch. The only solution is to take all the node features at the same time, hence batch_size=N.

In [22]:
# train model
model.fit([X, A], y,
         sample_weight=train_mask,
         validation_data=validation_data,
         batch_size=N,
         shuffle=False,
         epochs=50)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<tensorflow.python.keras.callbacks.History at 0x1454edda0>

In [21]:
pd.DataFrame([train_mask, val_mask, test_mask]).T.head(20)

Unnamed: 0,0,1,2
0,True,False,False
1,True,False,False
2,True,False,False
3,True,False,False
4,True,False,False
5,True,False,False
6,True,False,False
7,True,False,False
8,True,False,False
9,True,False,False


In [25]:
# Evaluate the model
eval_results = model.evaluate([X, A],
                              y,
                              sample_weight=test_mask,
                              batch_size=N)



In [26]:
print('Done.\n'
      'Test loss: {}\n'
      'Test accuracy: {}'.format(*eval_results))

Done.
Test loss: 0.7065826654434204
Test accuracy: 0.6809999942779541
