In [1]:
import networkx as nx
import torch
import torch_geometric as tg

from utils.transforms import EuclideanInformationTransform, OneHot
from models.tensor_field_networks import RadiallyParamaterisedTensorProduct, QueryNetwork

import e3nn

In [2]:
class ToyTransformer(tg.nn.MessagePassing):

    def __init__(self, K, Q, V):
        super().__init__(aggr='add')
        self.K = K
        self.Q = torch.nn.Identity()
        self.V = torch.nn.Identity()

    # Should replace these adaptors with a 'graph adapted' subclass of the parent
    def K_adaptor(self, edge_index, features, geometric_information, distances):
        source_indices = edge_index[0, :]
        source_features = features[source_indices]

        # Key queries, represented as a set of edge features
        k_uv = K.forward(source_features,
                         geometric_information,
                         distances)
        return k_uv

    def V_adaptor(self, edge_index, features, geometric_information, distances):
        source_indices = edge_index[0, :]
        source_features = features[source_indices]

        # Key queries, represented as a set of edge features
        k_uv = K.forward(source_features,
                         geometric_information,
                         distances)
        return k_uv



    def forward(self, edge_index, features, geometric_information, distances):
        k_uv = self.K_adaptor(edge_index, features, geometric_information, distances)

        # alpha =  k_ij @ q.T
        # alpha = torch.nn.functional.softmax(alpha, dim=1)
        return self.propagate(edge_index, x=x, alpha=alpha, v=v)


    def message(self, alpha, v_j, edge_index):
        """
        Absolutely horrendous - v_j is the value of each message, and it is
        actually a tensor as long as there are edges in the graph.
        Thus, we need to reference the edge index and reshape alpha into
        a shape that reflects the edge structure.

        awful awful awful
        """

        alpha_j = alpha[edge_index[0, :], edge_index[1, :]]
        alpha_j = alpha_j.reshape(alpha_j.shape[0], 1)

        return alpha_j*v_j


In [3]:
g = nx.DiGraph()

vertices = (0, 1, 2, 3)
edges = [(0, 1),
         (1, 2),
         (2, 0),
         (2, 3),
         (3, 0)
         ]
z = [0, 1, 2, 1]
pos = [(0.,   0.,  0.),
             (-1., -1., -1.),
             (1.,   1.,  1.),
             (2.,   2.,  2.),
             ]

features = {i: {'z': z[i], 'pos': pos[i]} for i in vertices }

for v in vertices:
    g.add_node(v)

for e in edges:
    g.add_edge(*e)

nx.set_node_attributes(g, features)

graph = tg.utils.from_networkx(g)

euc_transform = EuclideanInformationTransform()
one_hot_transform = OneHot('z', 'z')
transform = tg.transforms.Compose([euc_transform, one_hot_transform])

graph = transform(graph)


In [4]:
test_dataloader = tg.data.DataLoader([graph, graph.clone()], batch_size=1)
for batch in test_dataloader:
    break



In [5]:
feature_irreps = e3nn.o3.Irreps("10x0e + 10x1e + 10x2e")
geometric_irreps = e3nn.o3.Irreps("3x0e+3x1e+3x2e")
output_irreps = e3nn.o3.Irreps("10x0e + 10x1e + 10x2e")
internal_key_query_irreps = e3nn.o3.Irreps("5x0e + 5x1e + 5x2e")

K = RadiallyParamaterisedTensorProduct(feature_irreps,
                                      geometric_irreps,
                                      internal_key_query_irreps,
                                      radial_hidden_units=16
                                      )
Q = QueryNetwork(feature_irreps,
                 internal_key_query_irreps)

V = RadiallyParamaterisedTensorProduct(feature_irreps,
                                       geometric_irreps,
                                       output_irreps,
                                       radial_hidden_units=16)


embed = torch.nn.Linear(batch.z.shape[1], feature_irreps.dim)
features = embed(batch.z.float())
edge_harmonics = e3nn.o3.spherical_harmonics(geometric_irreps,
                                             batch.relative_positions,
                                             normalize=False)

tt = ToyTransformer(Q=Q, K=K, V=V)





In [6]:
weights = K.radial_net(batch.distances)
source_indices = batch.edge_index[0, :]
source_features = features[source_indices]

k_uv = K.tensor_product(source_features,
                 edge_harmonics,
                 weights)

In [7]:
(K(source_features, edge_harmonics, batch.distances) == k_uv).all()

tensor(True)

In [8]:
features

tensor([[ 3.3037e-01, -7.4980e-01,  4.3395e-01,  3.6405e-01, -5.8631e-02,
         -5.3934e-01, -3.0915e-01, -1.0250e+00,  3.2032e-01,  5.6237e-01,
          5.5003e-01,  4.0327e-01, -6.3422e-04,  2.6351e-01, -1.6178e-01,
         -4.2415e-01,  8.4053e-01,  8.1346e-01, -8.2944e-01, -2.7472e-01,
          4.5746e-01,  3.2546e-01,  8.5883e-01,  3.6332e-01,  5.2407e-02,
         -6.7889e-01,  1.3784e-01,  3.7491e-02, -2.9573e-02, -1.0870e-01,
          5.0632e-01, -8.5697e-01,  6.1405e-01, -6.4359e-01,  4.7958e-01,
         -9.0270e-02, -3.4762e-02, -8.0660e-01,  4.1118e-01,  3.3420e-02,
         -8.2740e-01, -7.7128e-01, -6.1208e-01, -2.7663e-01,  1.7288e-02,
          1.2395e-01, -1.0792e+00,  5.6299e-02, -1.0036e+00,  4.7238e-01,
          2.6964e-01,  3.6934e-01,  3.1888e-02, -7.6808e-01,  1.1107e+00,
         -7.4334e-01,  4.0324e-01, -7.6760e-01,  4.1159e-01,  1.9941e-01,
          6.4150e-01,  2.8485e-02,  3.3016e-01, -1.5398e-01,  5.7032e-02,
          7.4063e-01, -4.5163e-01,  2.

In [9]:
Q.tensor_product._in2_dim

1

In [10]:
Q.constant.shape

torch.Size([1])

In [11]:
# One row for each edge!
q = Q.forward(features)

###

In [14]:
q.shape

torch.Size([4, 45])

In [12]:
tt.forward(
    batch.edge_index,
    features,
    edge_harmonics,
    graph.distances,
    )

NameError: name 'x' is not defined