In [26]:
import torch_geometric as tg

import torch
import e3nn

import networkx as nx
import numpy as np

from models.attention_mechanisms import Se3AttentionHead
from utils.load_md17 import load_md17
from utils.transforms import EuclideanInformationTransform

In [11]:
dd = load_md17('aspirin CCSD', './../real_datasets/MD17', 2)
data = dd['train']

In [23]:
def make_3d_rotation_matrix(alpha, beta, gamma):
    rot_z = torch.tensor([[torch.cos(alpha), -torch.sin(alpha), 0],
                          [torch.sin(alpha),  torch.cos(alpha), 0],
                          [0,                0,                1]])

    rot_y = torch.tensor([[torch.cos(beta), 0, torch.sin(beta)],
                          [0,               1,               0],
                          [-torch.sin(beta),0, torch.cos(beta)]])
    rot_x = torch.tensor([[1, 0,                0               ],
                          [0, torch.cos(gamma), -torch.sin(beta)],
                          [0 ,torch.sin(gamma), torch.cos(gamma)]]
                         )

    full_rotation_matrix = rot_z @ rot_y @ rot_x

    return full_rotation_matrix


def rotate_graph(graph: tg.data.Data, alpha, beta, gamma):
    """Return a copy of the graph with all geometric quantities rotated
    according to the spherical angles alpha, beta, gamma"""

    out_graph = graph.clone()
    rotation_matrix = make_3d_rotation_matrix(alpha, beta, gamma)

    out_graph.pos =(rotation_matrix @ out_graph.pos.unsqueeze(-1)).squeeze(-1)

    # Rederive the relative positions etc
    transform = EuclideanInformationTransform
    out_graph = transform(out_graph)

    return out_graph


In [19]:
graph = data[0]
alpha, beta, gamma = map(torch.tensor, (np.pi/2., 0, 0))
rot = make_3d_rotation_matrix(alpha, beta, gamma)

In [27]:
feature_irreps = e3nn.o3.Irreps("10x0e + 10x1e + 10x2e")
geometric_irreps = e3nn.o3.Irreps("3x0e+3x1e+3x2e")
output_irreps = e3nn.o3.Irreps("10x0e+10x1e+10x2e")

att = Se3AttentionHead(num_attention_layers=3,
                       feature_input_repr = feature_irreps,
                       feature_output_repr=feature_irreps,
                       geometric_repr=geometric_irreps,
                       hidden_feature_repr=feature_irreps,
                       key_and_query_irreps=feature_irreps
                       )



In [31]:
att.forward(graph.edge_index.unsqueeze(-1), graph.z.unsqueeze(1), graph.relative_positions.unsqueeze(-1), graph.distances.unsqueeze(-1))

AssertionError: Incorrect last dimension for x

In [None]:
# We need to create a test graph and associate node/edge features to it

g = nx.DiGraph()
vertices = np.arange(4)
edges = [(0, 1),
         (1, 2),
         (2, 1),
         (2, 3)]

for v in vertices:
    g.add_node(v)

for e in edges:
    g.add_edge(e)

node_features = {0: {'z': 0, 'x': [0, 0]},
                 1: {'z': 0, 'x': [0, 0]}}


In [3]:
# We need to add batch dimensions
random_features = feature_irreps.randn(1, -1).unsqueeze(0)
random_geometric = geometric_irreps.randn(1, -1).unsqueeze(0)

distances = torch.tensor(1.).unsqueeze(0).unsqueeze(0)  # Add a batch dimension and a node dimension

output = att.forward(random_features, random_geometric, distances)

TypeError: forward() missing 1 required positional argument: 'distances'

In [None]:
test_equivariance(rptp, 100, feature_irreps, geometric_irreps)