In [2]:
from typing import Dict

import torch
import torch_geometric.utils
from e3nn import o3
from e3nn import nn as enn
from e3nn.math import soft_one_hot_linspace
from e3nn.nn import Gate
from e3nn.util.jit import compile_mode

import rem.e3nn_utils as e3nn_utils
from e3nn.nn import SO3Activation
import rem.pooling as pooling

from datasets import DragMeshDataset
from torch_geometric.data import DataLoader

In [6]:
ds = DragMeshDataset("data/cube50k.dat", "STLs/Cube_38_1m.stl", return_features_separately=False)
dl = iter(DataLoader(ds, batch_size=1, shuffle=False))

  df = pd.read_csv(utils.to_absolute_path(attr_file), delim_whitespace=True, header=None)


In [22]:
class GNN(torch.nn.Module):
  def __init__(
    self,
    irreps_node_input,
    irreps_node_output,
    max_radius,
    mul=50,
    layers=3,
    lmax=2,
    pool_nodes=True,
  ) -> None:
    super().__init__()

    self.lmax = lmax
    self.max_radius = max_radius
    self.number_of_basis = 10
    self.pool_nodes = pool_nodes

    irreps_node_hiddens = list()
    irreps_edge_hiddens = list()
    for layer_lmax in lmax[:-1]:
      irreps_node_hiddens.append(o3.Irreps(
        #[(mul * 2 * l + 1, (l, 1)) for l in range(layer_lmax + 1)]
        [(mul, (l, p)) for l in range(layer_lmax + 1) for p in [-1, 1]]
      ))
      irreps_edge_hiddens.append(o3.Irreps.spherical_harmonics(layer_lmax))


    self.irreps_node_seq = [irreps_node_input] + irreps_node_hiddens + [irreps_node_output]
    irreps_edge_seq = irreps_edge_hiddens + [o3.Irreps.spherical_harmonics(lmax[-1])]
    self.mp = MessagePassing(
      irreps_node_sequence=self.irreps_node_seq,
      irreps_edge_attrs=irreps_edge_seq,
      fc_neurons=[self.number_of_basis, 100],
      max_radius=max_radius,
      lmax=self.lmax
    )
    self.irreps_node_input = self.mp.irreps_node_input
    self.irreps_node_output = self.mp.irreps_node_output

  def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
    batch, node_outputs = self.mp(data)

    if self.pool_nodes:
      return torch_geometric.utils.scatter(node_outputs, batch, dim=0, reduce='mean')
    else:
      return node_outputs

class MessagePassing(torch.nn.Module):
  def __init__(
    self,
    irreps_node_sequence,
    irreps_edge_attrs,
    fc_neurons,
    max_radius,
    lmax,
  ) -> None:
    super().__init__()
    self.lmax = lmax
    self.max_radius = max_radius
    self.number_of_basis = 10

    irreps_node_sequence = [o3.Irreps(irreps) for irreps in irreps_node_sequence]
    self.irreps_edge_attrs = [o3.Irreps(irreps) for irreps in irreps_edge_attrs]

    act = {
      1: torch.nn.functional.silu,
      -1: torch.tanh,
    }
    act_gates = {
      1: torch.sigmoid,
      -1: torch.tanh,
    }

    self.layers = torch.nn.ModuleList()

    self.irreps_node_sequence = [irreps_node_sequence[0]]
    irreps_node = irreps_node_sequence[0]

    for li, (irreps_node_hidden, irreps_edge_attr) in enumerate(zip(irreps_node_sequence[1:-1], self.irreps_edge_attrs[:-1])):
      irreps_scalars = o3.Irreps(
        [
          (mul, ir)
          for mul, ir in irreps_node_hidden
          if ir.l == 0
          and e3nn_utils.tp_path_exists(
            irreps_node, irreps_edge_attr, ir
          )
        ]
      ).simplify()
      irreps_gated = o3.Irreps(
        [
          (mul, ir)
          for mul, ir in irreps_node_hidden
          if ir.l > 0
          and e3nn_utils.tp_path_exists(
            irreps_node, irreps_edge_attr, ir
          )
        ]
      )
      if irreps_gated.dim > 0:
        if e3nn_utils.tp_path_exists(irreps_node, irreps_edge_attr, "0e"):
          ir = "0e"
        elif e3nn_utils.tp_path_exists(
          irreps_node, irreps_edge_attr, "0o"
        ):
          ir = "0o"
        else:
          raise ValueError(
            f"irreps_node={irreps_node} times irreps_edge_attr={self.irreps_edge_attr} is unable to produce gates "
            f"needed for irreps_gated={irreps_gated}"
          )
      else:
        ir = None
      irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify()

      gate = Gate(
        irreps_scalars,
        [act[ir.p] for _, ir in irreps_scalars],  # scalar
        irreps_gates,
        [act_gates[ir.p] for _, ir in irreps_gates],  # gates (scalars)
        irreps_gated,  # gated tensors
      )
      conv = GraphConvolution(
        irreps_node,
        irreps_edge_attr,
        gate.irreps_in,
        fc_neurons,
      )
      #if li == 1:
      #  pool = pooling.VoxelPooling(gate.irreps_out, self.lmax, [0.5, 0.5, 0.5])#, start=[-2, -2, -2],end=[2, 2, 2])
      #else:
      #  pool = pooling.VoxelPooling(gate.irreps_out, self.lmax, [0.1, 0.1, 0.5])#, start=[-2, -2, -2],end=[2, 2, 2])
      #pool = pooling.EdgePooling(gate.irreps_out, self.lmax)
      pool = pooling.TopKPooling(gate.irreps_out, self.lmax)
      #pool = None
      self.layers.append(e3nn_utils.GraphConvBlock(conv, gate, pool))
      irreps_node = gate.irreps_out
      self.irreps_node_sequence.append(irreps_node)

    irreps_node_output = irreps_node_sequence[-1]
    self.layers.append(
      e3nn_utils.GraphConvBlock(
        GraphConvolution(
          irreps_node,
          self.irreps_edge_attrs[-1],
          irreps_node_output,
          fc_neurons,
        )
      )
    )
    self.irreps_node_sequence.append(irreps_node_output)

    self.irreps_node_input = self.irreps_node_sequence[0]
    self.irreps_node_output = self.irreps_node_sequence[-1]

  def forward(self, data) -> torch.Tensor:
    for i, lay in enumerate(self.layers):
      # Edge attributes
      edge_sh = o3.spherical_harmonics(
        range(self.lmax[i] + 1), data.edge_vec, True, normalization="component"
      )
      data.edge_attr_sh = edge_sh

      # Edge length embedding
      edge_length = data.edge_vec.norm(dim=1)
      data.edge_scalars = soft_one_hot_linspace(
        edge_length,
        0.0,
        self.max_radius,
        self.number_of_basis,
        basis="smooth_finite",  # the smooth_finite basis with cutoff = True goes to zero at max_radius
        cutoff=True,  # no need for an additional smooth cutoff
      ).mul(self.number_of_basis**0.5)

      # Forward
      data = lay(data)

    return data.batch, data.x

@compile_mode("script")
class GraphConvolution(torch.nn.Module):
  def __init__(
    self, irreps_node_input, irreps_edge_attr, irreps_node_output, fc_neurons
  ) -> None:
    super().__init__()
    self.irreps_node_input = o3.Irreps(irreps_node_input)
    self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
    self.irreps_node_output = o3.Irreps(irreps_node_output)
    self.sc = o3.Linear(self.irreps_node_input, self.irreps_node_output)
    # self.lin1 = o3.FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_input)

    irreps_mid = []
    instructions = []
    for i, (mul, ir_in) in enumerate(self.irreps_node_input):
      for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
        for ir_out in ir_in * ir_edge:
          if ir_out in self.irreps_node_output or ir_out == o3.Irrep(0, 1):
            k = len(irreps_mid)
            irreps_mid.append((mul, ir_out))
            instructions.append((i, j, k, "uvu", True))
    irreps_mid = o3.Irreps(irreps_mid)
    irreps_mid, p, _ = irreps_mid.sort()

    assert irreps_mid.dim > 0, (
      f"irreps_node_input={self.irreps_node_input} time irreps_edge_attr={self.irreps_edge_attr} produces nothing "
      f"in irreps_node_output={self.irreps_node_output}"
    )
    instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions]

    tp = o3.TensorProduct(
      self.irreps_node_input,
      self.irreps_edge_attr,
      irreps_mid,
      instructions,
      internal_weights=False,
      shared_weights=False,
    )
    self.fc = enn.FullyConnectedNet(fc_neurons + [tp.weight_numel], torch.nn.functional.silu)
    self.tp = tp

    self.lin = o3.Linear(irreps_mid, self.irreps_node_output)
    # self.lin2 = o3.FullyConnectedTensorProduct(...)

    # inspired by https://arxiv.org/pdf/2002.10444.pdf
    self.alpha = o3.Linear(irreps_mid, "0e")
    with torch.no_grad():
      self.alpha.weight.zero_()

  def forward(self, data) -> torch.Tensor:
    weight = self.fc(data.edge_scalars)

    node_self_connection =  self.sc(data.x)
    # print(f"Node self connection: {node_self_connection.shape}")
    # print(node_self_connection)
    edge_features = self.tp(data.x[data.edge_index[0]], data.edge_attr_sh, weight)
    # print(f"Edge features: {edge_features.shape}")
    # print(edge_features)
    node_features = torch_geometric.utils.scatter(edge_features, data.edge_index[1], dim=0, dim_size=data.x.shape[0], reduce='mean')
    # print(f"Node features: {node_features.shape}")
    # print(node_features)

    alpha = self.alpha(node_features)
    # print(f"Alpha: {alpha.shape}")
    # print(alpha)
    node_conv_out = self.lin(node_features)
    # print(f"Node conv out: {node_conv_out.shape}")
    # print(node_conv_out)

    m = self.sc.output_mask
    alpha = (1 - m) + alpha * m
    return node_self_connection + alpha * node_conv_out
  
class Decoder(torch.nn.Module):
  def __init__(self, lmax_in, lmax_out, f_in, f_out):
    super().__init__()

    grid_s2 = e3nn_utils.s2_near_identity_grid()
    grid_so3 = e3nn_utils.so3_near_identity_grid()

    self.so3_conv1 = e3nn_utils.SO3Convolution(
      f_in, 64, lmax_in, kernel_grid=grid_so3
    )
    self.act1 = SO3Activation(lmax_in, lmax_out, torch.relu, resolution=12)

    self.so3_conv2 = e3nn_utils.SO3Convolution(
      64, 128, lmax_in, kernel_grid=grid_so3
    )
    self.act2 = SO3Activation(lmax_in, lmax_out, torch.relu, resolution=12)

    self.so3_conv3 = e3nn_utils.SO3Convolution(
      128, 256, lmax_in, kernel_grid=grid_so3
    )

    # 2) Project out ONLY the l=0 block (the scalars)
      #    This activation sets lmax_out=0 => keep only l=0
    self.act3 = SO3Activation(lmax_in, 0, torch.relu, resolution=12)
    
    self.lin = o3.Linear("1x0e", f"1x0e", f_in=256, f_out=1)

    # # Output: Maps to 53 (rho_0, rho_1, rho_2, rho_3, ...) -> 53 S2 signals
    # if self.invariant_out:
    #   self.act3 = SO3Activation(lmax_in, 0, torch.relu, resolution=12)
    #   self.lin = o3.Linear(256, f_out)
    # else:
    #   self.act3 = SO3Activation(lmax_in, lmax_out, torch.relu, resolution=12)
    #   self.lin = e3nn_utils.SO3ToS2Convolution(
    #     256, f_out, lmax_out, kernel_grid=grid_s2
    #   )

  def forward(self, x):
    x = self.so3_conv1(x)
    x = self.act1(x)

    x = self.so3_conv2(x)
    x = self.act2(x)

    x = self.so3_conv3(x)
    x = self.act3(x)  # keep only l=0
    print(x.shape)

    x = self.lin(x)   # from 256 scalar channels -> f_out scalar channels
    return x

In [67]:
class DEDM(torch.nn.Module):
  def __init__(self, num_node_features, z_lmax, max_radius, out_dim):
    super().__init__()

#    z_lmax = 4
    self.lmax = z_lmax
    self.out_dim = out_dim
    f = 16

    self.irreps_in = o3.Irreps(f"{num_node_features}x0e")
    self.irreps_latent = e3nn_utils.so3_irreps(z_lmax)
    self.irreps_enc_out = o3.Irreps(
      #[(f, (l, p)) for l in range((z_lmax // 2) + 1) for p in [-1,1]]
      [(f, (l, p)) for l in range((z_lmax) + 1) for p in [-1,1]]
    )
    self.irreps_enc_out = o3.Irreps("16x0e")
    
    self.encoder = GNN(
        irreps_node_input=self.irreps_in,
        irreps_node_output=self.irreps_enc_out,
        max_radius=max_radius,
        mul=f,
        #lmax=[self.lmax // 2, self.lmax // 2, self.lmax // 2],
        lmax=[self.lmax, self.lmax],
      )

    # TODO figure out what this linear layer actually is
    # remove nonlinearities (could be an error) then VN could help
    # equivariance error for encoder and decoder (on a layer by layer basis)
    # overfit to a spherical signal in the decoder
    # latent space
    # TODO develop a baseline mesh to radar model and see what the error is
    # resolution?
    self.lin = o3.Linear(self.irreps_enc_out, self.irreps_latent, f_in=1, f_out=f)
    self.decoder = Decoder(z_lmax, z_lmax, f, out_dim)

  def forward(self, x, return_latent=False):
    batch_size = x.batch.max() + 1
    gnn_out = self.encoder(x)
    z = self.lin(gnn_out.view(batch_size, 1, -1))
    out = self.decoder(z)

    return gnn_out, out
    cartesian = self.ar2los(x.orientation)
    out_response = self._getResponse(out, cartesian)

    if return_latent:
      return (out_response, out)
    else:
      return out_response

  def _getResponse(self, out, pose):
    if self.invariant_out:
      return out
    else:
      sh = torch.concatenate(
        [o3.spherical_harmonics(l, pose, True) for l in range(self.lmax + 1)], dim=1
      ).unsqueeze(2)  # B x (L^2 * S^2) x 1
      response = torch.bmm(out, sh).squeeze()  # B x D

      return response

In [68]:
dedm = DEDM(num_node_features=5, z_lmax=4, max_radius=1.8, out_dim=1)



In [73]:
samp, y = next(dl)
samp_copy = samp.clone()
samp_copy2 = samp.clone()
samp.orientation = torch.tensor([[0.0, 0.0]])
print(samp)
print(samp.orientation)
print(samp.x)
print(samp.pos)

DataBatch(x=[38, 5], edge_index=[2, 108], pos=[38, 3], edge_vec=[108, 3], orientation=[1, 2], batch=[38], ptr=[2])
tensor([[0., 0.]])
tensor([[0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246, 0.9524, 0.0676, 0.6570, 0.0993],
        [0.6246,

In [74]:
ret = dedm(samp)
print(ret[1].shape)
print(ret[1])
print(ret[0])

torch.Size([1, 256, 1])
torch.Size([1, 1, 1])
tensor([[[-2.3762e-05]]], grad_fn=<ViewBackward0>)
tensor([[-0.0499,  0.0266,  0.0319,  0.0059, -0.0337,  0.0068,  0.0067,  0.0191,
         -0.0003, -0.0375, -0.0089, -0.0043, -0.0044,  0.0049,  0.0733, -0.0324]],
       grad_fn=<DivBackward0>)


In [75]:
samp_copy.pos = samp_copy.pos @ o3.rand_matrix()
ret_copy = dedm(samp_copy)
print(ret_copy[1].shape)
print(ret_copy[1])
print(ret_copy[0])
assert torch.allclose(ret_copy[0], ret[0], atol=1e-3)

torch.Size([1, 256, 1])
torch.Size([1, 1, 1])
tensor([[[-2.3762e-05]]], grad_fn=<ViewBackward0>)
tensor([[-0.0499,  0.0266,  0.0319,  0.0059, -0.0337,  0.0068,  0.0067,  0.0191,
         -0.0003, -0.0375, -0.0089, -0.0043, -0.0044,  0.0049,  0.0733, -0.0324]],
       grad_fn=<DivBackward0>)


In [None]:
noise = torch.randn_like(samp_copy2.pos) * 1.5  # small random noise shouldn't be equivariant
samp_copy2.pos += noise
ret_copy2 = dedm(samp_copy2)
print(ret_copy2[1].shape)
print(ret_copy2[1])
print(ret_copy2[0])
# ret_copy2[0] all close to ret[0] -> equivariant
assert torch.allclose(ret_copy2[0], ret[0], atol=1e-5)

torch.Size([1, 256, 1])
torch.Size([1, 1, 1])
tensor([[[-2.3762e-05]]], grad_fn=<ViewBackward0>)
tensor([[-0.0499,  0.0266,  0.0319,  0.0059, -0.0337,  0.0068,  0.0067,  0.0191,
         -0.0003, -0.0375, -0.0089, -0.0043, -0.0044,  0.0049,  0.0733, -0.0324]],
       grad_fn=<DivBackward0>)
