In [213]:
import math
import random
from collections import deque, defaultdict, namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import networkx as nx
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import GCNConv
from torch_geometric.utils import add_self_loops, degree

In [212]:
class CustomGraph(Data):
    
    def __init__(self):
        self.graph_type = None
        self.adj_list = None
        self.num_nodes = 0
        self.nx_graph = nx.Graph()
    
    def create_ladder(self, n):
        self.graph_type = 'ladder'
        self.num_nodes = 2 * n
        edges = [(0, 1), (1, 0)]
        nx_edges = [(0, 1)]
        
        for i in range(2, 2 * n, 2):
            edges.append((i, i + 1))
            edges.append((i + 1, i))
            nx_edges.append((i, i + 1))
            edges.append((i, i - 2))
            edges.append((i - 2, i))
            nx_edges.append((i, i - 2))
            edges.append((i + 1, i - 1))
            edges.append((i - 1, i + 1))
            nx_edges.append((i + 1, i - 1))
            
        edge_index = torch.tensor(edges, dtype=torch.long)
        self.edge_index = edge_index.t().contiguous()
        self.nx_graph.add_edges_from(nx_edges)
        
        self.initialize_node_attr()
        self.initialize_hidden_states()
        self.randomize_edge_attr()
        self.add_self_edges()
        
    def create_grid(self, n, m):
        self.graph_type = 'grid'
        self.num_nodes = n * m
        edges, nx_edges = [], []
        
        for j in range(m - 1):
            edges.append((j, j + 1))
            edges.append((j + 1, j))
            nx_edges.append((j, j + 1))

        for i in range(1, n):
            for j in range(m - 1):
                edges.append((m*i + j, m*i + j + 1))
                edges.append((m*i + j + 1, m*i + j))
                nx_edges.append((m*i + j, m*i + j + 1))
                edges.append((m*(i-1) + j, m*i + j))
                edges.append((m*i + j, m*(i-1) + j))
                nx_edges.append((m*(i-1) + j, m*i + j))
            edges.append((m*(i-1) + m - 1, m*i + m - 1))
            edges.append((m*i + m - 1, m*(i-1) + m - 1))
            nx_edges.append((m*(i-1) + m - 1, m*i + m - 1))
            
        edge_index = torch.tensor(edges, dtype=torch.long)
        self.edge_index = edge_index.t().contiguous()
        self.nx_graph.add_edges_from(nx_edges)
        
        self.initialize_node_attr()
        self.initialize_hidden_states()
        self.randomize_edge_attr()
        self.add_self_edges()
        
    def initialize_node_attr(self, val=0.):
        self.x = torch.tensor([val for _ in range(self.num_nodes)]).unsqueeze_(-1)
        
    def initialize_hidden_states(self, dim=32, val=0.):
        self.h = torch.zeros([self.num_nodes, dim])
    
    def randomize_edge_attr(self, a=0.2, b=1):
        self.edge_attr = []
        for i, edge in enumerate(self.edge_index[0]):
            if i % 2 == 0:
                weight = random.uniform(a, b)
                self.edge_attr.append(weight)
            else:
                self.edge_attr.append(weight)
                
        self.edge_attr = torch.tensor(self.edge_attr, dtype=torch.float).unsqueeze_(-1)
        
    def add_self_edges(self):
        """Append edge attributes with ones for every self edge.
        We use ones instead of zeros, as these edge weights could somehow multiply messages
        in the Processor network (MPNN). Maybe, idk.
        """
        self.edge_index, _ = add_self_loops(self.edge_index, num_nodes=self.num_nodes)
        self.edge_attr = torch.cat((self.edge_attr, torch.ones([self.num_nodes, 1],
                                                                dtype=torch.float)), dim=0)

    def create_adj_list(self):
        self.adj_list = defaultdict(list)
        sources, targets = self.edge_index
        for s, t in zip(sources, targets):
            self.adj_list[s.item()].append(t.item())
    
    def draw(self):
        if self.graph_type == 'ladder':
            nx.draw_spring(self.nx_graph, with_labels=True)
        elif self.graph_type == 'grid':
            nx.draw(self.nx_graph, with_labels=True)
    
    def bfs(self, s):
        step_list = []
        Step = namedtuple('Step', ['step', 'x'])
        if self.adj_list is None:
            self.create_adj_list()
        self.initialize_node_attr()
        #print(self.graph.x)
        queue = deque()
        queue.append(s)
        i = 0
        
        while queue:
            n = queue.popleft()
            if self.x[n] == 1:
                continue
            print(f"Step {i}: visited node {n}")
            self.x[n] = 1.
            #print(self.graph.x)
            queue.extend(self.adj_list[n])
            step_list.append(Step(i, self.x.clone()))
            i += 1
                             
        return step_list
            
    def belman_ford(self, s):
        if self.adj_list is None:
            self.create_adj_list()
        if self.edge_attr is None:
            self.randomize_edge_attr()
        self.initialize_node_attr(math.inf)
        self.x[s] = 0.
        edge_weights = {}
        for edge, weight in zip(self.edge_index.t(), self.edge_attr):
            edge_weights[(tuple(map(int, edge)))] = weight.item()
            
        for _ in range(self.num_nodes - 1):
            for edge, weight in edge_weights.items():
                src, dest = edge[0], edge[1]
                if self.x[src] != math.inf and self.x[src] + weight < self.x[dest]:
                    self.x[dest] = self.x[src] + weight
                    
        # Checking for negative cycles
        # Not necessary for our case as edge features are positive
        for edge, weight in edge_weights.items():
            src, dest = edge[0], edge[1]
            if self.x[src] != math.inf and self.x[src] + weight < self.x[dest]:
                print("Negative cycle detected!")
                
        for n, x in enumerate(self.x):
            print(f"Distance to node {n}: {x.item()}")
            
        # return queue like in bfs?
        
    def __repr__(self):
        rep = f"Graph type: {self.graph_type}\n" + \
                f"Number of nodes: {self.num_nodes}\n" + \
                f"Adjacency list:\n{self.adj_list}\n" + \
                f"Node features:\n{self.x}"
        return rep

In [193]:
class Encoder(nn.Module):
    """Neural network which performs linear projection of current
    features (x^t) and latent features from previous state (h^t-1)
    for each node, and outputs encoded node features (z^t)
    """
    def __init__(self, in_channel=33, out_channel=10):
        super(Encoder, self).__init__()
        self.lin = nn.Linear(in_channel, out_channel, bias=False)
        
    def forward(self, x, h):
        z = torch.cat((x, h), dim=1)
        z = self.lin(z)
        return z

In [325]:
class Processor(MessagePassing):
    """Neural network performing message-passing between the nodes.
    It takes encoded node features (z^t) and edge features (e^t), and produces latent
    features (h^t) for each node.
    """
    def __init__(self, in_channels=10, hidden_channels=20, out_channels=32):
        super(Processor, self).__init__(aggr='max', flow='source_to_target')
        self.lin1 = nn.Linear(in_channels, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, out_channels)

    # I need edge features as well, but don't know how tu plug them in
    def forward(self, x, edge_index, edge_attr):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        #edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # U propagate uvijek saljes edge_index i size, i sve dodatne kwargse koje koristis u funkcijama
        # message i update
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        h = edge_attr.view(-1, 1) * x_j
        return self.lin1(h)

    def update(self, aggr_out):
        return self.lin2(aggr_out)

In [32]:
class Decoder(nn.Module):
    """Neural network which performs linear projection of encoded
    features (z^t) and latent features (h^t) for each node, and produces
    node specific outputs (y^t)
    """
    def __init__(self, in_channel=42, out_channel=1):
        super(Decoder, self).__init__()
        self.lin = nn.Linear(in_channel, out_channel, bias=False)
        
    def forward(self, z, h):
        y = torch.cat((z, h), dim=1)
        y = self.lin(y)
        return y

In [33]:
class Termination(nn.Module):
    """Neural network to determine whether an exectuion should be terminated.
    Takes in hidden states, and outputs the probability that the algorithm
    has finished.
    """
    def __init__(self, in_chanels=32, out_channels=1):
        super(Termination, self).__init__()
        self.lin = nn.Linear(in_chanels, out_channels)
        
    def forward(self, h):
        h = torch.mean(h, dim=0)
        return torch.sigmoid(self.lin(h))

In [327]:
enc = Encoder()
proc = Processor()
dec = Decoder()
term = Termination()

In [336]:
G = CustomGraph()
G.create_ladder(5)
G.x[3] = 1.
z = enc(G)
h = proc(x=z, edge_index=G.edge_index, edge_attr=G.edge_attr)
y = dec(z, h)
t = term(h)

In [341]:
print(t)

tensor([0.4770], grad_fn=<SigmoidBackward>)


### Here I have to write the training loop and the rest of the code

### Here is testing

In [73]:
graph = CustomGraph()
graph.create_ladder(10)
bfs_queue = graph.bfs(5)
graph.initialize_node_attr()
graph.graph.x[5] = 1
graph.h = torch.zeros([20, 15])
enc = Encoder(16, 10)
proc = Processor(10, 20, 15)
dec = Decoder(25, 1)
#term = Terminate(15, 1)  # Kaj se zapravo predaje ovoj mrezi? U kojem se trenu ona tocno pokrece?

x = graph.graph.x.clone()
h = graph.h.clone()
edge_index = graph.graph.edge_index

for i in range(len(bfs_queue)):
    step = bfs_queue[i]
    node_attr = step.x
    terminate = True if i == len(bfs_queue)-1 else False
    z = enc(x, h)
    h = proc(z, edge_index)
    y = dec(z, h)
    print(y)
    x = y.clone()

Step 0: visited node 5
Step 1: visited node 4
Step 2: visited node 3
Step 3: visited node 7
Step 4: visited node 2
Step 5: visited node 6
Step 6: visited node 1
Step 7: visited node 9
Step 8: visited node 0
Step 9: visited node 8
Step 10: visited node 11
Step 11: visited node 10
Step 12: visited node 13
Step 13: visited node 12
Step 14: visited node 15
Step 15: visited node 14
Step 16: visited node 17
Step 17: visited node 16
Step 18: visited node 19
Step 19: visited node 18
tensor([[0.0263],
        [0.0263],
        [0.0263],
        [0.0120],
        [0.0120],
        [0.0046],
        [0.0263],
        [0.0120],
        [0.0263],
        [0.0263],
        [0.0263],
        [0.0263],
        [0.0263],
        [0.0263],
        [0.0263],
        [0.0263],
        [0.0263],
        [0.0263],
        [0.0263],
        [0.0263]], grad_fn=<MmBackward>)
tensor([[0.0213],
        [0.0208],
        [0.0208],
        [0.0172],
        [0.0172],
        [0.0176],
        [0.0208],
        [0.

In [11]:
d = CustomGraph()
d.create_ladder(5)
d.graph.x = torch.rand((10, 5))

x = d.graph.x
h = torch.rand((10, 10))
edge_index = d.graph.edge_index.t().contiguous()

E = Encoder(15, 20)
P = Processor(20, 30, 10)
D = Decoder(

z = E(x, h)
P(z, edge_index)

SyntaxError: invalid syntax (<ipython-input-11-ed25aeac3d27>, line 14)

## TODO:
- [X] Create a queue for algorithm execution
- [X] Put named tuples inside with step, x, termination, ...
- [ ] Write the training loop
- [X] Use edge attributes somehow - easy, just use the default implementation for GCN, not your own shitty implementation
- [ ] Define loss
- [X] Cluster nodes by their hidden states to see if similar nodes fall within same cluster

### Testing MPNNs

In [208]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        #print(edge_index)
        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)
        #print(x.shape)

        # Step 3: Compute normalization
        row, col = edge_index
        #print(row)
        #print(col)
        deg = degree(row, num_nodes=x.size(0), dtype=x.dtype)
        #print(deg)
        deg_inv_sqrt = deg.pow(-0.5)
        #print(deg_inv_sqrt)
        #print(edge_index)
        #print()
        #print(row.shape)
        #print(row)
        #print(deg_inv_sqrt.shape)
        #print(deg_inv_sqrt)
        #print(deg_inv_sqrt[row].shape)
        #print(deg_inv_sqrt[row])
        
        
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        #print(norm)

        # Step 4-6: Start propagating messages.
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x,
                              norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        print("MESSAGE")
        print(norm.shape, x_j.shape)
        b = norm.view(-1, 1)
        c = norm.unsqueeze(1)
        print(b.shape)
        print(c.shape)
        #print(x_j)
        a = norm.view(-1, 1) * x_j
        print(a.shape)
        #print(a)
        return a

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]

        # Step 6: Return new node embeddings.
        print("UPDATE")
        print(aggr_out.shape)
        #print(aggr_out)
        return aggr_out

In [209]:
d = CustomGraph()
d.create_ladder(5)
d.graph.x = torch.rand((10, 15))
print(d.graph.x)
print(d.graph.edge_index)

x = d.graph.x
edge_index = d.graph.edge_index.t().contiguous()

print(x.shape)
print(edge_index.shape)

tensor([[0.1406, 0.6591, 0.1697, 0.0669, 0.7949, 0.4376, 0.4883, 0.8505, 0.0106,
         0.2651, 0.3136, 0.5625, 0.9660, 0.1673, 0.5228],
        [0.5707, 0.3044, 0.3018, 0.4780, 0.5388, 0.1321, 0.4225, 0.1104, 0.8837,
         0.4332, 0.1959, 0.1481, 0.6184, 0.3152, 0.5037],
        [0.2432, 0.9420, 0.7616, 0.5633, 0.8338, 0.7043, 0.1028, 0.8667, 0.7833,
         0.1941, 0.2652, 0.0949, 0.3141, 0.7705, 0.7781],
        [0.5991, 0.7170, 0.1504, 0.5622, 0.3782, 0.5089, 0.4149, 0.4699, 0.4020,
         0.0640, 0.7310, 0.4683, 0.3597, 0.9037, 0.2889],
        [0.2039, 0.1050, 0.1193, 0.7079, 0.0773, 0.7893, 0.2088, 0.8580, 0.4904,
         0.6377, 0.9905, 0.5525, 0.1263, 0.0445, 0.2113],
        [0.5072, 0.1929, 0.4061, 0.9080, 0.8984, 0.4496, 0.3166, 0.3896, 0.4508,
         0.4633, 0.3505, 0.4236, 0.3648, 0.7482, 0.3292],
        [0.9047, 0.5309, 0.0570, 0.5134, 0.9501, 0.0661, 0.4373, 0.1053, 0.4888,
         0.9868, 0.9355, 0.5249, 0.4799, 0.0290, 0.1758],
        [0.5446, 0.9496, 0.

In [287]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index, edge_attr):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        #edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        #print(x.shape)
        x = self.lin(x)
        #x = torch.ones((x.size(0), 6))
        print(x)
        # Step 3: Compute normalization
        row, col = edge_index
        #print(row, col)
        deg = degree(row, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        print(deg_inv_sqrt[row])
        #print(deg_inv_sqrt[col])
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        #print(norm)
        #print(norm.shape)

        # Step 4-6: Start propagating messages.
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x,
                              norm=norm, edge_attr=edge_attr)

    def message(self, x_i, x_j, norm, edge_attr):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        print(edge_attr.shape)
        print(edge_attr.view(-1, 1).shape)
        #print("x_i")
        #print(x_i)
        #print("x_j")
        #print(x_j)
        #print(x_j.shape)
        
        ret = edge_attr.view(-1, 1) * x_j
        print(ret)
        return ret

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]

        # Step 6: Return new node embeddings.
        return aggr_out

In [288]:
graph = CustomGraph()
graph.create_ladder(7)
graph.x = torch.rand((graph.num_nodes, 3), dtype=torch.float)
print(graph.x)
print(graph.edge_index)

tensor([[0.5129, 0.0593, 0.9542],
        [0.1566, 0.5738, 0.9336],
        [0.0140, 0.7003, 0.3221],
        [0.2465, 0.7691, 0.4411],
        [0.7643, 0.0187, 0.1459],
        [0.4074, 0.5697, 0.0577],
        [0.8975, 0.7874, 0.5048],
        [0.4921, 0.5120, 0.1873],
        [0.7630, 0.9798, 0.7898],
        [0.1147, 0.4239, 0.2617],
        [0.8958, 0.0777, 0.4051],
        [0.6195, 0.4228, 0.9112],
        [0.2767, 0.3269, 0.1440],
        [0.8497, 0.2596, 0.5317]])
tensor([[ 0,  1,  2,  3,  2,  0,  3,  1,  4,  5,  4,  2,  5,  3,  6,  7,  6,  4,
          7,  5,  8,  9,  8,  6,  9,  7, 10, 11, 10,  8, 11,  9, 12, 13, 12, 10,
         13, 11,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 1,  0,  3,  2,  0,  2,  1,  3,  5,  4,  2,  4,  3,  5,  7,  6,  4,  6,
          5,  7,  9,  8,  6,  8,  7,  9, 11, 10,  8, 10,  9, 11, 13, 12, 10, 12,
         11, 13,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13]])


In [289]:
GCN = GCNConv(3, 6)
GCN(graph.x, graph.edge_index, graph.edge_attr)

tensor([[ 0.3094,  0.6654,  0.7840, -0.4709,  0.7415, -0.3710],
        [-0.0818,  0.6899,  0.3689, -0.1690,  0.2802, -0.3338],
        [ 0.0485,  0.5425, -0.0821,  0.2015,  0.0677, -0.0499],
        [ 0.0567,  0.5319,  0.0455,  0.1433,  0.1678, -0.0222],
        [ 0.7595,  0.4094,  0.4755, -0.1422,  0.8018,  0.0789],
        [ 0.3774,  0.4137,  0.0056,  0.2091,  0.3132,  0.1576],
        [ 0.2777,  0.4399,  0.3442,  0.0270,  0.5085,  0.1403],
        [ 0.3868,  0.4377,  0.1394,  0.1081,  0.4033,  0.1053],
        [ 0.0112,  0.5277,  0.3457,  0.0010,  0.3692,  0.0135],
        [ 0.2521,  0.5252,  0.0643,  0.0865,  0.2594, -0.0619],
        [ 0.6755,  0.4537,  0.6409, -0.2569,  0.8700,  0.0077],
        [ 0.1858,  0.6156,  0.6260, -0.2961,  0.6011, -0.2274],
        [ 0.4130,  0.4723,  0.1172,  0.0735,  0.3823,  0.0168],
        [ 0.5139,  0.4850,  0.5995, -0.2254,  0.7645, -0.0203]],
       grad_fn=<AddmmBackward>)
tensor([0.5774, 0.5774, 0.5000, 0.5000, 0.5000, 0.5774, 0.5000, 0.5774,

tensor([[ 0.2724,  1.6822,  1.0523, -0.4727,  1.0416, -0.7057],
        [ 0.2448,  1.7580,  1.1089, -0.4613,  1.0914, -0.6846],
        [ 0.9324,  1.6419,  0.9130, -0.1916,  1.3625, -0.2706],
        [ 0.3849,  1.8351,  0.3385,  0.3032,  0.7646, -0.1881],
        [ 1.0244,  1.1550,  0.5410,  0.1000,  1.1505,  0.1409],
        [ 0.8616,  1.2869,  0.2686,  0.3601,  0.9297,  0.2116],
        [ 0.7814,  1.1387,  0.8023,  0.0311,  1.2462,  0.2347],
        [ 0.8368,  1.1195,  0.3563,  0.2648,  0.9489,  0.2284],
        [ 0.6700,  1.4076,  0.8755, -0.0238,  1.2526,  0.0484],
        [ 0.5155,  1.4040,  0.6603,  0.0194,  0.9874, -0.0926],
        [ 1.0904,  1.4104,  1.2761, -0.3958,  1.6857, -0.1213],
        [ 1.0758,  1.4517,  1.4776, -0.5824,  1.7919, -0.2611],
        [ 1.0180,  0.9142,  0.7196, -0.1642,  1.1876,  0.0171],
        [ 0.7464,  1.0284,  1.0601, -0.4106,  1.2751, -0.1725]],
       grad_fn=<ScatterAddBackward>)