In [319]:
# Propogate test with cora (to make sure it works)

In [320]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch
import torch.optim as optim

In [321]:
N = 5  # Num nodes
f = 12  # Num input feats
h = f  # Hidden state for node size
m = 6  # Message size
num_class = 1  # ITS REGRESSION RIGHT NOW, CHANGE IT BACK FOR CORA
o = num_class  # Output size
num_props = 3
num_epochs = 5

Input I need:
1. The graph (cora is undirected so I will just make each edge a 2-way) in networkx
2. predecessors list since this is a static graph and i need it for the aggregate
3. N x h that is updated every prop because its a static graph so this is faster

In [322]:
# Make graph
node_dict = {}
node_feats = np.zeros(shape=(N, f))
for i in range(N):
    feat = np.random.uniform(0, 1, f)
    node_feats[i] = feat
    node_dict[i] = feat
G = nx.complete_graph(node_dict, nx.DiGraph())
nx.set_node_attributes(G, node_dict, 'v')

In [323]:
# nx.draw(G)
# plt.show()

In [324]:
# Make numpy matrix of N x f
# node_feats = np.zeros(shape=(N, f))

In [325]:
predecessors = []
for node in G.nodes:
    predecessors.append(list(G.predecessors(node)))
assert len(predecessors) == N

In [326]:
# Input: (N x h) which is all nodes hidden states
# Outut: (N x m) all nodes messages
class MessageModel(nn.Module):
    def __init__(self):
        super(MessageModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(h, m),
            nn.ReLU()
        )
    
    def forward(self, nodes):
        assert nodes.shape == (N, h)
        messages = self.model(nodes)
        assert messages.shape == (N, m)
        return messages

In [327]:
# Input: (N, m) agg messages
# Output: (N, h) new hidden states for nodes
class UpdateModel(nn.Module):
    def __init__(self):
        super(UpdateModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(m, h),
            nn.ReLU()
        )
        
    def forward(self, messages):
        assert messages.shape == (N, m)
        updates = self.model(messages)
        assert updates.shape == (N, h)
        return updates

In [328]:
# Input: (N, h)  updated node hidden states
# Output: (N, o)  outputs for each node (softmax on classes)
class OutputModel(nn.Module):
    def __init__(self):
        super(OutputModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(h, o),
#             nn.Softmax()
        )
        
    def forward(self, nodes):
        assert nodes.shape == (N, h)
        outputs = self.model(nodes)
        assert outputs.shape == (N, o)
        return outputs

In [329]:
# Graph NN block
class GNBlock(nn.Module):
    def __init__(self):
        super(GNBlock, self).__init__()
        self.message_model = MessageModel()
        self.update_model = UpdateModel()
        self.output_model = OutputModel()
    
    # Input: (N x m)
    # Output: (N x m)
    def aggregate(self, predecessors, messages):
        agg = []
        # Collect all in predecessors for each node, if a node has no preds then just 0s for it
        for preds in predecessors:
            if len(preds) > 0:
                in_mess = messages[preds, :]
                assert in_mess.shape == (len(preds), m) or in_mess.shape == (m,)  # if one in-node
#                 agg_in_mess = agg_func(in_mess)
                agg_in_mess = torch.mean(in_mess, dim=0)
                assert agg_in_mess.shape == (m,)
                agg.append(agg_in_mess)
            else:
                agg.append(torch.zeros(m))
        # Stack
#         stack = np.stack(agg)
        stack = torch.stack(agg)
        assert stack.shape == (N, m)
        return stack
        
    # Propogate
    def forward(self, node_states, get_output):
        # Get messages of each node ----
        messages = self.message_model(node_states)
        # Aggregate pred. edges -----
        aggregates = self.aggregate(predecessors, messages)
        # Get Updates for each node hidden state ---------
        updates = self.update_model(aggregates)
        # Get outputs if need to
        if get_output:
            outputs = self.output_model(updates)
            return updates, outputs
        return updates, None
    
    # Outputs: (N, o) tensor
    def backward(self, outputs):
        loss = outputs.sum()
        loss.backward()
        return loss.data.tolist()

In [330]:
# Make GNBlock
gnn = GNBlock()
optimizer = optim.Adam(gnn.parameters(), lr = 1e-2)

In [331]:
def run_epoch(node_feats):
    # Every time you run through a minibatch, zero out the grad
    optimizer.zero_grad()
    # Convert node_feats np to node_states torch tensor
    node_states = torch.tensor(node_feats).float()
    for p in range(num_props):
        node_states, outputs = gnn(node_states, p == num_props-1)
    # Train
    loss = gnn.backward(outputs)
    optimizer.step()
    print('loss: {}'.format(loss))

In [332]:
for epoch in range(num_epochs):
    run_epoch(node_feats)

loss: 0.4241909980773926
loss: 0.19192063808441162
loss: -0.04645250737667084
loss: -0.293054461479187
loss: -0.5538526177406311
