Code based on this [tutorial](https://keras.io/examples/graph/gnn_citations/).

In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


## Simple MPNN with two layers

### This is how the full MPNN classifier looks.

In [16]:
class MPNN(tf.keras.Model):
    def __init__(
        self,
        node_features,
        edges,
        num_classes,
        hidden_layers = [32,32],
        dropout_rate=0.2,
        normalize=True,
        *args,
        **kwargs,
    ):
        
        super(MPNN, self).__init__(*args, **kwargs)

# In this code we store full graph in memory 

        self.node_features = node_features
        self.edges = edges

# two MPNN layers

        self.layer1 = MPNNlayer(
            hidden_layers,
            dropout_rate,
            name="layer1",
        )
        # Create the second GraphConv layer.
        self.layer2 = MPNNlayer(
            hidden_layers,
            dropout_rate,
            name="layer2",
        )
        
        # Compute logits layer for the classifier (the decoder part)
        self.clas = layers.Dense(units=num_classes, name="logits")

        
    def call(self, batch_indices):
         
        messages_1 = self.layer1(self.node_features,self.edges)
        
        messages_2 = self.layer2(messages_1,self.edges)
        

        ### Now we gather the embeddings in batch_indexes, so that we classify and 
        ### compare the result only for those nodes.
        
        batch_node_embeddings = tf.gather(messages_2, batch_indices)

        # Readout to get the paper subjects from the embeddings
        
        return self.clas(batch_node_embeddings)
        

### Taking a deeper look: first defining our multi-layered perceptrons, then MPNN layers

In [3]:
def create_MLP(hidden_layers, dropout_rate, name=None):
    mlp = []

    for layer in hidden_layers:
        mlp.append(layers.BatchNormalization())
        mlp.append(layers.Dropout(dropout_rate))
        mlp.append(layers.Dense(layer, activation=tf.nn.gelu))

    return keras.Sequential(mlp, name=name)


### MPNN layer.

In [13]:
class MPNNlayer(layers.Layer):
    def __init__(
        self,
        hidden_layers = [32,32],
        dropout_rate=0.2,
        normalize=False,
        *args,
        **kwargs,
    ):
        super(MPNNlayer, self).__init__(*args, **kwargs)
            
        ### These are the two trainable parts of the network. 
        
        self.mlp_aggregate = create_MLP(hidden_layers, dropout_rate)

        self.mlp_update = create_MLP(hidden_layers, dropout_rate)

    def call(self, node_repesentations, edges):

        # node_indexes: for each edge, the source node of the edge
        # neihbor_indexes: for each edge, the target node of the edge
        node_indexes, neighbor_indexes = edges[0], edges[1]
        
        # neighbour_repesentations: for each edge, the embedding of the target node
        neighbour_repesentations = tf.gather(node_repesentations, neighbor_indexes)

        ### Now comes computation:
        
        #Aggregate
        aggregated_messages = self.aggregate(
            node_indexes, neighbour_repesentations, node_repesentations
        )
        
        #Update
        return self.update(node_repesentations, aggregated_messages)

    def aggregate(self, node_indexes, neighbour_messages, node_repesentations):
        num_nodes = node_repesentations.shape[0]
        
        ### Run node features through an MLP
        
        preprocessed_messages = self.mlp_aggregate(neighbour_messages)
        
        ### The sum is taken care of with tf's unsorted segment sum. 
        ### First, get all positions of edges where node i is the source. 
        ### Then, sum all preprocessed messages in this position (they 
        ### have the embedding of the target of the edge). 
        ### The sum is put on the i-th entry of aggregated_message

        aggregated_message = tf.math.unsorted_segment_sum(
            preprocessed_messages, node_indexes, num_segments=num_nodes
        )
        
        return aggregated_message

    def update(self, node_repesentations, aggregated_messages):
        
        ### concatenate aggregated message with my own representation
        updated_messages = tf.concat([node_repesentations, aggregated_messages], axis=1)
        
        ### apply trainable MLP to the result
        
        node_embeddings = self.mlp_update(updated_messages)

        return node_embeddings


## Trying out our model

In [5]:
### Loading Quora dataset
citations = pd.read_csv("cora.cites",
    sep="\t",
    header=None,
    names=["target", "source"],
)

column_names = ["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"]
papers = pd.read_csv("cora.content", sep="\t", header=None, names=column_names,
)

In [6]:
### Some structuring, cleaning

class_values = sorted(papers["subject"].unique())
class_idx = {name: id for id, name in enumerate(class_values)}
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}

papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])

train_data, test_data = [], []

for _, group_data in papers.groupby("subject"):
    # Select around 50% of the dataset for training.
    random_selection = np.random.rand(len(group_data.index)) <= 0.5
    train_data.append(group_data[random_selection])
    test_data.append(group_data[~random_selection])

train_data = pd.concat(train_data).sample(frac=1)
test_data = pd.concat(test_data).sample(frac=1)

print("Train data shape:", train_data.shape)
print("Test data shape:", test_data.shape)



Train data shape: (1364, 1435)
Test data shape: (1344, 1435)


In [7]:
train_data

Unnamed: 0,paper_id,term_0,term_1,term_2,term_3,term_4,term_5,term_6,term_7,term_8,...,term_1424,term_1425,term_1426,term_1427,term_1428,term_1429,term_1430,term_1431,term_1432,subject
1936,1590,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
2675,593,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,6
2140,2406,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,3
2308,2573,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2560,202,0,0,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,6
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
525,1413,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,6
629,418,0,0,1,0,1,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
2152,1613,0,0,0,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,1
704,1059,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,3


In [8]:
feature_names = set(papers.columns) - {"paper_id", "subject"}
num_features = len(feature_names)
num_classes = len(class_idx)

# Create train and test features as a numpy array.
x_train = train_data["paper_id"].to_numpy()
x_test = test_data["paper_id"].to_numpy()
# Create train and test targets as a numpy array.
y_train = train_data["subject"].to_numpy()
y_test = test_data["subject"].to_numpy()

x_train, y_train

(array([1590,  593, 2406, ..., 1613, 1059,   69]),
 array([1, 6, 3, ..., 1, 3, 0]))

In [17]:
#Adyacency list 
edges = citations[["source", "target"]].to_numpy().T

# Tensorflow vector for each node
node_features = tf.cast(
    papers.sort_values("paper_id")[feature_names].to_numpy(), dtype=tf.dtypes.float32
)

print("Edges shape:", edges.shape)
print("Nodes shape:", node_features.shape)

# Create the GNN. Note how the entire graph is passed onto the model. 

GNN = MPNN(
    node_features = node_features,
    edges = edges,
    num_classes=7,
    hidden_layers=[32,32],
    dropout_rate=0.2,
    name="mpnn_model",
)





Edges shape: (2, 5429)
Nodes shape: (2708, 1433)


  papers.sort_values("paper_id")[feature_names].to_numpy(), dtype=tf.dtypes.float32


In [18]:
GNN.compile(
        optimizer=keras.optimizers.Adam(0.01),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
    )

# Create an early stopping callback.
early_stopping = keras.callbacks.EarlyStopping(
        monitor="val_acc", patience=50, restore_best_weights=True
    )


    # Fit the model.
history = GNN.fit(
        x=x_train,
        y=y_train,
        epochs=300,
        batch_size=256,
        validation_split=0.15,
        callbacks=[early_stopping],
    )



Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300


Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78/300
Epoch 79/300
Epoch 80/300
Epoch 81/300
Epoch 82/300
Epoch 83/300
Epoch 84/300
Epoch 85/300
Epoch 86/300
Epoch 87/300
Epoch 88/300
Epoch 89/300
Epoch 90/300
Epoch 91/300
Epoch 92/300
Epoch 93/300
Epoch 94/300
Epoch 95/300
Epoch 96/300
Epoch 97/300
Epoch 98/300
Epoch 99/300
Epoch 100/300
Epoch 101/300
Epoch 102/300
Epoch 103/300
Epoch 104/300
Epoch 105/300
Epoch 106/300
Epoch 107/300
Epoch 108/300
Epoch 109/300
Epoch 110/300
Epoch 111/300
Epoch 112/300
Epoch 113/300
Epoch 114/300
Epoch 115/300


In [19]:
GNN.evaluate(x=x_test, y=y_test)



[1.1346373558044434, 0.8266369104385376]

Not bad!