In [7]:
from typing import Dict

import torch
import torch_geometric.utils
from e3nn import o3
from e3nn import nn as enn
from e3nn.math import soft_one_hot_linspace
from e3nn.nn import Gate
from e3nn.util.jit import compile_mode

import rem.e3nn_utils as e3nn_utils
import rem.pooling as pooling

from datasets import DragMeshDataset
from torch_geometric.data import DataLoader

In [None]:
class GNN(torch.nn.Module):
  def __init__(
    self,
    irreps_node_input,
    irreps_node_output,
    max_radius,
    mul=50,
    layers=3,
    lmax=2,
    pool_nodes=True,
  ) -> None:
    super().__init__()

    self.lmax = lmax
    self.max_radius = max_radius
    self.number_of_basis = 10
    self.pool_nodes = pool_nodes

    irreps_node_hiddens = list()
    irreps_edge_hiddens = list()
    for layer_lmax in lmax[:-1]:
      irreps_node_hiddens.append(o3.Irreps(
        #[(mul * 2 * l + 1, (l, 1)) for l in range(layer_lmax + 1)]
        [(mul, (l, p)) for l in range(layer_lmax + 1) for p in [-1, 1]]
      ))
      irreps_edge_hiddens.append(o3.Irreps.spherical_harmonics(layer_lmax))


    self.irreps_node_seq = [irreps_node_input] + irreps_node_hiddens + [irreps_node_output]
    irreps_edge_seq = irreps_edge_hiddens + [o3.Irreps.spherical_harmonics(lmax[-1])]
    self.mp = MessagePassing(
      irreps_node_sequence=self.irreps_node_seq,
      irreps_edge_attrs=irreps_edge_seq,
      fc_neurons=[self.number_of_basis, 100],
      max_radius=max_radius,
      lmax=self.lmax
    )
    self.irreps_node_input = self.mp.irreps_node_input
    self.irreps_node_output = self.mp.irreps_node_output

  def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
    batch, node_outputs = self.mp(data)

    if self.pool_nodes:
      return torch_geometric.utils.scatter(node_outputs, batch, dim=0, reduce='mean')
    else:
      return node_outputs

class MessagePassing(torch.nn.Module):
  def __init__(
    self,
    irreps_node_sequence,
    irreps_edge_attrs,
    fc_neurons,
    max_radius,
    lmax,
  ) -> None:
    super().__init__()
    self.lmax = lmax
    self.max_radius = max_radius
    self.number_of_basis = 10

    irreps_node_sequence = [o3.Irreps(irreps) for irreps in irreps_node_sequence]
    self.irreps_edge_attrs = [o3.Irreps(irreps) for irreps in irreps_edge_attrs]

    act = {
      1: torch.nn.functional.silu,
      -1: torch.tanh,
    }
    act_gates = {
      1: torch.sigmoid,
      -1: torch.tanh,
    }

    self.layers = torch.nn.ModuleList()

    self.irreps_node_sequence = [irreps_node_sequence[0]]
    irreps_node = irreps_node_sequence[0]

    for li, (irreps_node_hidden, irreps_edge_attr) in enumerate(zip(irreps_node_sequence[1:-1], self.irreps_edge_attrs[:-1])):
      irreps_scalars = o3.Irreps(
        [
          (mul, ir)
          for mul, ir in irreps_node_hidden
          if ir.l == 0
          and e3nn_utils.tp_path_exists(
            irreps_node, irreps_edge_attr, ir
          )
        ]
      ).simplify()
      irreps_gated = o3.Irreps(
        [
          (mul, ir)
          for mul, ir in irreps_node_hidden
          if ir.l > 0
          and e3nn_utils.tp_path_exists(
            irreps_node, irreps_edge_attr, ir
          )
        ]
      )
      if irreps_gated.dim > 0:
        if e3nn_utils.tp_path_exists(irreps_node, irreps_edge_attr, "0e"):
          ir = "0e"
        elif e3nn_utils.tp_path_exists(
          irreps_node, irreps_edge_attr, "0o"
        ):
          ir = "0o"
        else:
          raise ValueError(
            f"irreps_node={irreps_node} times irreps_edge_attr={self.irreps_edge_attr} is unable to produce gates "
            f"needed for irreps_gated={irreps_gated}"
          )
      else:
        ir = None
      irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify()

      gate = Gate(
        irreps_scalars,
        [act[ir.p] for _, ir in irreps_scalars],  # scalar
        irreps_gates,
        [act_gates[ir.p] for _, ir in irreps_gates],  # gates (scalars)
        irreps_gated,  # gated tensors
      )
      conv = GraphConvolution(
        irreps_node,
        irreps_edge_attr,
        gate.irreps_in,
        fc_neurons,
      )
      #if li == 1:
      #  pool = pooling.VoxelPooling(gate.irreps_out, self.lmax, [0.5, 0.5, 0.5])#, start=[-2, -2, -2],end=[2, 2, 2])
      #else:
      #  pool = pooling.VoxelPooling(gate.irreps_out, self.lmax, [0.1, 0.1, 0.5])#, start=[-2, -2, -2],end=[2, 2, 2])
      #pool = pooling.EdgePooling(gate.irreps_out, self.lmax)
      pool = pooling.TopKPooling(gate.irreps_out, self.lmax)
      #pool = None
      self.layers.append(e3nn_utils.GraphConvBlock(conv, gate))
      irreps_node = gate.irreps_out
      self.irreps_node_sequence.append(irreps_node)

    irreps_node_output = irreps_node_sequence[-1]
    self.layers.append(
      e3nn_utils.GraphConvBlock(
        GraphConvolution(
          irreps_node,
          self.irreps_edge_attrs[-1],
          irreps_node_output,
          fc_neurons,
        )
      )
    )
    

    self.irreps_node_sequence.append(irreps_node_output)

    self.irreps_node_input = self.irreps_node_sequence[0]
    self.irreps_node_output = self.irreps_node_sequence[-1]

  def forward(self, data) -> torch.Tensor:
    for i, lay in enumerate(self.layers):
      # Edge attributes
      edge_sh = o3.spherical_harmonics(
        range(self.lmax[i] + 1), data.edge_vec, True, normalization="component"
      )
      data.edge_attr_sh = edge_sh

      # Edge length embedding
      edge_length = data.edge_vec.norm(dim=1)
      data.edge_scalars = soft_one_hot_linspace(
        edge_length,
        0.0,
        self.max_radius,
        self.number_of_basis,
        basis="smooth_finite",  # the smooth_finite basis with cutoff = True goes to zero at max_radius
        cutoff=True,  # no need for an additional smooth cutoff
      ).mul(self.number_of_basis**0.5)

      # Forward
      data = lay(data)
    
    # data.x = self.final_act(data.x)

    return data.batch, data.x

@compile_mode("script")
class GraphConvolution(torch.nn.Module):
  def __init__(
    self, irreps_node_input, irreps_edge_attr, irreps_node_output, fc_neurons
  ) -> None:
    super().__init__()
    self.irreps_node_input = o3.Irreps(irreps_node_input)
    self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
    self.irreps_node_output = o3.Irreps(irreps_node_output)
    self.sc = o3.Linear(self.irreps_node_input, self.irreps_node_output)
    # self.lin1 = o3.FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_input)

    irreps_mid = []
    instructions = []
    for i, (mul, ir_in) in enumerate(self.irreps_node_input):
      for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
        for ir_out in ir_in * ir_edge:
          if ir_out in self.irreps_node_output or ir_out == o3.Irrep(0, 1):
            k = len(irreps_mid)
            irreps_mid.append((mul, ir_out))
            instructions.append((i, j, k, "uvu", True))
    irreps_mid = o3.Irreps(irreps_mid)
    irreps_mid, p, _ = irreps_mid.sort()

    assert irreps_mid.dim > 0, (
      f"irreps_node_input={self.irreps_node_input} time irreps_edge_attr={self.irreps_edge_attr} produces nothing "
      f"in irreps_node_output={self.irreps_node_output}"
    )
    instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions]

    tp = o3.TensorProduct(
      self.irreps_node_input,
      self.irreps_edge_attr,
      irreps_mid,
      instructions,
      internal_weights=False,
      shared_weights=False,
    )
    self.fc = enn.FullyConnectedNet(fc_neurons + [tp.weight_numel], torch.nn.functional.silu)
    self.tp = tp

    self.lin = o3.Linear(irreps_mid, self.irreps_node_output)
    # self.lin2 = o3.FullyConnectedTensorProduct(...)

    # inspired by https://arxiv.org/pdf/2002.10444.pdf
    self.alpha = o3.Linear(irreps_mid, "0e")
    # with torch.no_grad():
    #   self.alpha.weight.zero_()

  def forward(self, data) -> torch.Tensor:
    weight = self.fc(data.edge_scalars)

    node_self_connection =  self.sc(data.x)
    # print(f"Node self connection: {node_self_connection.shape}")
    # print(node_self_connection)
    edge_features = self.tp(data.x[data.edge_index[0]], data.edge_attr_sh, weight)
    # print(f"Edge features: {edge_features.shape}")
    # print(edge_features)
    node_features = torch_geometric.utils.scatter(edge_features, data.edge_index[1], dim=0, dim_size=data.x.shape[0], reduce='mean')
    # print(f"Node features: {node_features.shape}")
    # print(node_features)

    alpha = self.alpha(node_features)
    # print(f"Alpha: {alpha.shape}")
    # print(alpha)
    node_conv_out = self.lin(node_features)
    # print(f"Node conv out: {node_conv_out.shape}")
    # print(node_conv_out)

    m = self.sc.output_mask
    alpha = (1 - m) + alpha * m
    return node_self_connection + alpha * node_conv_out



In [400]:
lmax = 5
irreps_in = o3.Irreps("5x0e")
# irreps_out = o3.Irreps("16x0o+16x0e+16x1o+16x1e+16x2o+16x2e+16x3o+16x3e+16x4o+16x4e")
irreps_out = o3.Irreps("4x0e + 1x1o")
encoder = GNN(
    irreps_node_input=irreps_in,
    irreps_node_output=irreps_out,
    max_radius=1.7,
    mul=50,
    pool_nodes=True,
    lmax=[lmax, lmax, lmax],
)

print(encoder.irreps_node_input)
print(encoder.irreps_node_output)
print(encoder.irreps_node_seq)

encoder



5x0e
4x0e+1x1o
[5x0e, 50x0o+50x0e+50x1o+50x1e+50x2o+50x2e+50x3o+50x3e+50x4o+50x4e+50x5o+50x5e, 50x0o+50x0e+50x1o+50x1e+50x2o+50x2e+50x3o+50x3e+50x4o+50x4e+50x5o+50x5e, 4x0e+1x1o]


GNN(
  (mp): MessagePassing(
    (layers): ModuleList(
      (0): GraphConvBlock(
        (gconv): GraphConvolution(
          (sc): Linear(5x0e -> 300x0e+50x1o+50x2e+50x3o+50x4e+50x5o | 1500 weights)
          (fc): FullyConnectedNet[10, 100, 30]
          (tp): TensorProduct(5x0e x 1x0e+1x1o+1x2e+1x3o+1x4e+1x5o -> 5x0e+5x1o+5x2e+5x3o+5x4e+5x5o | 30 paths | 30 weights)
          (lin): Linear(5x0e+5x1o+5x2e+5x3o+5x4e+5x5o -> 300x0e+50x1o+50x2e+50x3o+50x4e+50x5o | 2750 weights)
          (alpha): Linear(5x0e+5x1o+5x2e+5x3o+5x4e+5x5o -> 1x0e | 5 weights)
        )
        (act): Gate (300x0e+50x1o+50x2e+50x3o+50x4e+50x5o -> 50x0e+50x1o+50x2e+50x3o+50x4e+50x5o)
      )
      (1): GraphConvBlock(
        (gconv): GraphConvolution(
          (sc): Linear(50x0e+50x1o+50x2e+50x3o+50x4e+50x5o -> 550x0e+50x1o+50x1e+50x2o+50x2e+50x3o+50x3e+50x4o+50x4e+50x5o+50x5e | 40000 weights)
          (fc): FullyConnectedNet[10, 100, 5550]
          (tp): TensorProduct(50x0e+50x1o+50x2e+50x3o+50x4e+50x5o x

In [401]:
ds = DragMeshDataset("data/cube50k.dat", "STLs/Cube_38_1m.stl", return_features_separately=False)
dl = iter(DataLoader(ds, batch_size=1, shuffle=False))

  df = pd.read_csv(utils.to_absolute_path(attr_file), delim_whitespace=True, header=None)


In [402]:
samp, y = next(dl)
samp_copy = samp.clone()
samp_copy2 = samp.clone()
samp.orientation = torch.tensor([[0.0, 0.0]])
print(samp)
print(samp.orientation)
print(samp.x)
print(samp.pos)

DataBatch(x=[38, 5], edge_index=[2, 108], pos=[38, 3], edge_vec=[108, 3], orientation=[1, 2], batch=[38], ptr=[2])
tensor([[0., 0.]])
tensor([[0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587, 0.5424, 0.5712, 0.2145, 0.8279],
        [0.4587,

In [403]:
ret = encoder(samp)
print(ret.shape)
print(ret)

torch.Size([1, 7])
tensor([[ 0.0529, -0.0796, -0.4811, -0.1557, -0.0017, -0.0722, -0.0789]],
       grad_fn=<DivBackward0>)


In [404]:
samp_copy.pos = samp_copy.pos @ o3.rand_matrix()
ret_copy = encoder(samp_copy)
print(ret_copy.shape)
print(ret_copy)
# assert all close to ret
assert torch.allclose(ret, ret_copy)

torch.Size([1, 7])
tensor([[ 0.0529, -0.0796, -0.4811, -0.1557, -0.0017, -0.0722, -0.0789]],
       grad_fn=<DivBackward0>)


In [405]:
noise = torch.randn_like(samp_copy2.pos) * 1.5  # small random noise shouldn't be equivariant
samp_copy2.pos += noise
ret_copy2 = encoder(samp_copy2)
print(ret_copy2.shape)
print(ret_copy2)

assert torch.allclose(ret_copy2, ret, atol=1e-5)

torch.Size([1, 7])
tensor([[ 0.0529, -0.0796, -0.4811, -0.1557, -0.0017, -0.0722, -0.0789]],
       grad_fn=<DivBackward0>)


## Old

In [None]:
samp, y = next(dl)
samp_copy = samp.clone()
from e3nn.o3 import rand_matrix

def spherical_to_cartesian(alpha, beta):
    """
    Convert spherical angles (alpha, beta) to cartesian coordinates.
    alpha: azimuthal angle, beta: polar angle
    """
    x = torch.cos(beta) * torch.cos(alpha)
    y = torch.cos(beta) * torch.sin(alpha)
    z = torch.sin(beta)
    return torch.stack([x, y, z], dim=-1)


def rotate_orientation(orientation, R):
    """
    Rotate 2D orientation (alpha, beta) values using a rotation matrix.
    Orientation is converted to Cartesian, rotated, and then converted back.
    """
    # Convert spherical angles (alpha, beta) to Cartesian coordinates
    cartesian = spherical_to_cartesian(orientation[:, 0], orientation[:, 1])

    # Apply rotation
    rotated_cartesian = torch.einsum('ij,nj->ni', R, cartesian)

    # Convert back to spherical coordinates
    rho = rotated_cartesian.norm(dim=-1)
    beta = torch.asin(rotated_cartesian[:, 2] / rho)
    alpha = torch.atan2(rotated_cartesian[:, 1], rotated_cartesian[:, 0])
    return torch.stack([alpha, beta], dim=-1)


def test_encoder_equivariance(encoder, data):
    """
    Test the equivariance of rem.encoder with respect to SO(3) rotations.

    Parameters:
        rem: REM
            The REM model with the encoder to test.
        data: torch_geometric.data.Data
            Input graph data with `data.orientation` as (alpha, beta).

    Returns:
        bool: True if invariant, False otherwise.
    """
    # Generate a random SO(3) rotation matrix
    R = rand_matrix()

    # Rotate orientation
    original_orientation = data.orientation.clone()
    print("Original orientation", original_orientation)
    rotated_orientation = rotate_orientation(original_orientation, R)
    print("New orientation", rotated_orientation)

    if hasattr(data, 'pos'):
        data.pos = torch.einsum('ij,nj->ni', R, data.pos)
        data.pos[:, 1] += 0.5

    # Clone data and apply rotated orientation
    data_rotated = data.clone()
    data_rotated.orientation = rotated_orientation

    # Forward pass on original and rotated data
    encoder.eval()  # Ensure evaluation mode
    with torch.no_grad():
        output_original = encoder(data)
        # print("Output original", output_original)
        output_rotated = encoder(data_rotated)
        # print("Output rotated", output_rotated)

    # Check invariance: the outputs should be identical
    is_invariant = torch.allclose(output_original, output_rotated, atol=1e-6)

    print(f"Equivariance Test: {'Passed' if is_invariant else 'Failed'}")
    return is_invariant

test_encoder_equivariance(encoder, samp_copy)

Original orientation tensor([[ 5.6706, -0.7876]])
New orientation tensor([[-1.1105,  0.7781]])
Equivariance Test: Passed


True