In [5]:
import torch
import torch_geometric as tg
import networkx as nx

In [6]:
class ToyTransformer(tg.nn.MessagePassing):

    def __init__(self):
        super().__init__(aggr='add')
        self.Q = torch.nn.Identity()
        self.V = torch.nn.Identity()

    def K(self, x, rel_pos):
        return x
            
    def forward(self, edge_index, x, edge_features):
        k = self.K(x, edge_features)
        q = self.Q(x)
        v = self.V(x)

        alpha = k @ q.T
        # alpha = torch.nn.functional.softmax(alpha, dim=1)
        return self.propagate(edge_index, x=x, alpha=alpha, v=v)
    

    def message(self, alpha, v_j, edge_index):
        """
        Absolutely horrendous - v_j is the value of each message, and it is 
        actually a tensor as long as there are edges in the graph.
        Thus, we need to reference the edge index and reshape alpha into 
        a shape that reflects the edge structure.
        
        awful awful awful
        """

        alpha_j = alpha[edge_index[0, :], edge_index[1, :]]
        alpha_j = alpha_j.reshape(alpha_j.shape[0], 1)

        return alpha_j*v_j 


In [7]:
vertices = [1, 2, 3, 
            ]
edges = [(1, 2),
         (2, 3),
         (3, 1),
         ]

# edges += [(i, i) for i in vertices]

node_attributes = {1: {'x': [1., 0.]},
                   2: {'x': [0., 1.,]},
                   3: {'x': [1., 1.,]},
                #    4: {'x': [2., 2.,]}
                   }


edge_attributes = {(1, 2): {'rel': [1., 0, 0,]} ,
                (2, 3): {'rel': [1., 0, 0]},
                (3, 1): {'rel':[1., 0, 0]},
                # (2, 4): {'rel':[1., 0, 0]}
                }


graph = nx.DiGraph()
for v in vertices:
    graph.add_node(v)

graph.add_edges_from(edges)
nx.set_node_attributes(graph, node_attributes)
nx.set_edge_attributes(graph, edge_attributes)

# nx.draw(graph)


toy_data = tg.utils.convert.from_networkx(graph)

In [8]:
tt = ToyTransformer()
tt.forward(toy_data.edge_index, toy_data.x, toy_data.rel)


tensor([[1., 1.],
        [0., 0.],
        [0., 1.]])