In [1]:
import torch

from e3nn.point.message_passing import TensorPassingHomogenous
from e3nn.radial import GaussianRadialModel
from e3nn.non_linearities.rescaled_act import swish
from e3nn.point.gate import Gate
from e3nn import rs

from torch_geometric.data import Data

import random
from functools import partial

from tqdm.notebook import tqdm

In [2]:
torch.set_default_dtype(torch.float64)

# Utils to generate random graphs

In [3]:
def gen_Rs(max_l):
    limit = min(11, max_l)
    num_ls = random.randint(1, limit)
    Ls = sorted(random.sample(range(limit), k=num_ls))
    muls = [random.randint(1, 16) for _ in range(num_ls)]
    Rs = [(mul, L, 0) for (mul, L) in zip(muls, Ls)]
    return Rs

def gen_edge_index(n_nodes, n_edges, device):
    edge_index = torch.randint(low=0, high=n_nodes, size=(2, n_edges), dtype=torch.int64, device=device)
    for idx in range(n_edges):
        if edge_index[0][idx] == edge_index[1][idx]:
            edge_index[1][idx] = torch.abs(edge_index[1][idx] - 1)
        
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
    return edge_index

def gen_graph(n_nodes, n_edges, Rs_features, device='cuda:0'):
    pos = torch.randn((n_nodes, 3), device=device)
    edge_index = gen_edge_index(n_nodes, n_edges, device)
    
    rel_vec = pos[edge_index[1]] - pos[edge_index[0]]
    abs_distances = torch.norm(rel_vec, dim=1)
    
    rel_vec = torch.nn.functional.normalize(rel_vec)
    x = torch.randn((n_nodes, rs.dim(Rs_features)), device=device)
    y = torch.tensor(n_nodes%2, dtype=torch.get_default_dtype())
    return Data(x, edge_index, pos=pos, rel_vec=rel_vec, abs_distances=abs_distances, y=y)

# Define shapes of the features (layers)

### (mul, L, p) 
### mul = multiplicity (number of copies)
### L = rotation order (order of polinomial)
### p = parity (0 means no parity, ignore for now)

In [6]:
input_Rs = gen_Rs(1)
hidden1_Rs = gen_Rs(4)
hidden2_Rs = gen_Rs(4)
output_Rs = [(1, 0, 0)]
representations = [input_Rs, hidden1_Rs, hidden2_Rs, output_Rs]
print(*zip(['input', 'hidden_1', 'hidden_2', 'output'], representations))

('input', [(8, 0, 0)]) ('hidden_1', [(7, 0, 0), (11, 1, 0), (11, 2, 0)]) ('hidden_2', [(15, 1, 0), (11, 2, 0), (3, 3, 0)]) ('output', [(1, 0, 0)])


# Define radial model
### trainable function defined on absolute distances only  

In [7]:
radial_model = partial(GaussianRadialModel, max_radius=3.2, min_radius=0.7, number_of_basis=10, h=100, L=3, act=swish)

# Define gate
### Equivariant nonlinearity

In [8]:
gate = partial(Gate, scalar_act=torch.nn.ReLU, tensor_act=torch.nn.ReLU)

# Instantiate model
### For homogeneous model all instances of radial models and gates are of the same type

In [9]:
model = TensorPassingHomogenous(representations, radial_model, gate)
model = model.cuda()

# Train

In [11]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = torch.nn.BCELoss()

for _ in tqdm(range(10000)):
    n_nodes = random.randint(2, 100)
    n_edges = random.randint(n_nodes, 5*n_nodes)
    graph = gen_graph(n_nodes, n_edges, input_Rs)

    optimizer.zero_grad()
    output = torch.sigmoid(model(graph).sum())
    loss = criterion(output, graph.y)
    loss.backward()
    optimizer.step()

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))


