In [None]:
import networkx as nx
import torch
import torch_geometric as tg

from utils.transforms import EuclideanInformationTransform, OneHot
from models.tensor_field_networks import RadiallyParamaterisedTensorProduct, QueryNetwork

import e3nn

In [95]:
class GraphAttentionNetwork(tg.nn.MessagePassing):

    def __init__(self, K, Q, V):
        super().__init__(aggr='add')
        self.K = K
        self.Q = Q
        self.V = V

    def compute_alpha(self, edge_index, k_uv, q):
        """Creates a matrix of alpha values based on keys and queries"""
        alphas = torch.zeros((dots.shape[0], dots.shape[0]))
        for node in range(q.shape[0]): # iterate through the nodes
            neighbourhood_edge_indices = (edge_index[1,:] == node).nonzero() # Finds the indices of the edges for which this node is a target.
            neighbourhood_edge_indices = neighbourhood_edge_indices.flatten()

            neighbourhood_k = k_uv[neighbourhood_edge_indices, :] # Get all k in this neighbourhood
            q_node = q[node]

            neighbourhood_dot = q_node @ neighbourhood_k.T # Matrix multiplication gives dot products
            neighbourhood_alphas = torch.nn.functional.softmax(neighbourhood_dot, dim=0)

            # Now, use the edges to store the alphas at the correct points
            neighbourhood_edges = edge_index[:, neighbourhood_edge_indices]
            source_nodes = neighbourhood_edges[0, :]
            alphas[node, source_nodes] = neighbourhood_alphas

        # Finally, we force an attention coefficient from each node to itself
        diagonal_indices = torch.arange(0, alphas.shape[0])
        alphas[diagonal_indices, diagonal_indices] = 1.

        return alphas


    def forward(self, edge_index, features, geometric_information, distances):
        k_uv = self.K(edge_index, features, geometric_information, distances)

        q = self.Q(features)
        alpha = self.compute_alpha(edge_index, k_uv, q)
        v = self.V(edge_index, features, geometric_information, distances)

        return self.propagate(edge_index, alpha=alpha, v=v)


    def message(self, alpha, v_j, edge_index):
        """
        Absolutely horrendous - v_j is the value of each message, and it is
        actually a tensor as long as there are edges in the graph.
        Thus, we need to reference the edge index and reshape alpha into
        a shape that reflects the edge structure.

        awful awful awful
        """

        alpha_j = alpha[edge_index[0, :], edge_index[1, :]]
        alpha_j = alpha_j.reshape(alpha_j.shape[0], 1)

        return alpha_j*v_j

In [3]:
g = nx.DiGraph()

vertices = (0, 1, 2, 3)
edges = [(0, 1),
         (1, 2),
         (2, 0),
         (2, 3),
         (3, 0)
         ]
z = [0, 1, 2, 1]
pos = [(0.,   0.,  0.),
             (-1., -1., -1.),
             (1.,   1.,  1.),
             (2.,   2.,  2.),
             ]

features = {i: {'z': z[i], 'pos': pos[i]} for i in vertices }

for v in vertices:
    g.add_node(v)

for e in edges:
    g.add_edge(*e)

nx.set_node_attributes(g, features)

graph = tg.utils.from_networkx(g)

euc_transform = EuclideanInformationTransform()
one_hot_transform = OneHot('z', 'z')
transform = tg.transforms.Compose([euc_transform, one_hot_transform])

graph = transform(graph)


In [4]:
test_dataloader = tg.data.DataLoader([graph, graph.clone()], batch_size=1)
for batch in test_dataloader:
    break



In [5]:
feature_irreps = e3nn.o3.Irreps("10x0e + 10x1e + 10x2e")
geometric_irreps = e3nn.o3.Irreps("3x0e+3x1e+3x2e")
output_irreps = e3nn.o3.Irreps("10x0e + 10x1e + 10x2e")
internal_key_query_irreps = e3nn.o3.Irreps("5x0e + 5x1e + 5x2e")

K = RadiallyParamaterisedTensorProduct(feature_irreps,
                                      geometric_irreps,
                                      internal_key_query_irreps,
                                      radial_hidden_units=16
                                      )
Q = QueryNetwork(feature_irreps,
                 internal_key_query_irreps)

V = RadiallyParamaterisedTensorProduct(feature_irreps,
                                       geometric_irreps,
                                       output_irreps,
                                       radial_hidden_units=16)


embed = torch.nn.Linear(batch.z.shape[1], feature_irreps.dim)
features = embed(batch.z.float())
edge_harmonics = e3nn.o3.spherical_harmonics(geometric_irreps,
                                             batch.relative_positions,
                                             normalize=False)

tt = ToyTransformer(Q=Q, K=K, V=V)





In [6]:
# Code to derive the edge features
weights = K.radial_net(batch.distances)
source_indices = batch.edge_index[0, :]
source_features = features[source_indices]

k_uv = K.tensor_product(source_features,
                 edge_harmonics,
                 weights)

In [7]:
(K(source_features, edge_harmonics, batch.distances) == k_uv).all()

tensor(True)

In [9]:
# One row for each node
q = Q.forward(features)
print(q.shape) # a query vector for each node)

###

In [11]:
k_uv.shape # A key vector for each edge

torch.Size([5, 45])

In [12]:
dots = q @ k_uv.T # Matrix multiplication gives us a number of nodes * number of edges matrix
                  # This isn't actually ideal because you waste a few computations

In [13]:
k_uv.shape

torch.Size([5, 45])

In [14]:
targets = batch.edge_index[1, :]
targets

tensor([1, 2, 0, 3, 0])

In [15]:
k_uv.shape

torch.Size([5, 45])

In [92]:
edge_index = batch.edge_index

alphas = torch.zeros((dots.shape[0], dots.shape[0]))
for node in range(dots.shape[0]): # iterate through the nodes
    # This gets the indices of the
    neighbourhood_edge_indices = (edge_index[1,:] == node).nonzero() # Finds the indices of the edges for which this node is a target.
    neighbourhood_edge_indices = neighbourhood_edge_indices.flatten()

    neighbourhood_k = k_uv[neighbourhood_edge_indices, :] # Get all k in this neighbourhood
    q_node = q[node]

    neighbourhood_dot = q_node @ neighbourhood_k.T # Matrix multiplication
    neighbourhood_alphas = torch.nn.functional.softmax(neighbourhood_dot, dim=0)

    # Now, use the edges to store the alphas at the correct points
    neighbourhood_edges = edge_index[:, neighbourhood_edge_indices]
    source_nodes = neighbourhood_edges[0, :]
    alphas[node, source_nodes] = neighbourhood_alphas

# Finally, we force an attention coefficient from each node to itself
diagonal_indices = torch.arange(0, alphas.shape[0])
alphas[diagonal_indices, diagonal_indices] = 1.

In [94]:
alphas

tensor([[1.0000, 0.0000, 0.5000, 0.5000],
        [1.0000, 1.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 1.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 1.0000]], grad_fn=<IndexPutBackward0>)

In [89]:
torch.nn.functional.softmax(neighbourhood_dot, dim=0)

tensor([1.], grad_fn=<SoftmaxBackward0>)

In [86]:
diagonal_indices = torch.arange(0, )

tensor([0., 0., 0., 0.], grad_fn=<DiagBackward0>)

In [81]:
alphas

tensor([[0., 0., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [73]:
q_node.shape

torch.Size([45])

In [74]:
neighbourhood_edge_indices

tensor([2, 4])

In [75]:
neighbourhood_edges

tensor([[2, 3],
        [0, 0]])

In [76]:
alphas = torch.zeros((dots.shape[0], dots.shape[0]))

In [77]:
alphas[neighbourhood_edges[1, :], neighbourhood_edges[0, :]] = 1.0

In [78]:
alphas

tensor([[0., 0., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [17]:
# Need to fill these in POST softmax
alpha = torch.zeros((q.shape[0], q.shape[0]))


In [18]:
for source, target in batch.edge_index.T:
    alpha[target, source] = dots

RuntimeError: expand(torch.FloatTensor{[4, 5]}, size=[]): the number of sizes provided (0) must be greater or equal to the number of dimensions in the tensor (2)

In [None]:
batch.edge_index