In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import networkx as nx
import torch
import torch_geometric as tg

from utils.transforms import EuclideanInformationTransform, OneHot
from  models.se3_attention_mechanisms import Se3EquivariantAttentionMechanism

import e3nn

### Let's test the SE3 equivariant version of the GAT

In [3]:
g = nx.DiGraph()

vertices = (0, 1, 2)
edges = [(0, 1),
         (1, 2),
         (2, 0),
         ]

z = [0, 1, 2, 1]
pos = [(0.,   0.,  0.),
       (-1., -1., -1.),
       (1.,   1.,  1.),
       (2.,   2.,  2.),
     ]

features = {i: {'z': z[i], 'pos': pos[i]} for i in vertices }

for v in vertices:
    g.add_node(v)

for e in edges:
    g.add_edge(*e)

nx.set_node_attributes(g, features)

graph = tg.utils.from_networkx(g)

euc_transform = EuclideanInformationTransform()
one_hot_transform = OneHot('z', 'z')
transform = tg.transforms.Compose([euc_transform, one_hot_transform])

graph = transform(graph)


In [4]:
test_dataloader = tg.data.DataLoader([graph, graph.clone()], batch_size=1)
for batch in test_dataloader:
    break



In [5]:
feature_irreps = e3nn.o3.Irreps("10x0e + 10x1e + 10x2e")
geometric_irreps = e3nn.o3.Irreps("3x0e+3x1e+3x2e")
output_irreps = e3nn.o3.Irreps("10x0e + 10x1e + 10x2e")
internal_key_query_irreps = e3nn.o3.Irreps("5x0e + 5x1e + 5x2e")


net = Se3EquivariantAttentionMechanism(
    feature_irreps=feature_irreps,
    geometric_irreps=geometric_irreps,
    value_out_irreps=output_irreps,
    key_and_query_out_irreps=internal_key_query_irreps,
    radial_network_hidden_units=16
)




In [6]:
embed = torch.nn.Linear(batch.z.shape[1], feature_irreps.dim)
features = embed(batch.z.float())
edge_harmonics = e3nn.o3.spherical_harmonics(geometric_irreps,
                                             batch.relative_positions,
                                             normalize=False)

In [8]:
net.forward(edge_index=batch.edge_index,
           features=features,
           edge_features=edge_harmonics,
           distances=batch.distances,
                                         )

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<eval_with_key>.25", line 15, in forward
    getitem_2 = getattr_3[slice(None, -1, None)];  getattr_3 = None
    expand_2 = empty.expand(getitem_2);  empty = getitem_2 = None
    broadcast_tensors = torch.functional.broadcast_tensors(expand, expand_1, expand_2);  expand = expand_1 = expand_2 = None
                        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    getitem_3 = broadcast_tensors[0];  broadcast_tensors = None
    getattr_4 = getitem_3.shape;  getitem_3 = None
RuntimeError: The size of tensor a (6) must match the size of tensor b (3) at non-singleton dimension 0


In [9]:
net.V

GraphAdaptedTensorProduct(
  (tensor_product): FullyConnectedTensorProduct(10x0e+10x1e+10x2e x 3x0e+3x1e+3x2e -> 10x0e+10x1e+10x2e | 4500 paths | 4500 weights)
  (radial_net): RadialWeightFunction(
    (net): Sequential(
      (0): Linear(in_features=1, out_features=16, bias=True)
      (1): SiLU()
      (2): Linear(in_features=16, out_features=4500, bias=True)
    )
  )
)

###