David Klee | 3/22/23 | klee.d@northeastern.edu

For additional details on E3NN, please refer to their [documentation](https://docs.e3nn.org/en/latest/index.html), which is quite extensive.

# 0. Setup imports and util functions

In [None]:
%%capture
! pip install e3nn scipy matplotlib torchvision plotly healpy

In [None]:
import numpy as np
import torch
import torch.nn as nn
import e3nn
from e3nn import o3
import healpy as hp
import plotly.graph_objects as go
import warnings
warnings.filterwarnings('ignore')

In [None]:
axis = dict(
    showbackground=False,
    showticklabels=False,
    showgrid=False,
    zeroline=False,
    title='',
    nticks=3,
)

cmap_bwr = [[0, 'rgb(0,50,255)'], [0.5, 'rgb(200,200,200)'], [1, 'rgb(255,50,0)']]

def trace(r, f, c, radial_abs=True):
    if radial_abs:
        a = f.abs()
    else:
        a = 1
    return dict(
        x=a * r[..., 0] + c[0],
        y=a * r[..., 1] + c[1],
        z=a * r[..., 2] + c[2],
        surfacecolor=f
    )

def plot_harmonics(data, radial_abs=True):
    n = data.shape[-1]
    
    # reproduce spatial grid of data with endpoint=True to prevent gaps in surface
    beta, alpha = torch.meshgrid(
      torch.linspace(0, np.pi, data.shape[0]), 
      torch.linspace(0, 2 * np.pi, data.shape[1]), 
      indexing='ij'
    )
    r = o3.angles_to_xyz(alpha, beta)

    layout = dict(
        width=(n+1)//2*200,
        height=200,
        scene=dict(
            xaxis=dict(
                **axis,
                range=[-(n+1)//2* 2, (n+1)//2 * 2]
            ),
            yaxis=dict(
                **axis,
                range=[-2, 2]
            ),
            zaxis=dict(
                **axis,
                range=[-2, 2]
            ),
            aspectmode='manual',
            aspectratio=dict(x=(n+1)//2* 2, y=2, z=2),
            camera=dict(
                up=dict(x=0, y=0, z=1),
                center=dict(x=0, y=0, z=0),
                eye=dict(x=0, y=-5, z=5),
                projection=dict(type='orthographic'),
            ),
        ),
        paper_bgcolor="rgba(1,1,1,1)",
        plot_bgcolor="rgba(0,0,0,0)",
        margin=dict(l=0, r=0, t=0, b=0)
    )
    
    traces = [
        trace(r, data[..., i], torch.tensor([2.0 * i - (n - 1.0), 0.0, 0.0]), radial_abs=radial_abs)
        for i in range(n)
    ]

    cmax = max(d['surfacecolor'].abs().max().item() for d in traces)
    traces = [go.Surface(**d, colorscale=cmap_bwr, cmin=-cmax, cmax=cmax) for d in traces]
    fig = go.Figure(data=traces, layout=layout)
    fig.show()

def plot_s2_grid(data, c=None, markersize=1):
    fig = go.Figure(
        data=[
            go.Scatter3d(
                x=data[..., 0].flatten(),
                y=data[..., 1].flatten(),
                z=data[..., 2].flatten(),
                mode='markers',
                marker=dict(
                    size=markersize,
                    color=c.flatten() if c is not None else None,
                    colorscale=cmap_bwr, 
                ),
            )
        ],
        layout=dict(
          scene=dict(
            xaxis=dict(range=[-1.1, 1.1]),
            yaxis=dict(range=[-1.1, 1.1]),
            zaxis=dict(range=[-1.1, 1.1]),
          ),
          width=400,
          height=400,
          margin=dict(l=0, r=0, t=0, b=0),
          plot_bgcolor="rgba(0,0,0,0)",
          paper_bgcolor="rgba(0,0,0,0)",
        )
    )
    fig.show()

def plot_s2_surface(data):
    n = data.shape[-1]
  
    # reproduce spatial grid of data with endpoint=True to prevent gaps in surface
    beta, alpha = torch.meshgrid(
      torch.linspace(0, np.pi, data.shape[0]), 
      torch.linspace(0, 2 * np.pi, data.shape[1]), 
      indexing='ij'
    )
    r = o3.angles_to_xyz(alpha, beta)

    layout = dict(
        width=400,
        height=400,
        margin=dict(l=0, r=0, t=0, b=0),
        plot_bgcolor="rgba(0,0,0,0)",
        paper_bgcolor="rgba(0,0,0,0)",
    )

    cmax = data.abs().max().item()
    fig = go.Figure(
      data=[
          go.Surface(
            x=r[..., 0],
            y=r[..., 1],
            z=r[..., 2],
            surfacecolor=data,
            colorscale=cmap_bwr, 
            cmin=-cmax, 
            cmax=cmax,
          )
      ], 
      layout=layout)
    fig.show()

# 1. Visualizing Spherical Harmonics

## 1a. Create Spatial Grid over Sphere

In [None]:
def grid_over_s2(res_beta, res_alpha):
  # alpha varies from 0 to 2PI (longitude), beta varies 0 to PI (latitude)
  betas, alphas = o3.s2_grid(res_beta=res_beta, res_alpha=res_alpha) 
  betas, alphas = torch.meshgrid([betas, alphas], indexing='ij')

  # convert euler angles to XYZ position on S2
  grid = o3.angles_to_xyz(alphas, betas) # tensor of shape [res_beta, res_alpha, 3]
  return grid

grid = grid_over_s2(40, 80)

plot_s2_grid(grid)

## 1b. Define Harmonics over Grid

In [None]:
L = 1
Y = o3.spherical_harmonics(L, grid, normalize=True, normalization='integral')

# note how many harmonics exist for every degree L
plot_harmonics(Y, radial_abs=True)

# 2. Fourier Transform
Now that we have some intuition about the spherical harmonics, let's convert some spatial signals into harmonics using the Fourier transform.  Keep note of how the spacing of the spatial grid and the degree of the harmonics affects the result.

## 2b. Define Signal over S2 Grid
We will use a 1D signal for visualization purposes, but the `e3nn` functions extend to arbitrary number of channels.

In [None]:
# use grid from before
res_beta = 30
res_alpha = 60
grid = grid_over_s2(res_beta=20, res_alpha=40)

# OR generate grid with random points
grid = o3.angles_to_xyz(
    alpha=torch.FloatTensor(res_beta * res_alpha).uniform_(0, 2*np.pi),
    beta=torch.FloatTensor(res_beta * res_alpha).uniform_(0, np.pi),
)

def wavey_signal(grid, x_freq=0.5, y_freq=0, z_freq=0):
  signal =  torch.stack([
    torch.sin(np.pi * grid[..., 0] * x_freq),
    torch.sin(np.pi * grid[..., 1] * y_freq),
    torch.sin(np.pi * grid[..., 2] * z_freq),
  ], dim=-1)
  signal = torch.mean(signal, dim=-1)

  return signal

signal = wavey_signal(grid, 0.5, 0.0, 0)

plot_s2_grid(grid, signal, markersize=5)

## 2c. Convert to Harmonics and plot

In [None]:
lmax = 3

Y = o3.spherical_harmonics_alpha_beta(
    range(lmax + 1),
    *o3.xyz_to_angles(grid),
    normalization="component"
) # shape [n_grid_points, (lmax+1)**2]

signal_harmonics = signal @ Y


# we cant actually visualize the harmonics signal without converting back to 
# spatial domain, so we will convert back using very dense grid
to_s2_grid = o3.ToS2Grid(lmax, res=(100, 101))
signal_to_plot = to_s2_grid(signal_harmonics)

# there is also a `o3.FromS2Grid` which converts to harmonics, but it does not
# give us flexibility to use arbitrary spatial grids

plot_s2_surface(signal_to_plot)

# 3. $S^2$ Convolution

In [None]:
def s2_irreps(lmax):
  return o3.Irreps([(1, (l, 1)) for l in range(lmax + 1)])

def so3_irreps(lmax):
  return o3.Irreps([(2 * l + 1, (l, 1)) for l in range(lmax + 1)])

def flat_wigner(lmax, alpha, beta, gamma):
  return torch.cat([
    (2 * l + 1) ** 0.5 * o3.wigner_D(l, alpha, beta, gamma).flatten(-2) for l in range(lmax + 1)
  ], dim=-1)


class S2Conv(nn.Module):
  '''S2 group convolution which outputs signal over SO(3) irreps

  :f_in: feature dimensionality of input signal
  :f_out: feature dimensionality of output signal
  :lmax: maximum degree of harmonics used to represent input and output signals
         technically, you can have different degrees for input and output, but
         we do not explore that in our work
  :kernel_grid: spatial locations over which the filter is defined (alphas, betas)
                we find that it is better to parametrize filter in spatial domain
                and project to harmonics at every forward pass.
  '''
  def __init__(self, f_in: int, f_out: int, lmax: int, kernel_grid: tuple):
    super().__init__()

    # filter weight parametrized over spatial grid on S2
    self.register_parameter(
      "w", torch.nn.Parameter(torch.randn(f_in, f_out, kernel_grid.shape[1]))
    )  # [f_in, f_out, n_s2_pts]

    # linear projection to convert filter weights to fourier domain
    self.register_buffer(
      "Y", o3.spherical_harmonics_alpha_beta(range(lmax + 1), *kernel_grid, normalization="component")
    )  # [n_s2_pts, (lmax+1)**2]

    # defines group convolution using appropriate irreps
    # note, we set internal_weights to False since we defined our own filter above
    self.lin = o3.Linear(s2_irreps(lmax), so3_irreps(lmax), 
                         f_in=f_in, f_out=f_out, internal_weights=False)

  def forward(self, x):
    '''Perform S2 group convolution to produce signal over irreps of SO(3).
    First project filter into fourier domain then perform convolution

    :x: tensor of shape (B, f_in, (lmax+1)**2), signal over S2 irreps
    :return: tensor of shape (B, f_out, \sum_l^L (lmax+1)**2)
    '''
    psi = torch.einsum("ni,xyn->xyi", self.Y, self.w) / self.Y.shape[0] ** 0.5
    return self.lin(x, weight=psi)

## Example Instance with Filter over half-sphere
Quick example with filter localized over half the sphere, centered at north pole.  Keep in mind that the complexity of the spherical convolution is based on the maximum degree of the harmonics, so a locally supported filter does not improve speed.  The support of your filter is decided based on any domain knowledge or intuition.  The original spherical CNN work used all locally-supported filters.

In [None]:
def s2_healpix_grid(rec_level: int=0, max_beta: float=np.pi/6):
    """Returns healpix grid up to a max_beta
    """
    n_side = 2**rec_level
    npix = hp.nside2npix(n_side)
    m = hp.query_disc(nside=n_side, vec=(0,0,1), radius=max_beta)
    # print(f'nside: {nside} -> npix: {npix} -> n_in_disc: {len(m)}')
    beta, alpha = hp.pix2ang(n_side, m)
    alpha = torch.from_numpy(alpha)
    beta = torch.from_numpy(beta)
    return torch.stack((alpha, beta)).float()

lmax = 5
kernel_grid = s2_healpix_grid(rec_level=2, max_beta=np.pi/4)
s2_conv = S2Conv(f_in=1, f_out=1, lmax=lmax, kernel_grid=kernel_grid)

# visualize the conv filter in spatial domain
filter = s2_conv.w[0,0].detach()
plot_s2_grid(o3.angles_to_xyz(*kernel_grid), filter, markersize=5)

# visualize the conv filter after it has been converted to harmonics
filter_harmonics = filter @ s2_conv.Y
filter_to_plot = o3.ToS2Grid(lmax, (100,101))(filter_harmonics)
plot_s2_surface(filter_to_plot)

# 4. Applying Non-Linearities
If we are not careful, applying a non-linearity to our representation could destroy the equivariance properties of the network.  Since our representation is coefficients over the irreps of SO(3), we cannot apply a ReLU directly.  Can you intuit why this is the case? (consider removing a harmonic before or after a rotation).

Instead, we will apply a ReLU in the spatial domain, using the Fourier transform to map back and forth.  This will still introduce some minor equivariance errors since it introduces higher frequencies that will be lost in the truncated fourier decomposition.

## 4a. Why Non-Linearities are Bad in Fourier domain

In [None]:
# random 1D signal over s2 irreps
lmax = 4
irreps = s2_irreps(lmax)
signal = torch.randn(((lmax+1)**2), dtype=torch.float32)

# compare effect of randomly rotating signal before or after relu
rot_mtx = irreps.D_from_angles(*torch.randn(3))

signal_rot_before = torch.einsum("ij,...j->...i", rot_mtx, signal)

signal_rot_after = torch.einsum("ij,...j->...i", rot_mtx, torch.relu(signal))

to_s2grid = o3.ToS2Grid(lmax, res=(100, 101))

plot_s2_surface(to_s2grid(signal_rot_before))
plot_s2_surface(to_s2grid(signal_rot_after))

## 4b. Applying ReLU the Long Way

In [None]:
# I will show it for the SO3 case since most of the time the internal features are in SO(3)
lmax = 4
irreps = so3_irreps(lmax)
num_irreps = sum((2*l+1)**2 for l in range(lmax+1))

# 16 channel signal over so3 irreps
so3_signal = torch.randn((32, num_irreps))

# convert to spatial using API (or you could do this by hand with WignerD)
so3_grid = o3.SO3Grid(lmax, resolution=10)
spatial_signal = so3_grid.to_grid(so3_signal)

# apply relu or other non linearity of your choice
spatial_signal = torch.relu(spatial_signal)

# convert back to so3 irreps
so3_grid = so3_grid.from_grid(spatial_signal)

## 4c. Applying ReLU the Easy Way

In [None]:
lmax = 6
s2_relu = e3nn.nn.S2Activation(irreps=s2_irreps(lmax), act=torch.relu, res=20)

# this doesnt exist in the docs but it does in the source code
# note: it has different arguments
so3_relu = e3nn.nn.SO3Activation(lmax_in=lmax, lmax_out=lmax, act=torch.relu, resolution=10)

# 5. $SO(3)$ Convolution

In [None]:
class SO3Conv(nn.Module):
  '''SO3 group convolution

  :f_in: feature dimensionality of input signal
  :f_out: feature dimensionality of output signal
  :lmax: maximum degree of harmonics used to represent input and output signals
         technically, you can have different degrees for input and output, but
         we do not explore that in our work
  :kernel_grid: spatial locations over which the filter is defined (alphas, betas, gammas)
                we find that it is better to parametrize filter in spatial domain
                and project to harmonics at every forward pass
  '''
  def __init__(self, f_in: int, f_out: int, lmax: int, kernel_grid: tuple):
    super().__init__()

    # filter weight parametrized over spatial grid on SO3
    self.register_parameter(
      "w", torch.nn.Parameter(torch.randn(f_in, f_out, kernel_grid.shape[1]))
    )  # [f_in, f_out, n_so3_pts]

    # wigner D matrices used to project spatial signal to irreps of SO(3)
    self.register_buffer("D", flat_wigner(lmax, *kernel_grid))  # [n_so3_pts, \sum_l^L (l+1)**2]

    # defines group convolution using appropriate irreps
    self.lin = o3.Linear(so3_irreps(lmax), so3_irreps(lmax), 
                         f_in=f_in, f_out=f_out, internal_weights=False)

  def forward(self, x):
    '''Perform SO3 group convolution to produce signal over irreps of SO(3).
    First project filter into fourier domain then perform convolution

    :x: tensor of shape (B, f_in, sum_l^L (l+1)**2), signal over SO3 irreps
    :return: tensor of shape (B, f_out, \sum_l^L (l+1)**2)
    '''
    psi = torch.einsum("ni,xyn->xyi", self.D, self.w) / self.D.shape[0] ** 0.5
    return self.lin(x, weight=psi)

## Example SO(3) Convolution with Localized Filter
Here is an example of an SO(3) convolution layer with a filter that is localized to within 30 degrees of the identity rotation. 

**Note**: you can change `lmax` during either S2 or SO(3) convolution.  You can think of this as similar to downsampling/upsampling during 2D convolutions.  For instance, by reducing `lmax` over several layers, the network will learn to extract global (i.e. low frequency) information.

In [None]:
def so3_near_identity_grid(max_beta=np.pi / 8, max_gamma=2 * np.pi, n_alpha=8, n_beta=3, n_gamma=None):
    """
    :return: rings of rotations around the identity, all points (rotations) in
    a ring are at the same distance from the identity
    size of the kernel = n_alpha * n_beta * n_gamma
    """
    if n_gamma is None:
        n_gamma = n_alpha  # similar to regular representations
    beta = torch.arange(1, n_beta + 1) * max_beta / n_beta
    alpha = torch.linspace(0, 2 * np.pi, n_alpha)[:-1]
    pre_gamma = torch.linspace(-max_gamma, max_gamma, n_gamma)
    A, B, preC = torch.meshgrid(alpha, beta, pre_gamma, indexing="ij")
    C = preC - A
    A = A.flatten()
    B = B.flatten()
    C = C.flatten()
    return torch.stack((A, B, C))

kernel_grid = so3_near_identity_grid(max_beta=np.radians(30))
so3_conv = SO3Conv(f_in=1, f_out=1, lmax=5, kernel_grid=kernel_grid)

# visualize kernel grid in S2 by ignoring gamma dim
filter = so3_conv.w[0,0].detach()
plot_s2_grid(o3.angles_to_xyz(*kernel_grid[:2]), filter, markersize=5)

# 6. Summary
With the help of `e3nn`, spherical convolution is straightforward to implement and computationally efficient.  Spherical convolutions are well-suited for designing end-to-end SO(3) equivariant networks to process:
- 360$^\circ$ camera images
- spherical projections of point clouds
-  signals on the sphere (e.g. weather patterns on Earth).  

There are also exciting, unexplored opportunities to apply SO(3) equivariant reasoning to inputs that do not already live on the sphere.  Some ideas are:
- map SO(2) equivariant image features to SO(3)
- map features from discrete subgroup of SO(3) to SO(3)