In [1]:
%load_ext autoreload
%autoreload 2

In [43]:
import networkx as nx
import torch
import torch_geometric as tg
import numpy as np

from utils.transforms import EuclideanInformationTransform, OneHot
from  models.se3_attention_mechanisms import Se3EquivariantAttentionMechanism
from models.tensor_field_networks import QueryNetwork

import e3nn

### Let's test the SE3 equivariant version of the GAT

In [3]:
g = nx.DiGraph()

vertices = (0, 1, 2)
edges = [(0, 1),
         (1, 0),
         (1, 2),
         (2, 0),
         ]

z = [0, 1, 2,]
pos = [(0.,   0.,  0.),
       (-1., -1., -1.),
       (1.,   1.,  1.),
     ]

features = {i: {'z': z[i], 'pos': pos[i]} for i in vertices }

for v in vertices:
    g.add_node(v)

for e in edges:
    g.add_edge(*e)

nx.set_node_attributes(g, features)

graph = tg.utils.from_networkx(g)

euc_transform = EuclideanInformationTransform()
one_hot_transform = OneHot('z', 'z')
transform = tg.transforms.Compose([euc_transform, one_hot_transform])

graph = transform(graph)


In [4]:
test_dataloader = tg.data.DataLoader([graph, graph.clone()], batch_size=1)
for batch in test_dataloader:
    break



In [5]:
feature_irreps = e3nn.o3.Irreps("10x0e + 10x1e + 10x2e")
geometric_irreps = e3nn.o3.Irreps("3x0e+3x1e+3x2e")
output_irreps = e3nn.o3.Irreps("10x0e + 10x1e + 10x2e")
internal_key_query_irreps = e3nn.o3.Irreps("5x0e + 5x1e + 5x2e")


net = Se3EquivariantAttentionMechanism(
    feature_irreps=feature_irreps,
    geometric_irreps=geometric_irreps,
    value_out_irreps=output_irreps,
    key_and_query_out_irreps=internal_key_query_irreps,
    radial_network_hidden_units=16
)




In [6]:
from tests.test_attention_mechanism_equivariance import GraphInputEquivarianceTest

alpha, beta, gamma = e3nn.o3.rand_angles(1)

In [7]:
# For normal inputs
embed = torch.nn.Linear(batch.z.shape[1], feature_irreps.dim)
features = embed(batch.z.float())
edge_harmonics = e3nn.o3.spherical_harmonics(geometric_irreps,
                                             batch.relative_positions,
                                             normalize=False)

output = net.forward(edge_index=batch.edge_index,
           features=features,
           edge_features=edge_harmonics,
           distances=batch.distances,
           )

rotation_matrix_for_output = output_irreps.D_from_angles(alpha, beta, gamma)
rotated_output = output @ rotation_matrix_for_output

In [38]:
# Alternative way of getting the rotated harmonics
edge_harmonic_rotater = geometric_irreps.D_from_angles(alpha, beta, gamma).squeeze(0)
rotated_edge_harmonics =  edge_harmonics @ edge_harmonic_rotater

output_from_rotated = net.forward(
           edge_index=batch.edge_index,
           features=features,
           edge_features=rotated_edge_harmonics,
           distances=batch.distances,
           )


In [40]:
(rotated_output - output_from_rotated)

tensor([[[-4.4111e-02,  7.1985e-03, -3.0878e-02, -3.4925e-02,  4.6860e-02,
          -4.9958e-02, -1.8769e-02,  1.8522e-01, -8.1026e-02,  6.1958e-02,
           1.4775e-03,  2.2824e-02,  1.9159e-02, -9.8669e-02, -2.5568e-02,
           8.5812e-02,  3.6213e-02, -9.2414e-04,  2.3195e-02, -3.9827e-02,
          -4.0071e-02,  1.0477e-01,  2.5204e-02, -3.2784e-02, -8.4977e-02,
           3.3356e-02, -5.6514e-02,  2.4619e-02,  1.4071e-02, -2.0325e-02,
           2.4853e-02,  1.9631e-02, -4.6869e-02,  4.1648e-02,  4.1100e-02,
          -5.7533e-02, -1.6084e-02,  2.1501e-02,  6.6564e-02, -9.5900e-02,
          -2.5916e-02,  5.3230e-03, -8.7527e-03,  6.3623e-02,  1.9251e-02,
          -2.3297e-02, -3.1619e-03, -1.6661e-02,  3.0277e-02, -3.4938e-02,
           1.8330e-02,  4.3064e-03,  6.9326e-03, -4.0454e-02, -2.2670e-02,
           1.5212e-02, -8.6178e-03, -4.2730e-02, -4.3235e-02, -1.6684e-03,
           4.5444e-02, -4.4011e-03,  3.8577e-02,  2.5120e-02,  3.7315e-03,
           7.4096e-02,  3

In [None]:
# For rotated inputs
rotated_graph = batch.clone()
rotation_matrix = e3nn.o3._rotation.angles_to_matrix(alpha, beta, gamma)
rotated_graph.relative_positions = (rotation_matrix @ rotated_graph.relative_positions.unsqueeze(-1)).squeeze(-1)



edge_harmonics_from_rotated = e3nn.o3.spherical_harmonics(geometric_irreps,
                                                     rotated_graph.relative_positions,
                                                     normalize=False)


In [48]:
q = QueryNetwork(feature_irreps=feature_irreps, irreps_out=internal_key_query_irreps)

In [49]:
q.forward(features)

tensor([[-0.1527, -0.0830,  0.1483, -0.7328, -0.6267,  0.0179,  0.2110, -0.3453,
          0.2425, -0.7961, -0.1468,  0.4997, -0.9843, -0.4274,  0.5686, -0.1127,
         -0.0883, -0.4857,  0.3248,  0.8573, -0.0110, -0.2758, -0.0941, -0.1833,
          0.5809, -0.2757,  0.3822,  0.0727, -0.4175, -0.4101, -0.8587,  0.3687,
         -0.4869, -0.1748,  0.3836,  0.2581,  0.2781,  0.5381, -0.0721, -0.6052,
         -0.4293,  0.6057,  0.1193,  0.4705, -0.3993],
        [ 0.3417, -0.3640,  0.6812, -0.6029, -0.1906,  0.2185,  0.0493,  0.1199,
         -0.0414, -0.6892, -0.6158, -0.4395,  0.1015, -0.7816,  0.5274, -0.0480,
         -0.3881, -0.4681,  0.2227,  0.2753,  0.2720, -0.0160, -0.2497,  0.0951,
         -0.1425, -0.4509, -0.0651,  0.7093, -0.0478, -0.0314, -1.3942, -0.0447,
         -0.8410, -0.0384,  0.0301, -0.4153,  0.2948,  0.9690, -0.2024,  0.3105,
         -0.3496,  0.3410,  1.2558,  0.2711, -0.0403],
        [ 0.1065, -0.1452,  0.0343, -0.7553, -0.2603,  0.2558, -0.2037,  0.2095,