In [None]:
# Install libraries
!pip install --pre dgl-cu101

# Import libraries
import torch
from torch import nn
import torch.nn.functional as F

import dgl
import dgl.function as fn
from dgl import DGLGraph

import numpy as np
import math

!pip install git+https://github.com/AMLab-Amsterdam/lie_learn
import lie_learn

from lie_learn.representations.SO3.wigner_d import wigner_D_matrix

In [None]:
def get_spherical_from_cartesian_torch(cartesian, divide_radius_by=1.0):

    ###################################################################################################################
    # ON ANGLE CONVENTION
    #
    # sh has following convention for angles:
    # :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
    # :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
    #
    # the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta:
    # beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)).
    # alpha = phi
    #
    ###################################################################################################################

    # initialise return array
    # ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
    spherical = torch.zeros_like(cartesian)

    # indices for return array
    ind_radius = 0
    ind_alpha = 1
    ind_beta = 2

    cartesian_x = 2
    cartesian_y = 0
    cartesian_z = 1

    # get projected radius in xy plane
    # xy = xyz[:,0]**2 + xyz[:,1]**2
    r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2

    # get second angle
    # version 'elevation angle defined from Z-axis down'
    spherical[..., ind_beta] = torch.atan2(torch.sqrt(r_xy), cartesian[..., cartesian_z])
    # ptsnew[:,4] = np.arctan2(np.sqrt(xy), xyz[:,2])
    # version 'elevation angle defined from XY-plane up'
    #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy))
    # spherical[:, ind_beta] = np.arctan2(cartesian[:, 2], np.sqrt(r_xy))

    # get angle in x-y plane
    spherical[...,ind_alpha] = torch.atan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x])

    # get overall radius
    # ptsnew[:,3] = np.sqrt(xy + xyz[:,2]**2)
    if divide_radius_by == 1.0:
        spherical[..., ind_radius] = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)
    else:
        spherical[..., ind_radius] = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)/divide_radius_by

    return spherical

def pochhammer(x, k):
    """Compute the pochhammer symbol (x)_k.
    (x)_k = x * (x+1) * (x+2) *...* (x+k-1)
    Args:
        x: positive int
    Returns:
        float for (x)_k
    """
    xf = float(x)
    for n in range(x+1, x+k):
        xf *= n
    return xf
    
def semifactorial(x):
    """Compute the semifactorial function x!!.
    x!! = x * (x-2) * (x-4) *...
    Args:
        x: positive int
    Returns:
        float for x!!
    """
    y = 1.
    for n in range(x, 1, -2):
        y *= n
    return y

class SphericalHarmonics(object):
    def __init__(self):
        self.leg = {}

    def clear(self):
        self.leg = {}

    def negative_lpmv(self, l, m, y):
        """Compute negative order coefficients"""
        if m < 0:
            y *= ((-1)**m / pochhammer(l+m+1, -2*m))
        return y

    def lpmv(self, l, m, x):
        """Associated Legendre function including Condon-Shortley phase.
        Args:
            m: int order 
            l: int degree
            x: float argument tensor
        Returns:
            tensor of x-shape
        """
        # Check memoized versions
        m_abs = abs(m)
        if (l,m) in self.leg:
            return self.leg[(l,m)]
        elif m_abs > l:
            return None
        elif l == 0:
            self.leg[(l,m)] = torch.ones_like(x)
            return self.leg[(l,m)]
        
        # Check if on boundary else recurse solution down to boundary
        if m_abs == l:
            # Compute P_m^m
            y = (-1)**m_abs * semifactorial(2*m_abs-1)
            y *= torch.pow(1-x*x, m_abs/2)
            self.leg[(l,m)] = self.negative_lpmv(l, m, y)
            return self.leg[(l,m)]
        else:
            # Recursively precompute lower degree harmonics
            self.lpmv(l-1, m, x)

        # Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m
        # Inplace speedup
        y = ((2*l-1) / (l-m_abs)) * x * self.lpmv(l-1, m_abs, x)
        if l - m_abs > 1:
            y -= ((l+m_abs-1)/(l-m_abs)) * self.leg[(l-2, m_abs)]
        #self.leg[(l, m_abs)] = y
        
        if m < 0:
            y = self.negative_lpmv(l, m, y)
        self.leg[(l,m)] = y

        return self.leg[(l,m)]

    def get_element(self, l, m, theta, phi):
        """Tesseral spherical harmonic with Condon-Shortley phase.
        The Tesseral spherical harmonics are also known as the real spherical
        harmonics.
        Args:
            l: int for degree
            m: int for order, where -l <= m < l
            theta: collatitude or polar angle
            phi: longitude or azimuth
        Returns:
            tensor of shape theta
        """
        assert abs(m) <= l, "absolute value of order m must be <= degree l"

        N = np.sqrt((2*l+1) / (4*np.pi))
        leg = self.lpmv(l, abs(m), torch.cos(theta))
        if m == 0:
            return N*leg
        elif m > 0:
            Y = torch.cos(m*phi) * leg
        else:
            Y = torch.sin(abs(m)*phi) * leg
        N *= np.sqrt(2. / pochhammer(l-abs(m)+1, 2*abs(m)))
        Y *= N
        return Y

    def get(self, l, theta, phi, refresh=True):
        """Tesseral harmonic with Condon-Shortley phase.
        The Tesseral spherical harmonics are also known as the real spherical
        harmonics.
        Args:
            l: int for degree
            theta: collatitude or polar angle
            phi: longitude or azimuth
        Returns:
            tensor of shape [*theta.shape, 2*l+1]
        """
        results = []
        if refresh:
            self.clear()
        for m in range(-l, l+1):
            results.append(self.get_element(l, m, theta, phi))
        return torch.stack(results, -1)

def precompute_sh(r_ij, max_J):
    """
    pre-comput spherical harmonics up to order max_J
    :param r_ij: relative positions
    :param max_J: maximum order used in entire network
    :return: dict where each entry has shape [B,N,K,2J+1]
    """
    
    i_distance = 0
    i_alpha = 1
    i_beta = 2

    Y_Js = {}
    sh = SphericalHarmonics()

    for J in range(max_J+1):
        # dimension [B,N,K,2J+1]
        #Y_Js[J] = spherical_harmonics(order=J, alpha=r_ij[...,i_alpha], beta=r_ij[...,i_beta])
        Y_Js[J] = sh.get(J, theta=math.pi-r_ij[...,i_beta], phi=r_ij[...,i_alpha], refresh=False)

    sh.clear()
    return Y_Js

def irr_repr(order, alpha, beta, gamma, dtype=None):
    """
    irreducible representation of SO3
    - compatible with compose and spherical_harmonics
    """
    # from from_lielearn_SO3.wigner_d import wigner_D_matrix
    from lie_learn.representations.SO3.wigner_d import wigner_D_matrix
    # if order == 1:
    #     # change of basis to have vector_field[x, y, z] = [vx, vy, vz]
    #     A = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
    #     return A @ wigner_D_matrix(1, alpha, beta, gamma) @ A.T

    # TODO (non-essential): try to do everything in torch
    # return torch.tensor(wigner_D_matrix(torch.tensor(order), alpha, beta, gamma), dtype=torch.get_default_dtype() if dtype is None else dtype)
    return torch.tensor(wigner_D_matrix(order, np.array(alpha), np.array(beta), np.array(gamma)), dtype=torch.get_default_dtype() if dtype is None else dtype)


def kron(a, b):
    """
    A part of the pylabyk library: numpytorch.py at https://github.com/yulkang/pylabyk
    Kronecker product of matrices a and b with leading batch dimensions.
    Batch dimensions are broadcast. The number of them mush
    :type a: torch.Tensor
    :type b: torch.Tensor
    :rtype: torch.Tensor
    """
    siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:]))
    res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
    siz0 = res.shape[:-4]
    return res.reshape(siz0 + siz1)

def get_matrix_kernel(A, eps=1e-10):
    '''
    Compute an orthonormal basis of the kernel (x_1, x_2, ...)
    A x_i = 0
    scalar_product(x_i, x_j) = delta_ij
    :param A: matrix
    :return: matrix where each row is a basis vector of the kernel of A
    '''
    _u, s, v = torch.svd(A)

    # A = u @ torch.diag(s) @ v.t()
    kernel = v.t()[s < eps]
    return kernel


def get_matrices_kernel(As, eps=1e-10):
    '''
    Computes the commun kernel of all the As matrices
    '''
    return get_matrix_kernel(torch.cat(As, dim=0), eps)
    
class torch_default_dtype:

    def __init__(self, dtype):
        self.saved_dtype = None
        self.dtype = dtype

    def __enter__(self):
        self.saved_dtype = torch.get_default_dtype()
        torch.set_default_dtype(self.dtype)

    def __exit__(self, exc_type, exc_value, traceback):
        torch.set_default_dtype(self.saved_dtype)

def _basis_transformation_Q_J(J, order_in, order_out, version=3):  # pylint: disable=W0613
    """
    :param J: order of the spherical harmonics
    :param order_in: order of the input representation
    :param order_out: order of the output representation
    :return: one part of the Q^-1 matrix of the article
    """
    with torch_default_dtype(torch.float64):
        def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c))

        def _sylvester_submatrix(J, a, b, c):
            ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J '''
            R_tensor = _R_tensor(a, b, c)  # [m_out * m_in, m_out * m_in]
            R_irrep_J = irr_repr(J, a, b, c)  # [m, m]
            return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \
                kron(torch.eye(R_tensor.size(0)), R_irrep_J.t())  # [(m_out * m_in) * m, (m_out * m_in) * m]

        random_angles = [
            [4.41301023, 5.56684102, 4.59384642],
            [4.93325116, 6.12697327, 4.14574096],
            [0.53878964, 4.09050444, 5.36539036],
            [2.16017393, 3.48835314, 5.55174441],
            [2.52385107, 0.2908958, 3.90040975]
        ]
        null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c) for a, b, c in random_angles])
        assert null_space.size(0) == 1, null_space.size()  # unique subspace solution
        Q_J = null_space[0]  # [(m_out * m_in) * m]
        Q_J = Q_J.view((2 * order_out + 1) * (2 * order_in + 1), 2 * J + 1)  # [m_out * m_in, m]
        assert all(torch.allclose(_R_tensor(a, b, c) @ Q_J, Q_J @ irr_repr(J, a, b, c)) for a, b, c in torch.rand(4, 3))

    assert Q_J.dtype == torch.float64
    return Q_J  # [m_out * m_in, m]

In [None]:
class PointCloud:
  '''Represents a point cloud in R^3. This class calculates and stores relevant
  geometric information such as the vectors, distances, directions, and
  spherical harmonics of the vectors.'''
  def __init__(self, pos, cutoff=8.0, J_max=4):
    edges = self._find_edges_(pos, cutoff)
    self.graph = dgl.graph(edges)
    self.graph.ndata['pos'] = pos
    self._calc_edge_info_(J_max)
    self.w_j = dict()

  def _find_edges_(self, pos, cutoff):
    # Use positions to create graph. Need to improve! Currently O(n^2)
    vec_mat = pos[:,None,:]-pos[None,:,:]
    dist_mat = torch.sqrt(torch.sum((vec_mat)**2,axis=-1))
    u = []
    v = []
    for j in range(len(pos)):
      for i in range(j):
        if dist_mat[i,j] < cutoff:
          u.append(i)
          v.append(j)
    u, v = torch.tensor(u+v), torch.tensor(v+u)
    return (u,v)

  def _calc_edge_info_(self, J_max):
    # Calculate and store position and angle information
    u,v = self.graph.edges()[0], self.graph.edges()[1]
    pos = self.graph.ndata['pos']
    vec = pos[u]-pos[v]
    self.graph.edata['vec'] = vec
    r_ij = get_spherical_from_cartesian_torch(vec)
    self.graph.edata['r_ij'] = r_ij
    self.Y = precompute_sh(r_ij, J_max)
    self.graph.edata['dist'] = torch.norm(vec, dim=1)

  def get_sh(self, J):
    # Returns spherical harmonic of order J.
    if not J in self.Y.keys(): # If J <= J_max this is false.
      r_ij = self.graph.edata['r_ij']
      Y_new = precompute_sh(r_ij,J)
      for key in Y_new.keys():
        if not key in self.Y.keys():
          self.Y[key] = Y_new[key]
    return self.Y[J]
    
  def get_w_j(self, l, k):
    # Returns basis kernel
    if not (l,k) in self.w_j.keys():
      w_j = torch.zeros(k+l-abs(k-l)+1, self.graph.number_of_edges(), 2*l+1, 2*k+1)
      for j, J in enumerate(range(abs(k-l), k+l+1)):
        Y_J = self.get_sh(J)
        Q_J = _basis_transformation_Q_J(J, k, l).float()
        w_j[j] = torch.matmul(Y_J,Q_J.T).reshape(self.graph.number_of_edges(), 2*l+1, 2*k+1)#.transpose(2,3)
      self.w_j[(l,k)] = w_j.transpose(0,1)
    return self.w_j[(l,k)]

In [None]:
# TODO: Self interaction
# TODO: Ability to change head size

# Dictionary for indices
# e: edges
# o: c_out
# i: c_in
# l: output tensor representation
# k: input tensor representation
# j: hidden tensor representation

class WLayer(nn.Module):
  def __init__(self, k, l, c_in=1, c_out=1):
    super(WLayer, self).__init__()
    self.k = k
    self.l = l
    self.c_in = c_in
    self.c_out = c_out

    J_size = k+l-abs(k-l)+1
    self.J_size = J_size
    r_size = J_size * c_out * c_in

    self.radial = nn.Sequential(nn.Linear(1,32),
                                nn.BatchNorm1d(32),
                                nn.ReLU(),
                                nn.Linear(32,32),
                                nn.BatchNorm1d(32,32),
                                nn.ReLU(),
                                nn.Linear(32,r_size))

  def forward(self, pc):
    l, k = self.l, self.k
    w_j = pc.get_w_j(l, k)
    dist = pc.graph.edata['dist'][:,None]
    size = (pc.graph.number_of_edges(), self.J_size, self.c_out, self.c_in)
    R = self.radial(dist).view(*size)
    w = torch.einsum('ejoi,ejlk->eoilk', R, w_j)
    return w

In [None]:
class AttnBlock(nn.Module):
  def __init__(self, d_in, d_out, c_in=1, c_out=1):
    super(AttnBlock, self).__init__()
    self.d_in = d_in
    self.d_out = d_out

    self.c_in = c_in
    self.c_out = c_out

    self.wq = torch.randn(d_in+1, c_out, c_in, requires_grad=True)
    
    self.wk_layers = nn.ModuleList([
                      nn.ModuleList([WLayer(k, l, c_in, c_out)
                        for k in range(d_in+1)])
                      for l in range(d_out+1)])

  def forward(self, pc, f):
    for key, value in f.items():
      pc.graph.ndata[key] = value
    pc.graph.ndata['q'] = self.calc_q(pc, f)

    for l in range(self.d_out+1):
      for k in range(self.d_in+1):
        pc.graph.edata[(k,l)] = self.wk_layers[l][k](pc)

    pc.graph.update_all(self.attn_msg, self.attn_rdc)
    a = pc.graph.edata['exp'] / pc.graph.ndata['sum'][pc.graph.edges()[1]]
    return a

  def attn_msg(self, edges):
    k = self.calc_k(edges)
    q = edges.dst['q']
    exp = torch.exp(torch.einsum('eol,eol->e',q,k))
    edges.data['exp'] = exp
    return {'exp': exp}

  def attn_rdc(self, nodes):
    # does sum over j'
    exp = nodes.mailbox['exp']
    sum = torch.sum(exp, dim=1)
    return {'sum': sum}

  def calc_q(self, pc, f):
    ql = []
    for k in range(min(self.d_in,self.d_out)+1):
      sum = torch.einsum('oi,nik->nok',self.wq[k],f[k])
      ql.append(sum)
    q = torch.cat(ql, dim=2)
    return q

  def calc_k(self, edges):
    kl = []
    for l in range(min(self.d_in,self.d_out)+1):
      wks = []
      for k in range(self.d_in+1):
        wk = torch.einsum('eoilk,eik->eol',
                          edges.data[(k,l)],
                          edges.dst[k])
        wks.append(wk)
      stack = torch.stack(wks, dim=3)
      sum = torch.sum(stack, dim=3)
      kl.append(sum)
    k = torch.cat(kl, dim=2)
    return k

class TransLayer(nn.Module):
  def __init__(self, d_in, d_out, c_in=1, c_out=1):
    super(TransLayer, self).__init__()
    self.d_in = d_in
    self.d_out = d_out
    self.c_in = c_in
    self.c_out = c_out

    self.wv_layers = nn.ModuleList([
                        nn.ModuleList([
                          WLayer(k, l, c_in=c_in, c_out=c_out)
                          for l in range(d_out+1)])
                        for k in range(d_in+1)])

    self.block = AttnBlock(d_in, d_out, c_in=c_in, c_out=c_out)

    self.si_nets = [nn.Sequential(nn.Linear(c_in**2, 2*c_in*c_out),
                                 nn.BatchNorm1d(2*c_in*c_out),
                                 nn.ReLU(),
                                 nn.Linear(2*c_in*c_out, 2*c_in*c_out),
                                 nn.BatchNorm1d(2*c_in*c_out),
                                 nn.ReLU(),
                                 nn.Linear(2*c_in*c_out, c_in*c_out))
                   for _ in range(min(d_in, d_out)+1)]

  def forward(self, pc, f):
    pc.graph.edata['a'] = self.block(pc, f)
    for k in range(self.d_in+1):
      pc.graph.ndata[k] = f[k]
      for l in range(self.d_out+1):
        pc.graph.edata[(k,l)] = self.wv_layers[k][l](pc)
    pc.graph.update_all(self.msg_func, self.rdc_func)
    si = self.calc_si(pc, f)
    f = dict()
    for l in range(self.d_out+1):
      data = pc.graph.ndata[l]
      if l <= self.d_in:
        data = data + si[l]
      f[l] = data
    return f

  def msg_func(self, edges):
    vls = dict()
    for l in range(self.d_out+1):
      vl = self.calc_vl(edges, l)
      vls[l] = vl
    return vls

  def rdc_func(self, nodes):
    f = dict()
    for key, value in nodes.mailbox.items():
      f[key] = torch.sum(value, dim=1)
    return f

  def calc_vl(self, edges, l):
    a = edges.data['a']
    vlks = []
    for k in range(self.d_in+1):
      wk = edges.data[(k,l)]
      f = edges.src[k]
      vlk = torch.einsum('e,eoilk,eik->eol', a, wk, f)
      vlks.append(vlk)
    vlk = torch.stack(vlks, dim=3)
    vl = torch.sum(vlk, dim=3)
    return vl

  def calc_si(self, pc, f):
    si = dict()
    num_nodes = pc.graph.num_nodes()
    c_in = self.c_in
    c_out = self.c_out
    size_in = (num_nodes, c_in*c_in)
    size_out = (num_nodes, c_out, c_in)
    for l in range(min(self.d_in,self.d_out)+1):
      f_l = f[l]
      inner = torch.einsum('ncl,ndl->ncd',f_l,f_l).view(*size_in)
      si_w = self.si_nets[l](inner).view(*size_out)
      si[l] = torch.einsum('noi,nil->nol', si_w, f[l])
    return si

class MultiHead(nn.Module):
  def __init__(self, d_in, d_out, c_in=1, c_out=1, heads = 1):
    super(MultiHead, self).__init__()
    self.d_in = d_in
    self.d_out = d_out
    self.c_in = c_in
    self.c_out = c_out
    self.heads = heads
    layers = [TransLayer(d_in, d_out, c_in//heads, c_out//heads)
              for _ in range(heads)]
    self.layers = nn.ModuleList(layers)

  def forward(self, pc, f):
    heads = self.heads
    fs = self.split_f(f, heads)
    outs = [self.layers[i](pc, fs[i]) for i in range(heads)]
    out = self.merge_fs(outs)
    return out

  # Split features into chunks
  @staticmethod
  def split_f(f, heads):
    fs = [dict()]*heads
    for l, f_l in f.items():
      chunk = torch.chunk(f_l,heads,dim=1)
      for i in range(heads):
        fs[i][l] = chunk[i]
    return fs

  # Put chunks back together
  @staticmethod
  def merge_fs(fs):
    f = dict()
    for l in fs[0].keys():
      f_ls = [f_i[l] for f_i in fs]
      f[l] = torch.cat(f_ls, dim=1)
    return f

class GraphReLU(nn.Module):
  def __init__(self, d_in, channels=1):
    super(GraphReLU, self).__init__()
    self.d_in = d_in
    self.channels = channels
    self.lns = nn.ModuleList([nn.LayerNorm(channels) for _ in range(d_in+1)])
    self.relu = nn.ReLU()

  def forward(self, f):
    out = dict()
    for l, f_l in f.items():
      norm = torch.norm(f[l],dim=2,keepdim=True)
      out[l] = self.relu(self.lns[l](norm[...,0])).unsqueeze(-1) * (f_l / norm)
    return out


# Testing

In [None]:
# Machine epsilon. Errors should be around this order of magnitude (assuming output are O(1)).
torch.finfo(torch.float32).eps

In [None]:
# Create test graphs, point clouds, and fures
num_pts = 10

a = 0.2
b = 0.4
c = 0.7

in_channels=12
out_channels=16

num_edgs = num_pts * (num_pts-1)

pos = torch.rand(num_pts,3)
pc = PointCloud(pos)

nf = {0: torch.rand(num_pts,in_channels,1),
         1: torch.rand(num_pts,in_channels,3),
         2: torch.rand(num_pts,in_channels,5)}

ef = {0: torch.rand(num_edgs,in_channels,1),
         1: torch.rand(num_edgs,in_channels,3),
         2: torch.rand(num_edgs,in_channels,5)}

d1 = torch.tensor(wigner_D_matrix(1,a,b,c)).float()

pos_rot = torch.einsum('lk,nk->nl',d1,pos)
pc_rot = PointCloud(pos_rot)

nf_rot = dict()
for key, value in nf.items():
  d = torch.tensor(wigner_D_matrix(key,a,b,c)).float()
  nf_rot[key] = torch.einsum('lk,nik->nil',d,value)

ef_rot = dict()
for key, value in ef.items():
  d = torch.tensor(wigner_D_matrix(key,a,b,c)).float()
  ef_rot[key] = torch.einsum('lk,eik->eil',d,value)

In [None]:
# Make sure distances are invariant in PointCloud

dist_diff = pc.graph.edata['dist'] - pc_rot.graph.edata['dist']
torch.max(torch.abs(dist_diff))

In [None]:
# Make sure vectors are equivariant in PointCloud

d1 = torch.tensor(wigner_D_matrix(1,a,b,c)).float()

vec = pc.graph.edata['vec']

vec_post = torch.einsum('lk,ek->el',d1,vec)
vec_pre = pc_rot.graph.edata['vec']

diff_vec = vec_post-vec_pre

torch.max(torch.abs(diff_vec))

In [None]:
# Check WLayer invariance
w_layer = WLayer(0,0,in_channels,out_channels)

w = w_layer(pc)
w_f = torch.einsum('eoilk,eik->eol',w,ef[0])

w_rot = w_layer(pc_rot)
w_rot_f = torch.einsum('eoilk,eik->eol',w_rot,ef_rot[0])

torch.max(torch.abs(w_f-w_rot_f))

In [None]:
# Check WLayer equivariance (0,1)
w_layer = WLayer(0,1,in_channels,out_channels)

d1 = torch.tensor(wigner_D_matrix(1,a,b,c)).float()

w = w_layer(pc)
w_f = torch.einsum('eoilk,eik->eol',w,ef[0])
w_f_rot = torch.einsum('lk,eok->eol',d1,w_f)

w_rot = w_layer(pc_rot)
w_rot_f = torch.einsum('eoilk,eik->eol',w_rot,ef_rot[0])

torch.max(torch.abs(w_f_rot-w_rot_f))

In [None]:
# Check WLayer equivariance (1,1)
w_layer = WLayer(1,1,in_channels,out_channels)

d1 = torch.tensor(wigner_D_matrix(1,a,b,c)).float()

w = w_layer(pc)
w_f = torch.einsum('eoilk,eik->eol',w,ef[1])
w_f_rot = torch.einsum('lk,eok->eol',d1,w_f)

w_rot = w_layer(pc_rot)
w_rot_f = torch.einsum('eoilk,eik->eol',w_rot,ef_rot[1])

torch.max(torch.abs(w_f_rot-w_rot_f))

In [None]:
# Check WLayer equivariance (2,3)
w_layer = WLayer(2,3,in_channels,out_channels)

d3 = torch.tensor(wigner_D_matrix(3,a,b,c)).float()

w = w_layer(pc)
w_f = torch.einsum('eoilk,eik->eol',w,ef[2])

w_f_rot = torch.einsum('lk,eok->eol',d3,w_f)

w_rot = w_layer(pc_rot)
w_rot_f = torch.einsum('eoilk,eik->eol',w_rot,ef_rot[2])

torch.max(torch.abs(w_f_rot-w_rot_f))

In [None]:
# Check TransLayer for equivariance
layer = TransLayer(2,3,in_channels,out_channels)

out = layer(pc, nf)
out_rot = layer(pc_rot, nf_rot)

out_post = dict()
for key, value in out.items():
  d = torch.tensor(wigner_D_matrix(key,a,b,c)).float()
  out_post[key] = torch.einsum('lk,nik->nil',d,value)

torch.max(torch.abs(out_rot[3]-out_post[3]))

In [None]:
# Check MultiHead for equivariance
layer = MultiHead(2,3,in_channels,out_channels,heads=4)

out = layer(pc, nf)
out_rot = layer(pc_rot, nf_rot)

out_post = dict()
for key, value in out.items():
  d = torch.tensor(wigner_D_matrix(key,a,b,c)).float()
  out_post[key] = torch.einsum('lk,nik->nil',d,value)
torch.max(torch.abs(out_rot[3]-out_post[3]))

In [None]:
# Check GraphReLU for equivariance
layer = GraphReLU(2, channels=in_channels)

out = layer(nf)
out_rot = layer(nf_rot)

out_post = dict()
for key, value in out.items():
  d = torch.tensor(wigner_D_matrix(key,a,b,c)).float()
  out_post[key] = torch.einsum('lk,nik->nil',d,value)
torch.max(torch.abs(out_rot[1]-out_post[1]))