In [1]:

import tensorflow as tf
import numpy as np
import gnn.gnn_utils as utils
from gnn.GNN import GNN as GraphNetwork


# import tensorflow as tf
# import numpy as np
# import utils
# import GNNs as GNN
# import Net_Karate as n
# from scipy.sparse import coo_matrix

##### GPU & stuff config
import os

physical_devices = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)


############# training set ################


E, N, labels,  mask_train, mask_test = utils.load_karate()
inp, arcnode, graphnode = utils.from_EN_to_GNN(E, N)

Loading karate club dataset...


In [15]:
EPSILON = 0.00000001

@tf.function()
def loss(target,output,mask):
    target = tf.cast(target,tf.float32)
    output = tf.maximum(output, EPSILON, name="Avoiding_explosions")  # to avoid explosions
    xent = -tf.reduce_sum(target * tf.math.log(output), 1)

    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.reduce_mean(mask)
    xent *= mask
    lo = tf.reduce_mean(xent)
    return lo

@tf.function()
def metric(output, target,mask):
    correct_prediction = tf.equal(tf.argmax(output, 1), tf.argmax(target, 1))
    accuracy_all = tf.cast(correct_prediction, tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.reduce_mean(mask)
    accuracy_all *= mask

    return tf.reduce_mean(accuracy_all)


In [77]:
# set input and output dim, the maximum number of iterations, the number of epochs and the optimizer
threshold = 0.001
learning_rate = 0.0001
state_dim = 2
input_dim = inp.shape[1]
output_dim = labels.shape[1]
max_it = 50
num_epoch = 1000

def create_model():

    comp_inp = tf.keras.Input(shape=(input_dim), name="input")
    
    layer = GraphNetwork(input_dim, state_dim, output_dim,                             
                         hidden_state_dim = 5, hidden_output_dim = 25,
                         ArcNode=arcnode,GraphNode=None,threshold=threshold)
    
    output = layer(comp_inp)
    
    model = tf.keras.Model(comp_inp, output)

    return model,layer


tf.keras.backend.clear_session()
model,GNNLayer = create_model()

# initialize GNN
param = "st_d" + str(state_dim) + "_th" + str(threshold) + "_lr" + str(learning_rate)
print(param)

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
for count in range(num_epoch):

    with tf.GradientTape() as tape:
        
        out = model(inp,training=True)

        loss_value = loss(labels,out, mask=mask_train)

        grads = tape.gradient(loss_value, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
        if count % 10 == 0:
            out_val = GNNLayer.predict_node(inp, arcnode)
            loss_value_val = metric(labels,out, mask=mask_test)
            
            print("Epoch ", count)
            print("Training: ", loss_value.numpy())
            print("Test: ", loss_value_val.numpy())

st_d2_th0.001_lr0.0001
Epoch  0
Training:  1.3891187
Test:  0.46666667
Epoch  10
Training:  1.384643
Test:  0.4
Epoch  20
Training:  1.3800702
Test:  0.4
Epoch  30
Training:  1.375344
Test:  0.36666667
Epoch  40
Training:  1.3704196
Test:  0.36666667
Epoch  50
Training:  1.3652649
Test:  0.36666667
Epoch  60
Training:  1.3598564
Test:  0.36666667
Epoch  70
Training:  1.3541766
Test:  0.36666667
Epoch  80
Training:  1.3482156
Test:  0.36666667
Epoch  90
Training:  1.3419695
Test:  0.36666667
Epoch  100
Training:  1.3354416
Test:  0.36666667
Epoch  110
Training:  1.3286406
Test:  0.33333334
Epoch  120
Training:  1.3215774
Test:  0.33333334
Epoch  130
Training:  1.314266
Test:  0.33333334
Epoch  140
Training:  1.3067214
Test:  0.33333334
Epoch  150
Training:  1.2989595
Test:  0.33333334
Epoch  160
Training:  1.2909964
Test:  0.33333334
Epoch  170
Training:  1.2828491
Test:  0.29999998
Epoch  180
Training:  1.2745337
Test:  0.29999998
Epoch  190
Training:  1.2660676
Test:  0.29999998
Epoch

KeyboardInterrupt: 

In [42]:
np.sum(labels,axis=0)

array([ 5, 10,  7, 12])

In [78]:
arcnode

SparseMatrix(indices=array([[ 0,  0],
       [ 0,  1],
       [ 0,  2],
       [ 0,  3],
       [ 0,  4],
       [ 0,  5],
       [ 0,  6],
       [ 0,  7],
       [ 0,  8],
       [ 0,  9],
       [ 0, 10],
       [ 0, 11],
       [ 0, 12],
       [ 0, 13],
       [ 0, 14],
       [ 0, 15],
       [ 1, 16],
       [ 1, 17],
       [ 1, 18],
       [ 1, 19],
       [ 1, 20],
       [ 1, 21],
       [ 1, 22],
       [ 1, 23],
       [ 2, 24],
       [ 2, 25],
       [ 2, 26],
       [ 2, 27],
       [ 2, 28],
       [ 2, 29],
       [ 2, 30],
       [ 2, 31],
       [ 3, 32],
       [ 3, 33],
       [ 3, 34],
       [ 4, 35],
       [ 4, 36],
       [ 5, 37],
       [ 5, 38],
       [ 5, 39],
       [ 6, 40],
       [ 8, 41],
       [ 8, 42],
       [ 8, 43],
       [ 9, 44],
       [13, 45],
       [14, 46],
       [14, 47],
       [15, 48],
       [15, 49],
       [18, 50],
       [18, 51],
       [19, 52],
       [20, 53],
       [20, 54],
       [22, 55],
       [22, 56],
       [23