In [1]:
### CONFIG ###
from trafficgraphnn.sumo_network import SumoNetwork

sn = SumoNetwork(
    'data/networks/simonnet/simonnet.net.xml', routefile='data/networks/simonnet/simonnet_rand_routes.routes.xml',
    lanewise=True, addlfiles=['data/networks/simonnet/simonnet_e1.add.xml', 'data/networks/simonnet/simonnet_e2.add.xml', 'data/networks/simonnet/tls_output.add.xml']
)
#sn.run()

2018-07-27 13:42:38,052 matplotlib.font_manager:1465 DEBUG    Using fontManager instance from /home/simon/.cache/matplotlib/fontList.json
2018-07-27 13:42:38,396 matplotlib.backends:90  DEBUG    backend module://ipykernel.pylab.backend_inline version unknown
2018-07-27 13:42:38,701 matplotlib.backends:90  DEBUG    backend module://ipykernel.pylab.backend_inline version unknown


In [2]:
from trafficgraphnn.preprocess_data import PreprocessData
preprocess = PreprocessData(sn)

A, X_train, Y_train, X_test, Y_test, X_val, Y_val = preprocess.preprocess_for_gat()


In [12]:
from __future__ import division
import numpy as np

from keras.callbacks import EarlyStopping, TensorBoard
from keras.layers import Input, Dropout
from keras.models import Model
from keras.optimizers import Adam
from keras.regularizers import l2

from keras_gat import GraphAttention
from keras_gat.utils import load_data

# Parameters
N = X_train.shape[0]          # Number of nodes in the graph
F = X_train.shape[1]          # Original feature dimesnionality
n_classes = Y_train.shape[1]  # Number of classes
F_ = Y_train.shape[1]         # Output dimension of first GraphAttention layer
n_attn_heads = 8              # Number of attention heads in first GAT layer
dropout_rate = 0.6            # Dropout rate applied to the input of GAT layers
l2_reg = 5e-4                 # Regularization rate for l2
learning_rate = 5e-3          # Learning rate for SGD
epochs = 2000                 # Number of epochs to run for
es_patience = 100             # Patience fot early stopping

print('N:', N)
print('F:', F)
print('n_classes:', n_classes)

# Model definition (as per Section 3.3 of the paper)
X_in = Input(shape=(F,))
A_in = Input(shape=(N,))

dropout1 = Dropout(dropout_rate)(X_in)
graph_attention_1 = GraphAttention(F_,
                                   attn_heads=n_attn_heads,
                                   attn_heads_reduction='concat',
                                   activation='elu',
                                   kernel_regularizer=l2(l2_reg))([dropout1, A_in])
dropout2 = Dropout(dropout_rate)(graph_attention_1)
graph_attention_2 = GraphAttention(n_classes,
                                   attn_heads=1,
                                   attn_heads_reduction='average',
                                   activation='softmax',
                                   kernel_regularizer=l2(l2_reg))([dropout2, A_in])

# Build model
model = Model(inputs=[X_in, A_in], outputs=graph_attention_2)
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer,
              loss='mean_squared_error',
              weighted_metrics=['acc'])
model.summary()

# Callbacks
es_callback = EarlyStopping(monitor='val_weighted_acc', patience=es_patience)
tb_callback = TensorBoard(batch_size=N)

# Train model
validation_data = ([X_val, A], Y_val)
model.fit([X_train, A],
          Y_train,
          epochs=epochs,
          batch_size=N,
          validation_data = validation_data,
          shuffle=False,  # Shuffling data means shuffling the whole graph
          callbacks=[es_callback, tb_callback])

# Evaluate model
eval_results = model.evaluate([X_test, A],
                              Y_test,
                              batch_size=N,
verbose=0)

print('Done.\n'
      'Test loss: {}\n'
'Test accuracy: {}'.format(*eval_results))

N: 120
F: 500
n_classes: 6
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_9 (InputLayer)            (None, 500)          0                                            
__________________________________________________________________________________________________
dropout_45 (Dropout)            (None, 500)          0           input_9[0][0]                    
__________________________________________________________________________________________________
input_10 (InputLayer)           (None, 120)          0                                            
__________________________________________________________________________________________________
graph_attention_9 (GraphAttenti (None, 48)           24096       dropout_45[0][0]                 
                                                                 input_10[0][0]   

Epoch 40/2000
Epoch 41/2000
Epoch 42/2000
Epoch 43/2000
Epoch 44/2000
Epoch 45/2000
Epoch 46/2000
Epoch 47/2000
Epoch 48/2000
Epoch 49/2000
Epoch 50/2000
Epoch 51/2000
Epoch 52/2000
Epoch 53/2000
Epoch 54/2000
Epoch 55/2000
Epoch 56/2000
Epoch 57/2000
Epoch 58/2000
Epoch 59/2000
Epoch 60/2000
Epoch 61/2000
Epoch 62/2000
Epoch 63/2000
Epoch 64/2000
Epoch 65/2000
Epoch 66/2000
Epoch 67/2000
Epoch 68/2000
Epoch 69/2000
Epoch 70/2000
Epoch 71/2000
Epoch 72/2000
Epoch 73/2000
Epoch 74/2000
Epoch 75/2000
Epoch 76/2000
Epoch 77/2000
Epoch 78/2000
Epoch 79/2000
Epoch 80/2000
Epoch 81/2000
Epoch 82/2000
Epoch 83/2000
Epoch 84/2000
Epoch 85/2000
Epoch 86/2000
Epoch 87/2000
Epoch 88/2000
Epoch 89/2000
Epoch 90/2000


Epoch 91/2000
Epoch 92/2000
Epoch 93/2000
Epoch 94/2000
Epoch 95/2000
Epoch 96/2000
Epoch 97/2000
Epoch 98/2000
Epoch 99/2000
Epoch 100/2000
Epoch 101/2000
Done.
Test loss: 81126.1796875
Test accuracy: 0.10000000149011612


In [13]:
prediction = model.predict([X_train, A], batch_size = N)
print(prediction)

[[2.79451942e-05 3.55771859e-04 8.78717983e-05 1.76893675e-03
  9.97683764e-01 7.56653899e-05]
 [7.85502272e-15 1.66014737e-11 1.26800249e-13 2.48488618e-07
  9.99999762e-01 5.54285062e-15]
 [1.33273460e-12 1.33443578e-09 7.59848157e-11 7.09313284e-08
  9.99999881e-01 2.91140397e-12]
 [2.60495069e-03 6.80634240e-03 4.19700053e-03 3.30059268e-02
  9.50867057e-01 2.51863874e-03]
 [1.07801725e-05 9.12051473e-05 2.35390362e-05 7.45395664e-04
  9.99119580e-01 9.50819958e-06]
 [2.19474191e-06 3.46322668e-05 7.02762918e-06 7.94531603e-04
  9.99159694e-01 1.91302502e-06]
 [1.26600256e-07 3.32677928e-06 3.97650439e-07 8.62715315e-05
  9.99909759e-01 1.75651408e-07]
 [9.85163661e-06 9.90979970e-05 2.56504227e-05 1.81090832e-03
  9.98047829e-01 6.64176059e-06]
 [9.85163661e-06 9.90979970e-05 2.56504227e-05 1.81090832e-03
  9.98047829e-01 6.64176059e-06]
 [2.15237583e-06 2.70165838e-05 3.99330429e-06 1.01933270e-04
  9.99863148e-01 1.75257810e-06]
 [1.96144404e-03 1.48349931e-03 2.13661790e-03 5.3