In [1]:
import torch
import random
import math

from e3nn import o3
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import to_undirected

from cartesian_mace.models.model import CartesianMACE
from src.models import MACEModel

%load_ext snakeviz
%load_ext autoreload
%autoreload 2

In [2]:
def create_rotsym_envs(fold=3):
    dataset = []

    # Environment 0
    atoms = torch.LongTensor([ 0 ] + [ 0 ] * fold)
    edge_index = torch.LongTensor( [ [0] * fold, [i for i in range(1, fold+1)] ] )
    x = torch.Tensor([1,0,0])
    pos = [
        torch.Tensor([0,0,0]),  # origin
        x,   # first spoke
    ]
    for count in range(1, fold):
        R = o3.matrix_z(torch.Tensor([2*math.pi/fold * count])).squeeze(0)
        pos.append(x @ R.T)
    pos = torch.stack(pos)
    y = torch.LongTensor([0])  # Label 0
    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)

    # Environment 1
    q = 2*math.pi/(fold + random.randint(1, fold))
    assert q < 2*math.pi/fold
    Q = o3.matrix_z(torch.Tensor([q])).squeeze(0)
    pos = pos @ Q.T
    y = torch.LongTensor([1])  # Label 1
    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data2.edge_index = to_undirected(data2.edge_index)
    dataset.append(data2)

    return dataset

In [3]:
# Set parameters
model_name = "cmace"
correlation = 2
max_ell = 3
fold = 3

# Create dataset
dataset = create_rotsym_envs(fold)
# for data in dataset:
    # plot_2d(data, lim=1)

# Create dataloaders
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(dataset, batch_size=1, shuffle=False)

n_layers = 1
fold = 3

batch = list(dataloader)[0]


cmace_model = CartesianMACE(n_layers=n_layers, dim=3, n_channels=3, n_nodes=fold+1, self_tp_rank_max=2, basis_rank_max=2, n_edges=fold, feature_rank_max=max_ell,nu_max=correlation)
mace_model = MACEModel(scalar_pred=False, correlation=correlation, num_layers=n_layers, out_dim=2, max_ell=max_ell, emb_dim=3)




TypeError: __init__() got an unexpected keyword argument 'n_nodes'

In [None]:
%snakeviz -t cmace_model(batch)

In [None]:
%timeit cmace_model(batch)
%timeit mace_model(batch)

In [None]:
cmace_model(batch)