In [4]:
from datasets import DragMeshDataset
from torch_geometric.data import DataLoader

In [5]:
# GRACE_A_tpmc
ds = DragMeshDataset("data/cube50k.dat", "STLs/Cube_38_1m.stl", return_features_separately=False)
dl = iter(DataLoader(ds, batch_size=1, shuffle=False))

  df = pd.read_csv(utils.to_absolute_path(attr_file), delim_whitespace=True, header=None)


In [1]:
import math
import torch
import torch as tr
import torch.nn as nn
import torch.nn.functional as F

from e3nn import nn as enn
from e3nn import o3
from e3nn.nn import SO3Activation

from rem import e3nn_utils
from rem.equiv_gnn import GNN
from rem.equiv_gnn_w_attrs import AttrGNN


# TODO try getting rid of nonlinearities and see what happens

class Decoder(nn.Module):
  def __init__(self, lmax_in, lmax_out, f_in, f_out, invariant_out=False):
    super().__init__()
    self.invariant_out = invariant_out

    grid_s2 = e3nn_utils.s2_near_identity_grid()
    grid_so3 = e3nn_utils.so3_near_identity_grid()

    self.so3_conv1 = e3nn_utils.SO3Convolution(
      f_in, 64, lmax_in, kernel_grid=grid_so3
    )
    self.act1 = SO3Activation(lmax_in, lmax_out, torch.relu, resolution=12)

    self.so3_conv2 = e3nn_utils.SO3Convolution(
      64, 128, lmax_in, kernel_grid=grid_so3
    )
    self.act2 = SO3Activation(lmax_in, lmax_out, torch.relu, resolution=12)

    self.so3_conv3 = e3nn_utils.SO3Convolution(
      128, 256, lmax_in, kernel_grid=grid_so3
    )

    # Output: Maps to 53 (rho_0, rho_1, rho_2, rho_3, ...) -> 53 S2 signals
    if self.invariant_out:
      self.act3 = SO3Activation(lmax_in, 0, torch.relu, resolution=12)
      self.lin = o3.Linear(256, f_out)
    else:
      self.act3 = SO3Activation(lmax_in, lmax_out, torch.relu, resolution=12)
      self.lin = e3nn_utils.SO3ToS2Convolution(
        256, f_out, lmax_out, kernel_grid=grid_s2
      )

  def forward(self, x):
    x = self.so3_conv1(x)
    x = self.act1(x)

    x = self.so3_conv2(x)
    x = self.act2(x)

    x = self.so3_conv3(x)
    x = self.act3(x)

    x = self.lin(x)

    return x

class REM(nn.Module):
  def __init__(self, num_node_features, z_lmax, max_radius, out_dim, invariant_out=False):
    super().__init__()

#    z_lmax = 4
    self.lmax = z_lmax
    self.out_dim = out_dim
    self.invariant_out = invariant_out
    f = 16

    self.irreps_in = o3.Irreps(f"{num_node_features}x0e")
    self.irreps_latent = e3nn_utils.so3_irreps(z_lmax)
    self.irreps_enc_out = o3.Irreps(
      #[(f, (l, p)) for l in range((z_lmax // 2) + 1) for p in [-1,1]]
      [(f, (l, p)) for l in range((z_lmax) + 1) for p in [-1,1]]
    )
    if self.invariant_out:
      self.irreps_node_attr = o3.Irreps("1x1e")
      self.encoder = AttrGNN(
        irreps_node_input=self.irreps_in,
        irreps_node_attr=self.irreps_node_attr,
        irreps_node_output=self.irreps_enc_out,
        max_radius=max_radius,
        layers=2,
        mul=f,
        lmax=[self.lmax, self.lmax, self.lmax],
      )
    else:
      self.encoder = GNN(
        irreps_node_input=self.irreps_in,
        irreps_node_output=self.irreps_enc_out,
        max_radius=max_radius,
        layers=2,
        mul=f,
        #lmax=[self.lmax // 2, self.lmax // 2, self.lmax // 2],
        lmax=[self.lmax, self.lmax, self.lmax],
      )

    # TODO figure out what this linear layer actually is
    # remove nonlinearities (could be an error) then VN could help
    # equivariance error for encoder and decoder (on a layer by layer basis)
    # overfit to a spherical signal in the decoder
    # latent space
    # TODO develop a baseline mesh to radar model and see what the error is
    # resolution?
    self.lin = o3.Linear(self.irreps_enc_out, self.irreps_latent, f_in=1, f_out=f)
    self.decoder = Decoder(z_lmax, z_lmax, f, out_dim, invariant_out=invariant_out)

  def forward(self, x, return_latent=False):
    batch_size = x.batch.max() + 1
    gnn_out = self.encoder(x)
    z = self.lin(gnn_out.view(batch_size, 1, -1))
    out = self.decoder(z)
    cartesian = self.ar2los(x.orientation)
    out_response = self._getResponse(out, cartesian)

    if return_latent:
      return (out_response, out)
    else:
      return out_response

  def _getResponse(self, out, pose):
    if self.invariant_out:
      return out
    else:
      sh = torch.concatenate(
        [o3.spherical_harmonics(l, pose, True) for l in range(self.lmax + 1)], dim=1
      ).unsqueeze(2)  # B x (L^2 * S^2) x 1
      response = torch.bmm(out, sh).squeeze()  # B x D

      return response
    
  def ar2los(self, x_ar):
    """Convert a unit spherical coordinate to cartesian.
    Parameters
    ----------
    x_ar: Tensor, shape-(N, ..., [2, 4, 6])
        Aspect/Roll coordinates
    Returns
    -------
    x_los: Tensor, shape-(N, ..., [3, 6, 9])
        Cartesian coordinates
    """
    assert x_ar.shape[-1] % 2 == 0
    assert x_ar.shape[-1] <= 6

    # Line-of-sight in XYZ
    a = x_ar[..., 0]
    r = x_ar[..., 1]

    x = -tr.sin(a) * tr.cos(r)
    y = -tr.sin(a) * tr.sin(r)
    z = -tr.cos(a)

    if x_ar.shape[-1] == 2:
        return tr.stack([x, y, z], dim=-1)

    # First time derivative
    da_dt = x_ar[..., 2]
    dr_dt = x_ar[..., 3]

    # Non-zero partial derivatives
    dxlos_da = -tr.cos(a) * tr.cos(r)
    dxlos_dr = tr.sin(a) * tr.sin(r)
    dylos_da = -tr.cos(a) * tr.sin(r)
    dylos_dr = -tr.sin(a) * tr.cos(r)
    dzlos_da = tr.sin(a)

    # Time derivative of line-of-sight
    xd = dxlos_da * da_dt + dxlos_dr * dr_dt
    yd = dylos_da * da_dt + dylos_dr * dr_dt
    zd = dzlos_da * da_dt

    if x_ar.shape[-1] == 4:
        return tr.stack([x, y, z, xd, yd, zd], dim=-1)

    da_dtdt = x_ar[..., 4]
    dr_dtdt = x_ar[..., 5]

    # Second partial derivatives
    dxlos_dada = tr.sin(a) * tr.cos(r)
    dxlos_dadr = tr.cos(a) * tr.sin(r)
    dxlos_drda = tr.cos(a) * tr.sin(r)
    dxlos_drdr = tr.sin(a) * tr.cos(r)
    dylos_dada = tr.sin(a) * tr.sin(r)
    dylos_dadr = -tr.cos(a) * tr.cos(r)
    dylos_drda = -tr.cos(a) * tr.cos(r)
    dylos_drdr = tr.sin(a) * tr.sin(r)
    dzlos_dada = tr.cos(a)

    # Second time derivative of line-of-sight
    xdd = (
        (dxlos_dada * da_dt + dxlos_dadr * dr_dt) * da_dt
        + dxlos_da * da_dtdt
        + (dxlos_drda * da_dt + dxlos_drdr * dr_dt) * dr_dt
        + dxlos_dr * dr_dtdt
    )
    ydd = (
        (dylos_dada * da_dt + dylos_dadr * dr_dt) * da_dt
        + dylos_da * da_dtdt
        + (dylos_drda * da_dt + dylos_drdr * dr_dt) * dr_dt
        + dylos_dr * dr_dtdt
    )
    zdd = (dzlos_dada * da_dt) * da_dt + dzlos_da * da_dtdt

    return tr.stack([x, y, z, xd, yd, zd, xdd, ydd, zdd], dim=-1)

rem = REM(num_node_features=5, z_lmax=4, max_radius=1.8, out_dim=1)



## Test equivariance

In [13]:
samp, y = next(dl)
samp_copy = samp.clone()
print(rem(samp))

tensor(0.0114, grad_fn=<SqueezeBackward0>)


## Debug

In [52]:
def compute_max_radius(data) -> float:
    """
    Computes the maximum radius (maximum pairwise Euclidean distance) 
    for a torch_geometric Data object.

    Parameters:
        data (Data): A PyTorch Geometric Data object with node positions in `data.pos`.

    Returns:
        float: The maximum radius.
    """
    # Ensure the data object has 'pos' attribute
    if not hasattr(data, 'pos') or data.pos is None:
        raise ValueError("The Data object must have a 'pos' attribute for node positions.")

    # Compute pairwise distances
    pairwise_distances = torch.cdist(data.pos, data.pos, p=2)

    # Get the maximum distance
    max_radius = pairwise_distances.max().item()
    
    return max_radius

print(compute_max_radius(samp_copy))

1.7320507764816284


In [None]:
from e3nn.o3 import rand_matrix
samp, y = next(dl)
samp_copy = samp.clone()

def spherical_to_cartesian(alpha, beta):
    """
    Convert spherical angles (alpha, beta) to cartesian coordinates.
    alpha: azimuthal angle, beta: polar angle
    """
    x = torch.cos(beta) * torch.cos(alpha)
    y = torch.cos(beta) * torch.sin(alpha)
    z = torch.sin(beta)
    return torch.stack([x, y, z], dim=-1)


def rotate_orientation(orientation, R):
    """
    Rotate 2D orientation (alpha, beta) values using a rotation matrix.
    Orientation is converted to Cartesian, rotated, and then converted back.
    """
    # Convert spherical angles (alpha, beta) to Cartesian coordinates
    cartesian = spherical_to_cartesian(orientation[:, 0], orientation[:, 1])

    # Apply rotation
    rotated_cartesian = torch.einsum('ij,nj->ni', R, cartesian)

    # Convert back to spherical coordinates
    rho = rotated_cartesian.norm(dim=-1)
    beta = torch.asin(rotated_cartesian[:, 2] / rho)
    alpha = torch.atan2(rotated_cartesian[:, 1], rotated_cartesian[:, 0])
    return torch.stack([alpha, beta], dim=-1)

In [None]:

def test_encoder_equivariance(rem, data):
    """
    Test the equivariance of rem.encoder with respect to SO(3) rotations.

    Parameters:
        rem: REM
            The REM model with the encoder to test.
        data: torch_geometric.data.Data
            Input graph data with `data.orientation` as (alpha, beta).

    Returns:
        bool: True if invariant, False otherwise.
    """
    # Generate a random SO(3) rotation matrix
    # R = rand_matrix()
    R = torch.tensor([
        [1, 0, 0],
        [0, 0, -1],
        [0, 1, 0]
    ], dtype=torch.float32)

    # Rotate orientation
    original_orientation = data.orientation.clone()
    print("Original orientation", original_orientation)
    rotated_orientation = rotate_orientation(original_orientation, R)
    print("New orientation", rotated_orientation)

    R_random = o3.rand_matrix()
    rotated_orientation_random = rotate_orientation(original_orientation, R_random)
    print("New orientation random", rotated_orientation_random)

    # Clone data and apply rotated orientation
    data_rotated = data.clone()
    data_rotated.orientation = rotated_orientation

    data_rotated_random = data.clone()
    data_rotated_random.orientation = rotated_orientation_random

    # Forward pass on original and rotated data
    rem.eval()  # Ensure evaluation mode
    with torch.no_grad():
        output_original = rem(data)
        print("Output original", output_original)
        output_rotated = rem(data_rotated)
        print("Output rotated", output_rotated)
        output_rotated_random = rem(data_rotated_random)
        print("Output rotated random", output_rotated_random)

    # Check invariance: the outputs should be identical
    is_invariant = torch.allclose(output_original, output_rotated, atol=1e-6)

    print(f"Equivariance Test: {'Passed' if is_invariant else 'Failed'}")
    return is_invariant

test_encoder_equivariance(rem, samp_copy)

Original orientation tensor([[2.4197, 0.7810]])
New orientation tensor([[-2.2189,  0.4885]])
New orientation random tensor([[-1.8800,  0.1439]])
Output original tensor(0.0013)
Output rotated tensor(0.0010)
Output rotated random tensor(0.0005)
Equivariance Test: Failed


False

## TODO try what chatgpt says

Try all close on each layer's latent tensor