In [1]:
%load_ext autoreload
%autoreload 2

In [37]:
import networkx as nx
import torch
import torch_geometric as tg
import numpy as np

from utils.transforms import EuclideanInformationTransform, OneHot
from  models.se3_attention_mechanisms import Se3AttentionMechanism
from models.se3_equivariant_transformer import Se3EquivariantTransformer

import e3nn

### Let's test the SE3 equivariant version of the GAT

In [38]:
g = nx.DiGraph()

vertices = (0, 1, 2)
edges = [(0, 1),
         (1, 0),
         (1, 2),
         (2, 0),
         ]

z = [0, 1, 2,]
pos = [(0.,   0.,  0.),
       (-1., -1., -1.),
       (1.,   1.,  1.),
     ]

features = {i: {'z': z[i], 'pos': pos[i]} for i in vertices }

for v in vertices:
    g.add_node(v)

for e in edges:
    g.add_edge(*e)

nx.set_node_attributes(g, features)

graph = tg.utils.from_networkx(g)

euc_transform = EuclideanInformationTransform()
one_hot_transform = OneHot('z', 'z')
transform = tg.transforms.Compose([euc_transform, one_hot_transform])

graph = transform(graph)


In [39]:
test_dataloader = tg.data.DataLoader([graph, graph.clone()], batch_size=1)
for batch in test_dataloader:
    break

In [41]:
feature_irreps = e3nn.o3.Irreps("5x0e")
geometric_irreps = e3nn.o3.Irreps("3x0e+3x1e")
output_irreps = e3nn.o3.Irreps("10x0e + 10x1e")
internal_key_query_irreps = e3nn.o3.Irreps("5x0e+5x1e")

net = Se3EquivariantTransformer(
    num_features=1,
    num_attention_layers=4,
    num_feature_channels=10,
    num_attention_heads=2,
    feature_output_repr=output_irreps,
    geometric_repr=geometric_irreps,
    hidden_feature_repr=feature_irreps,
    key_and_query_irreps=internal_key_query_irreps,
    radial_network_hidden_units=5,
)


In [None]:
from tests.test_attention_mechanism_equivariance import GraphInputEquivarianceTest
alpha, beta, gamma = e3nn.o3.rand_angles(1)

In [26]:
# For normal inputs
embed = torch.nn.Linear(batch.z.shape[1], feature_irreps.dim)
features = embed(batch.z.float())
edge_harmonics = e3nn.o3.spherical_harmonics(geometric_irreps,
                                             batch.relative_positions,
                                             normalize=False
                                             )

output = net.forward(edge_index=batch.edge_index,
           features=features,
           edge_features=edge_harmonics,
           distances=batch.distances,
           )

rotation_matrix_for_output = output_irreps.D_from_angles(alpha, beta, gamma)
rotated_output = output @ rotation_matrix_for_output

In [32]:
# Alternative way of getting the rotated harmonics
edge_harmonic_rotater = geometric_irreps.D_from_angles(alpha, beta, gamma).squeeze(0)
rotated_edge_harmonics =  edge_harmonics @ edge_harmonic_rotater

feature_rotater = feature_irreps.D_from_angles(alpha, beta, gamma).squeeze(0)
rotated_features = features @ feature_rotater

output_from_rotated = net.forward(
           edge_index=batch.edge_index,
           features=features,
           edge_features=rotated_edge_harmonics,
           distances=batch.distances,
           )


In [33]:
(rotated_output - output_from_rotated).max()

tensor(1.4901e-08, grad_fn=<MaxBackward1>)