# Import libraries

In [3]:
import numpy as np
import tensorflow as tf 
import spektral

# Load the CORA dataset

In [None]:
dataset = spektral.datasets.citation.Citation(
    name = 'cora',
    random_split=True, 
    normalize_x = False
)

# Node features
features = dataset.read()[0].x

# Adjacency matrix
adj = dataset.read()[0].a

# Node-wise labels
labels = dataset.read()[0].y

# Train - val - test masks
train_mask = dataset.mask_tr
val_mask = dataset.mask_va
test_mask = dataset.mask_te

# Print out the shapes
print(features.shape)
print(adj.shape)
print(labels.shape)

print(np.sum(train_mask))
print(np.sum(val_mask))
print(np.sum(test_mask))

'''
Number of nodes : 2708
Number of features per node : 1433 
Number of labels (classes) : 7 

------

Number of training samples : 140
Number of validating samples : 210
Number of testing samples : 2358
'''

# Some utils functions

In [44]:
def masked_softmax_cross_entropy(logits, labels, mask):
    # Take the loss on the entire dataset
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)
    
    # Apply the mask to take only the masked loss terms
    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.reduce_sum(mask)
    loss *= mask 
    
    return tf.reduce_mean(loss)

def masked_accuracy(logits, labels, mask):
    # Compute correct predictions on all samples
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    accuracy_all = tf.cast(correct_prediction, dtype=tf.float32)
    
    # Mask the accuracy appropriately
    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.reduce_sum(mask)
    
    # Compute accuracy only on the samples of interest
    accuracy = accuracy_all * mask
    
    return accuracy

# The basic GNN unit
def gnn(features, adj, transform, activation):
    seq_fts = transform(features) # X @ W
    ret_fts = tf.matmul(adj, seq_fts) # A @ (X @ W)
    
    outputs = activation(ret_fts) # sigma(A @ (X @ W))
    
    return outputs

# Train function

In [None]:
def train_cora(fts, adj, gnn_fn, units, epochs, lr):
    