In [72]:
import math
import random
from collections import deque, defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import networkx as nx
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

In [73]:
class CustomGraph:
    
    def __init__(self):
        self.graph = Data()
        self.graph_type = None
        self.adj_list = None
        self.num_nodes = 0
        self.nx_graph = nx.Graph()
    
    def create_ladder(self, n):
        self.graph_type = 'ladder'
        self.num_nodes = 2 * n
        edges = [(0, 1), (1, 0)]
        nx_edges = [(0, 1)]
        
        for i in range(2, 2 * n, 2):
            edges.append((i, i + 1))
            edges.append((i + 1, i))
            nx_edges.append((i, i + 1))
            edges.append((i, i - 2))
            edges.append((i - 2, i))
            nx_edges.append((i, i - 2))
            edges.append((i + 1, i - 1))
            edges.append((i - 1, i + 1))
            nx_edges.append((i + 1, i - 1))
            
        edge_index = torch.tensor(edges, dtype=torch.long)
        self.graph.edge_index = edge_index
        self.nx_graph.add_edges_from(nx_edges)
        
    def create_grid(self, n, m):
        self.graph_type = 'grid'
        self.num_nodes = n * m
        edges, nx_edges = [], []
        
        for j in range(m - 1):
            edges.append((j, j + 1))
            edges.append((j + 1, j))
            nx_edges.append((j, j + 1))

        for i in range(1, n):
            for j in range(m - 1):
                edges.append((m*i + j, m*i + j + 1))
                edges.append((m*i + j + 1, m*i + j))
                nx_edges.append((m*i + j, m*i + j + 1))
                edges.append((m*(i-1) + j, m*i + j))
                edges.append((m*i + j, m*(i-1) + j))
                nx_edges.append((m*(i-1) + j, m*i + j))
            edges.append((m*(i-1) + m - 1, m*i + m - 1))
            edges.append((m*i + m - 1, m*(i-1) + m - 1))
            nx_edges.append((m*(i-1) + m - 1, m*i + m - 1))
            
        edge_index = torch.tensor(edges, dtype=torch.long)
        self.graph.edge_index = edge_index
        self.nx_graph.add_edges_from(nx_edges)
        
    def initialize_node_attr(self, val=0.):
        self.graph.x = torch.tensor([val for _ in range(self.num_nodes)]).unsqueeze_(-1)
    
    def randomize_edge_attr(self, a=0.2, b=1):
        self.graph.edge_attr = torch.tensor([random.uniform(a, b) for _ in range(len(self.graph.edge_index))], 
                                            dtype=torch.float).unsqueeze_(-1)
    
    def create_adj_list(self):
        self.adj_list = defaultdict(list)
        for edge in self.graph.edge_index:
            a, b = edge[0].item(), edge[1].item()
            self.adj_list[a].append(b)
    
    def draw(self):
        if self.graph_type == 'ladder':
            nx.draw_spring(self.nx_graph, with_labels=True)
        elif self.graph_type == 'grid':
            nx.draw(self.nx_graph, with_labels=True)
    
    def bfs(self, s):
        if self.adj_list is None:
            self.create_adj_list()
        self.initialize_node_attr()
        queue = deque()
        queue.append(s)
        step = 0
        
        while queue:
            n = queue.popleft()
            if self.graph.x[n] == 1:
                continue
            print(f"Step {step}: visited node {n}")
            self.graph.x[n] = 1.
            queue.extend(self.adj_list[n])
            step += 1
            
        # return queue with each step?
        
    def belman_ford(self, s):
        if self.adj_list is None:
            self.create_adj_list()
        if self.graph.edge_attr is None:
            self.randomize_edge_attr()
        self.initialize_node_attr(math.inf)
        self.graph.x[s] = 0.
        edge_weights = {}
        for edge, weight in zip(self.graph.edge_index, self.graph.edge_attr):
            edge_weights[(tuple(map(int, edge)))] = weight.item()
            
        for _ in range(self.num_nodes - 1):
            for edge, weight in edge_weights.items():
                src, dest = edge[0], edge[1]
                if self.graph.x[src] != math.inf and self.graph.x[src] + weight < self.graph.x[dest]:
                    self.graph.x[dest] = self.graph.x[src] + weight
                    
        # Checking for negative cycles
        # Not necessary for our case as edge features are positive
        for edge, weight in edge_weights.items():
            src, dest = edge[0], edge[1]
            if self.graph.x[src] != math.inf and self.graph.x[src] + weight < self.graph.x[dest]:
                print("Negative cycle detected!")
                
        for n, x in enumerate(self.graph.x):
            print(f"Distance to node {n}: {x.item()}")
            
        # return queue like in bfs?
        
    def __repr__(self):
        rep = f"Graph type: {self.graph_type}\n" + \
                f"Number of nodes: {self.num_nodes}\n" + \
                f"Adjacency list:\n{self.adj_list}\n" + \
                f"Node features:\n{self.graph.x}"
        return rep

In [44]:
class Encoder(nn.Module):
    """Neural network which performs linear projection of current
    features (x^t) and latent features from previous state (h^t-1)
    for each node, and outputs encoded node features (z^t)
    """
    def __init__(self, in_channel=33, out_channel=10):
        super(Encoder, self).__init__()
        self.lin = nn.Linear(in_channel, out_channel, bias=False)
        
    def forward(self, x, h):
        z = torch.cat((x, h), dim=1)
        z = self.lin(z)
        return z

In [226]:
class Processor(MessagePassing):
    """Neural network performing message-passing between the nodes.
    It takes encoded node features (z^t) and edge features (e^t), and produces latent
    features (h^t) for each node.
    """
    def __init__(self, encoded_channels, hidden_channels, latent_channels):
        super(Processor, self).__init__(aggr='max', flow='source_to_target')
        self.lin1 = nn.Linear(encoded_channels, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, latent_channels)

    # I need edge features as well, but don't know how tu plug them in
    def forward(self, z, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # U propagate uvijek saljes edge_index i size, i sve dodatne kwargse koje koristis u funkcijama
        # message i update
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=z)

    def message(self, x_j):
        z = self.lin1(x_j)
        return z

    def update(self, x):
        h = self.lin2(x)
        return h

In [77]:
class Decoder(nn.Module):
    """Neural network which performs linear projection of encoded
    features (z^t) and latent features (h^t) for each node, and produces
    node specific outputs (y^t)
    """
    def __init__(self, in_channel=42, out_channel=1):
        super(Decoder, self).__init__()
        self.lin = nn.Linear(in_channel, out_channel, bias=False)
        
    def forward(self, z, h):
        y = torch.cat((z, h), dim=1)
        y = self.lin(y)
        return y

In [260]:
class Termination(nn.Module):
    """Neural network to determine whether an exectuion should be terminated.
    Takes in hidden states, and outputs the probability that the algorithm
    has finished.
    """
    def __init__(self, in_chanels=32, out_channels=1):
        super(Termination, self).__init__()
        self.lin = nn.Linear(in_chanels, out_channels)
        
    def forward(self, h):
        h = torch.mean(h, dim=0)
        return torch.sigmoid(self.lin(h))

In [261]:
h = torch.rand(3, 5)
T = Termination(5, 1)
T(h)

tensor([0.5522], grad_fn=<SigmoidBackward>)

In [253]:
print(h.shape)
print(_dm.shape)
print(type(h))
print(type(_dm))
_dm = _dm.view(1, -1)
print(_dm.shape)
torch.cat((h, _dm), dim=0)

torch.Size([3, 5])
torch.Size([1, 5])
<class 'torch.Tensor'>
<class 'torch.Tensor'>
torch.Size([1, 5])


tensor([[0.3000, 0.3335, 0.5576, 0.9318, 0.2304],
        [0.4302, 0.6794, 0.1346, 0.9331, 0.2009],
        [0.4771, 0.2420, 0.3100, 0.4492, 0.2473],
        [0.4024, 0.4183, 0.3341, 0.7714, 0.2262]])

In [None]:
d = CustomGraph()
d.create_ladder(5)
d.graph.x = torch.rand((10, 5))

x = d.graph.x
h = torch.rand((10, 10))
edge_index = d.graph.edge_index.t().contiguous()

E = Encoder(15, 20)
P = Processor(20, 30, 10)
D = Decoder(

z = E(x, h)
P(z, edge_index)

In [70]:
#X Define encoder network f_A for each algorithm A
#X     - take current node features x and previous latent features h
#X     - produce enocded inputs z
# Create processor network P - MPNN
#     - take current encoded inputs z and edge features e
#     - produce latent node features
#X Create decoder network g_A for each algorithm A
#X     - take current encoded inputs z and latent features h
#X     - produce outputs y
# Processor network also needs to make a decision whether to terminate
# This is done by algorithm specific termination network T_A

In [79]:
# Napravi nekakav queue i unutra stavi neke (named) tupleove (step, x, termination, ...)
# Dobro si rekao, stavi x u kojem su pohranjena sva trenutna stanja i to usporedjuj u svakom koraku
# Provjeri treba li nesto promijeniti u smislu da se stanja updateaju paralalno (samo preko susjedstva)
# Iako to i nije toliko bitno, jer ce samo nauceni alg biti malo drugaciji, ali ce raditi istu stvar
# Danas napravi i ono kaj je Petar rekao - dva grafa, klasteriraj hidden stateove
# Hidden stateovi nekih karakteristcnih nodeova bi se trebali poklapati
# 

In [208]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        #print(edge_index)
        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)
        #print(x.shape)

        # Step 3: Compute normalization
        row, col = edge_index
        #print(row)
        #print(col)
        deg = degree(row, num_nodes=x.size(0), dtype=x.dtype)
        #print(deg)
        deg_inv_sqrt = deg.pow(-0.5)
        #print(deg_inv_sqrt)
        #print(edge_index)
        #print()
        #print(row.shape)
        #print(row)
        #print(deg_inv_sqrt.shape)
        #print(deg_inv_sqrt)
        #print(deg_inv_sqrt[row].shape)
        #print(deg_inv_sqrt[row])
        
        
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        #print(norm)

        # Step 4-6: Start propagating messages.
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x,
                              norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        print("MESSAGE")
        print(norm.shape, x_j.shape)
        b = norm.view(-1, 1)
        c = norm.unsqueeze(1)
        print(b.shape)
        print(c.shape)
        #print(x_j)
        a = norm.view(-1, 1) * x_j
        print(a.shape)
        #print(a)
        return a

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]

        # Step 6: Return new node embeddings.
        print("UPDATE")
        print(aggr_out.shape)
        #print(aggr_out)
        return aggr_out

In [209]:
d = CustomGraph()
d.create_ladder(5)
d.graph.x = torch.rand((10, 15))
print(d.graph.x)
print(d.graph.edge_index)

x = d.graph.x
edge_index = d.graph.edge_index.t().contiguous()

print(x.shape)
print(edge_index.shape)

tensor([[0.1406, 0.6591, 0.1697, 0.0669, 0.7949, 0.4376, 0.4883, 0.8505, 0.0106,
         0.2651, 0.3136, 0.5625, 0.9660, 0.1673, 0.5228],
        [0.5707, 0.3044, 0.3018, 0.4780, 0.5388, 0.1321, 0.4225, 0.1104, 0.8837,
         0.4332, 0.1959, 0.1481, 0.6184, 0.3152, 0.5037],
        [0.2432, 0.9420, 0.7616, 0.5633, 0.8338, 0.7043, 0.1028, 0.8667, 0.7833,
         0.1941, 0.2652, 0.0949, 0.3141, 0.7705, 0.7781],
        [0.5991, 0.7170, 0.1504, 0.5622, 0.3782, 0.5089, 0.4149, 0.4699, 0.4020,
         0.0640, 0.7310, 0.4683, 0.3597, 0.9037, 0.2889],
        [0.2039, 0.1050, 0.1193, 0.7079, 0.0773, 0.7893, 0.2088, 0.8580, 0.4904,
         0.6377, 0.9905, 0.5525, 0.1263, 0.0445, 0.2113],
        [0.5072, 0.1929, 0.4061, 0.9080, 0.8984, 0.4496, 0.3166, 0.3896, 0.4508,
         0.4633, 0.3505, 0.4236, 0.3648, 0.7482, 0.3292],
        [0.9047, 0.5309, 0.0570, 0.5134, 0.9501, 0.0661, 0.4373, 0.1053, 0.4888,
         0.9868, 0.9355, 0.5249, 0.4799, 0.0290, 0.1758],
        [0.5446, 0.9496, 0.

In [210]:
conv = GCNConv(15, 15)
x = conv(x, edge_index)

MESSAGE
torch.Size([36]) torch.Size([36, 15])
torch.Size([36, 1])
torch.Size([36, 1])
torch.Size([36, 15])
UPDATE
torch.Size([10, 15])


In [193]:
x

tensor([[-0.1007, -0.4942, -0.3864,  0.2864,  0.3258, -0.2034, -0.0524,  0.1674,
         -0.0017, -0.4090, -0.4654,  0.0227,  0.1265, -0.3814,  0.1445],
        [-0.0683, -0.4782, -0.2408,  0.3723,  0.2370, -0.1755,  0.0338,  0.0707,
          0.1085, -0.4143, -0.3083, -0.0044,  0.0373, -0.2913,  0.1844],
        [-0.0501, -0.4270, -0.4408,  0.4282,  0.3975, -0.2402, -0.0827,  0.0800,
          0.0392, -0.4487, -0.4702, -0.1703,  0.1890, -0.3226,  0.0918],
        [-0.0246, -0.5789, -0.3633,  0.2874,  0.3116, -0.2448, -0.0776,  0.2522,
          0.0698, -0.4552, -0.5559, -0.1788,  0.0650, -0.2393,  0.0809],
        [-0.1630, -0.4594, -0.4700,  0.2730,  0.3662, -0.3421, -0.1767,  0.2028,
          0.0874, -0.3411, -0.5032, -0.2898,  0.1067, -0.1521,  0.1213],
        [-0.0617, -0.4119, -0.3368,  0.3303,  0.3137, -0.2586, -0.1469,  0.1412,
          0.1226, -0.3848, -0.4314, -0.3293,  0.0183, -0.0392,  0.0463],
        [-0.1945, -0.3703, -0.3736,  0.3020,  0.2706, -0.2837, -0.2085,  0.1