In [1]:
from spektral.datasets import citation
A, X, y, train_mask, val_mask, test_mask = citation.load_data('cora')

N = A.shape[0]
F = X.shape[-1]
n_classes = y.shape[-1]

Downloading corafrom https://github.com/tkipf/gcn/raw/master/gcn/data/
Loading cora dataset
Pre-processing node features


In [2]:
from spektral.layers import GraphConv
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dropout

In [3]:
X_in = Input(shape=(F, ))
A_in = Input((N, ), sparse=True)

X_1 = GraphConv(16, 'relu')([X_in, A_in])
X_1 = Dropout(0.5)(X_1)
X_2 = GraphConv(n_classes, 'softmax')([X_1, A_in])

model = Model(inputs=[X_in, A_in], outputs=X_2)

In [4]:
A = GraphConv.preprocess(A).astype('f4')

In [5]:
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              weighted_metrics=['acc'])
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 1433)]       0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 2708)]       0                                            
__________________________________________________________________________________________________
graph_conv (GraphConv)          (None, 16)           22944       input_1[0][0]                    
                                                                 input_2[0][0]                    
__________________________________________________________________________________________________
dropout (Dropout)               (None, 16)           0           graph_conv[0][0]             

In [6]:
# Prepare data
X = X.toarray()
A = A.astype('f4')
validation_data = ([X, A], y, val_mask)

# Train model
model.fit([X, A], y,
          sample_weight=train_mask,
          validation_data=validation_data,
          batch_size=N,
          shuffle=False)




<tensorflow.python.keras.callbacks.History at 0x7fdd6b09cf10>

In [7]:
# Evaluate model
eval_results = model.evaluate([X, A],
                              y,
                              sample_weight=test_mask,
                              batch_size=N)
print('Done.\n'
      'Test loss: {}\n'
      'Test accuracy: {}'.format(*eval_results))

Done.
Test loss: 0.7183346748352051
Test accuracy: 0.18000000715255737
