In [1]:
import torch_geometric as tg
from utils.load_md17 import load_md17

from models.tensor_field_networks import RadiallyParamaterisedTensorProduct
import torch
import e3nn

In [2]:
# Define an Equivariance test that should work for an arbitrary module
# Then we can just test all the modules in turn 

dd = load_md17(dataset_name='aspirin CCSD',
               dataset_dir='./../real_datasets/MD17/',
               radius=2)
train_data = dd['train']

In [27]:
def compute_equivariance_error(alpha, beta, gamma, model, feature_irreps, geometric_irreps):
    random_features = feature_irreps.randn(1, -1)
    random_geometric = geometric_irreps.randn(1, -1)
    distances = torch.tensor(1.).unsqueeze(0).unsqueeze(0) # Add a batch dimension and a node dimension

    # Need to compute one 'rotation matrix' for each set of irreps
    rotation_matrix_features = feature_irreps.D_from_angles(alpha, beta, gamma)
    rotated_features = random_features @ rotation_matrix_features

    rotation_matrix_geometric = geometric_irreps.D_from_angles(alpha, beta, gamma)
    rotated_geometric = random_geometric @ rotation_matrix_geometric

    output = model.forward(random_features.unsqueeze(0).unsqueeze(0),
                           random_geometric.unsqueeze(0).unsqueeze(0),
                           distances)

    rotated_output = output @ rotation_matrix_features
    output_from_rotated_inputs = model.forward(rotated_features.unsqueeze(0).unsqueeze(0),
                                               rotated_geometric.unsqueeze(0).unsqueeze(0),
                                               distances)

    error =  (rotated_output - output_from_rotated_inputs).pow(2)/rotated_output.pow(2).sum()

    return error


def test_equivariance(model, n, feature_irreps, geometric_irreps):
    angles = e3nn.o3.rand_angles(n)

    errors = []
    for alpha, beta, gamma in zip(*angles):
        error = compute_equivariance_error(alpha, beta, gamma, model, feature_irreps, geometric_irreps)
        errors.append(error)

    return torch.concat(errors).max()

In [28]:
feature_irreps = e3nn.o3.Irreps("10x0e + 10x1e + 10x2e")
geometric_irreps = e3nn.o3.Irreps("3x0e+3x1e+3x2e")
output_irreps = e3nn.o3.Irreps("10x0e + 10x1e + 10x2e")

rptp = RadiallyParamaterisedTensorProduct(feature_irreps,
                                          geometric_irreps,
                                          output_irreps,
                                          radial_hidden_units=16
                                          )

In [None]:
rptp.irreps_in1

In [29]:
test_equivariance(rptp, 100, feature_irreps, geometric_irreps)

tensor(1.2935e-10, grad_fn=<MaxBackward1>)

In [9]:
# We need to add batch dimensions
random_features = feature_irreps.randn(1, -1).unsqueeze(0)
random_geometric = geometric_irreps.randn(1, -1).unsqueeze(0)

distances = torch.tensor(1.).unsqueeze(0).unsqueeze(0) # Add a batch dimension and a node dimension

output = rptp.forward(random_features, random_geometric, distances)