In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import networkx as nx
import torch
import torch_geometric as tg
import numpy as np

from utils.transforms import EuclideanInformationTransform, OneHot
from  models.se3_attention_mechanisms import Se3AttentionMechanism
from models.se3_equivariant_transformer import Se3EquivariantTransformer

import e3nn

### Let's test the SE3 equivariant version of the GAT

In [3]:
g = nx.DiGraph()

vertices = (0, 1, 2)
edges = [(0, 1),
         (1, 0),
         (1, 2),
         (2, 0),
         ]

z = [0, 1, 2,]
pos = [(0.,   0.,  0.),
       (-1., -1., -1.),
       (1.,   1.,  1.),
     ]

features = {i: {'z': z[i], 'pos': pos[i]} for i in vertices }

for v in vertices:
    g.add_node(v)

for e in edges:
    g.add_edge(*e)

nx.set_node_attributes(g, features)

graph = tg.utils.from_networkx(g)

euc_transform = EuclideanInformationTransform()
one_hot_transform = OneHot('z', 'z')
transform = tg.transforms.Compose([euc_transform, one_hot_transform])

graph = transform(graph)


In [4]:
test_dataloader = tg.data.DataLoader([graph, graph.clone()], batch_size=1)
for batch in test_dataloader:
    break



In [5]:
feature_irreps = e3nn.o3.Irreps("5x0e")
geometric_irreps = e3nn.o3.Irreps("3x0e+3x1e")
output_irreps = e3nn.o3.Irreps("10x0e + 10x1e")
internal_key_query_irreps = e3nn.o3.Irreps("5x0e+5x1e")

num_attention_heads = 2
net = Se3EquivariantTransformer(
    num_features=graph.z.shape[1],
    num_attention_layers=4,
    num_feature_channels=10,
    num_attention_heads=num_attention_heads,
    feature_output_repr=output_irreps,
    geometric_repr=geometric_irreps,
    hidden_feature_repr=feature_irreps,
    key_and_query_irreps=internal_key_query_irreps,
    radial_network_hidden_units=5,
)




In [6]:
from tests.test_attention_mechanism_equivariance import GraphInputEquivarianceTest
alpha, beta, gamma = e3nn.o3.rand_angles(1)

In [8]:
# For normal inputs
output = net.forward(graph=graph
           )

final_output_irreps = (output_irreps*num_attention_heads).simplify()
rotation_matrix_for_output = final_output_irreps.D_from_angles(alpha, beta, gamma).squeeze(0)
rotated_output = output @ rotation_matrix_for_output

In [9]:
rotated_graph = graph.clone()

position_rotator = e3nn.o3.Irreps('1x1e').D_from_angles(alpha, beta, gamma).squeeze(0)
rotated_graph.relative_positions = rotated_graph.relative_positions @ position_rotator

output_from_rotated = net.forward(graph=rotated_graph)

In [13]:
(rotated_output - output_from_rotated).max()

tensor(5.2387e-10, grad_fn=<MaxBackward1>)