In [1]:
%load_ext autoreload
%autoreload 2

In [15]:
import networkx as nx
import torch
import torch_geometric as tg
import numpy as np

from utils.transforms import EuclideanInformationTransform, OneHot
from  models.se3_attention_mechanisms import Se3EquivariantAttentionMechanism

import e3nn

### Let's test the SE3 equivariant version of the GAT

In [3]:
g = nx.DiGraph()

vertices = (0, 1, 2)
edges = [(0, 1),
         (1, 0),
         (1, 2),
         (2, 0),
         ]

z = [0, 1, 2,]
pos = [(0.,   0.,  0.),
       (-1., -1., -1.),
       (1.,   1.,  1.),
     ]

features = {i: {'z': z[i], 'pos': pos[i]} for i in vertices }

for v in vertices:
    g.add_node(v)

for e in edges:
    g.add_edge(*e)

nx.set_node_attributes(g, features)

graph = tg.utils.from_networkx(g)

euc_transform = EuclideanInformationTransform()
one_hot_transform = OneHot('z', 'z')
transform = tg.transforms.Compose([euc_transform, one_hot_transform])

graph = transform(graph)


In [4]:
test_dataloader = tg.data.DataLoader([graph, graph.clone()], batch_size=1)
for batch in test_dataloader:
    break



In [5]:
feature_irreps = e3nn.o3.Irreps("10x0e + 10x1e + 10x2e")
geometric_irreps = e3nn.o3.Irreps("3x0e+3x1e+3x2e")
output_irreps = e3nn.o3.Irreps("10x0e + 10x1e + 10x2e")
internal_key_query_irreps = e3nn.o3.Irreps("5x0e + 5x1e + 5x2e")


net = Se3EquivariantAttentionMechanism(
    feature_irreps=feature_irreps,
    geometric_irreps=geometric_irreps,
    value_out_irreps=output_irreps,
    key_and_query_out_irreps=internal_key_query_irreps,
    radial_network_hidden_units=16
)




In [51]:
from tests.test_attention_mechanism_equivariance import GraphInputEquivarianceTest

alpha = torch.tensor(np.pi/4)
beta = torch.tensor(0.)
gamma = torch.tensor(0.)

In [52]:
# For normal inputs
embed = torch.nn.Linear(batch.z.shape[1], feature_irreps.dim)
features = embed(batch.z.float())
edge_harmonics = e3nn.o3.spherical_harmonics(geometric_irreps,
                                             batch.relative_positions,
                                             normalize=False)

output = net.forward(edge_index=batch.edge_index,
           features=features,
           edge_features=edge_harmonics,
           distances=batch.distances,
           )

rotation_matrix_for_output = output_irreps.D_from_angles(alpha, beta, gamma)
rotated_output = output @ rotation_matrix_for_output

In [53]:
# For rotated inputs
rotated_graph = GraphInputEquivarianceTest.rotate_graph(batch, alpha, beta, gamma)

# For normal inputs
embed = torch.nn.Linear(rotated_graph.z.shape[1], feature_irreps.dim) # This is already invariant!
features_from_rotated = embed(rotated_graph.z.float())

rotated_edge_harmonics = e3nn.o3.spherical_harmonics(geometric_irreps,
                                                     rotated_graph.relative_positions,
                                                     normalize=False
                                                     )


output_from_rotated = net.forward(
           edge_index=rotated_graph.edge_index,
           features=features,
           edge_features=edge_harmonics,
           distances=rotated_graph.distances,
           )


In [54]:
(rotated_output - output_from_rotated).abs().max()

tensor(0.5020, grad_fn=<MaxBackward1>)

In [43]:
vv = e3nn.o3.Irreps('1x1e')

In [49]:
valpha, vbeta, vgamma = torch.tensor(np.pi/2), torch.tensor(0.), torch.tensor(0.)
vv.D_from_angles(valpha, vbeta, vgamma)

tensor([[ 2.3842e-07,  0.0000e+00,  1.0000e+00],
        [ 0.0000e+00,  1.0000e+00,  0.0000e+00],
        [-1.0000e+00,  0.0000e+00,  2.3842e-07]])

In [50]:
GraphInputEquivarianceTest.make_3d_rotation_matrix(valpha, vbeta, vgamma)

tensor([[-4.3711e-08,  0.0000e+00,  1.0000e+00],
        [ 0.0000e+00,  1.0000e+00,  0.0000e+00],
        [-1.0000e+00,  0.0000e+00, -4.3711e-08]])