In [1]:
import torch_geometric as tg

import torch
import e3nn

import networkx as nx
import numpy as np

from models.attention_mechanisms import Se3AttentionHead, Se3EquivariantAttentionMechanism
from utils.load_md17 import load_md17
from utils.transforms import EuclideanInformationTransform

In [2]:
dd = load_md17('aspirin CCSD', './../real_datasets/MD17', 2)
data = dd['train']

In [3]:
def make_3d_rotation_matrix(alpha, beta, gamma):
    rot_z = torch.tensor([[torch.cos(alpha), -torch.sin(alpha), 0],
                          [torch.sin(alpha),  torch.cos(alpha), 0],
                          [0,                0,                1]])

    rot_y = torch.tensor([[torch.cos(beta), 0, torch.sin(beta)],
                          [0,               1,               0],
                          [-torch.sin(beta),0, torch.cos(beta)]])
    rot_x = torch.tensor([[1, 0,                0               ],
                          [0, torch.cos(gamma), -torch.sin(beta)],
                          [0 ,torch.sin(gamma), torch.cos(gamma)]]
                         )

    full_rotation_matrix = rot_z @ rot_y @ rot_x

    return full_rotation_matrix


def rotate_graph(graph: tg.data.Data, alpha, beta, gamma):
    """Return a copy of the graph with all geometric quantities rotated
    according to the spherical angles alpha, beta, gamma"""

    out_graph = graph.clone()
    rotation_matrix = make_3d_rotation_matrix(alpha, beta, gamma)

    out_graph.pos =(rotation_matrix @ out_graph.pos.unsqueeze(-1)).squeeze(-1)

    # Rederive the relative positions etc
    transform = EuclideanInformationTransform()
    out_graph = transform(out_graph)

    return out_graph


In [4]:
graph = data[0]
alpha, beta, gamma = map(torch.tensor, (np.pi/2., 0, 0))
rot = make_3d_rotation_matrix(alpha, beta, gamma)

In [5]:
feature_irreps = e3nn.o3.Irreps("10x0e+10x1e+10x2e")
geometric_irreps = e3nn.o3.Irreps("3x0e+3x1e+3x2e")
output_irreps = e3nn.o3.Irreps("10x0e+10x1e+10x2e")

att_head = Se3AttentionHead(
                       num_attention_layers=3,
                       feature_input_repr = feature_irreps,
                       feature_output_repr=feature_irreps,
                       geometric_repr=geometric_irreps,
                       hidden_feature_repr=feature_irreps,
                       key_and_query_irreps=feature_irreps
                       )

att = Se3EquivariantAttentionMechanism(feature_irreps,
                                       geometric_irreps,
                                       output_irreps,
                                       feature_irreps)



In [9]:
print(att.key_network.irreps_in1 == feature_irreps)

True


In [10]:
print(att.key_network.irreps_in2 == geometric_irreps)

True


In [6]:
embed = torch.nn.Embedding(graph.z.shape[1], feature_irreps.dim)
features = embed(graph.z)

relative_positions = e3nn.o3.spherical_harmonics(geometric_irreps,
                                                 graph.relative_positions,
                                                 normalize=True
                                                 )

att.forward(graph.edge_index,
            features,
            relative_positions,
            graph.distances.unsqueeze(-1))

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<eval_with_key>.61", line 15, in forward
    getitem_2 = getattr_3[slice(None, -1, None)];  getattr_3 = None
    expand_2 = empty.expand(getitem_2);  empty = getitem_2 = None
    broadcast_tensors = torch.functional.broadcast_tensors(expand, expand_1, expand_2);  expand = expand_1 = expand_2 = None
                        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    getitem_3 = broadcast_tensors[0];  broadcast_tensors = None
    getattr_4 = getitem_3.shape;  getitem_3 = None
RuntimeError: The size of tensor a (9) must match the size of tensor b (50) at non-singleton dimension 1


In [None]:
# We need to create a test graph and associate node/edge features to it

g = nx.DiGraph()
vertices = np.arange(4)
edges = [(0, 1),
         (1, 2),
         (2, 1),
         (2, 3)]

for v in vertices:
    g.add_node(v)

for e in edges:
    g.add_edge(e)

node_features = {0: {'z': 0, 'x': [0, 0]},
                 1: {'z': 0, 'x': [0, 0]}}


In [None]:
# We need to add batch dimensions
random_features = feature_irreps.randn(1, -1).unsqueeze(0)
random_geometric = geometric_irreps.randn(1, -1).unsqueeze(0)

distances = torch.tensor(1.).unsqueeze(0).unsqueeze(0)  # Add a batch dimension and a node dimension

output = att.forward(random_features, random_geometric, distances)

In [None]:
test_equivariance(rptp, 100, feature_irreps, geometric_irreps)