# E(n)-Equivariant Steerable CNNs  -  Equivariant MLPs


In [1]:
import torch
import numpy as np

from escnn_jax import gspaces
from escnn_jax import nn
from escnn_jax import group

import jax
import jax.numpy as jnp
from typing import List, Tuple, Any, Mapping
from jaxtyping import Array, Float, Int, PyTree, PRNGKeyArray
import equinox as eqx

2023-06-26 11:52:51.933151: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-06-26 11:52:51.968346: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


The **escnn** library also supports MLPs equivariant to compact groups, which can be seen as a special case for $n=0$.
This is done by replacing the convolution layers (e.g. [R3Conv](https://quva-lab.github.io/escnn/api/escnn.nn.html#r3conv)) with the [Linear](https://quva-lab.github.io/escnn/api/escnn.nn.html#linear) layer and by choosing the [no_base_space](https://quva-lab.github.io/escnn/api/escnn.gspaces.html#group-action-trivial-on-single-point) `GSpace` (e.g., instead of [rot3dOnR3](https://quva-lab.github.io/escnn/api/escnn.gspaces.html#escnn.gspaces.rot3dOnR3)). 

All other modules can be used in a similar way, e.g. batch-norm and non-linearities.


Here, we provide an example with `G=SO(3)` and one with `G=O(2)`.

In [3]:
# sphere_grid = "thomson"
sphere_grid = 'ico'

class SO3MLP(nn.EquivariantModule):
    G: group.Group
    gspace: gspaces.GSpace
    layers: List
    
    def __init__(self, key: PRNGKeyArray, n_classes=10):
        keys = jax.random.split(key, 8)

        super(SO3MLP, self).__init__()
        
        # the model is equivariant to the group SO(3)
        self.G = group.so3_group()
        
        # since we are building an MLP, there is no base-space
        self.gspace = gspaces.no_base_space(self.G)
        
        # the input contains the coordinates of a point in the 3D space
        self.in_type = self.gspace.type(self.G.standard_representation())
        
        # Layer 1
        # We will use the representation of SO(3) acting on signals over a sphere, bandlimited to frequency 1
        # To apply a point-wise non-linearity (e.g. ELU), we need to sample the spherical signals over a finite number of points.
        # Note that this makes the equivariance only approximate.
        # The representation of SO(3) on spherical signals is technically a quotient representation,
        # identified by the subgroup of planar rotations, which has id=(False, -1) in our library
        
        # N.B.: the first this model is instantiated, the library computes numerically the spherical grids, which can take some time
        # These grids are then cached on disk, so future calls should be considerably faster.
        
        activation1 = nn.QuotientFourierELU(
            self.gspace,
            subgroup_id=(False, -1),
            channels=3, # specify the number of spherical signals in the output features
            irreps=self.G.bl_sphere_representation(L=1).irreps, # include all frequencies up to L=1
            grid=self.G.sphere_grid(type=sphere_grid, N=16), # build a discretization of the sphere containing 16 equally distributed points            
            # inplace=True
        )
        
        # map with an equivariant Linear layer to the input expected by the activation function, apply batchnorm and finally the activation
        block1 = nn.SequentialModule(
            nn.Linear(self.in_type, activation1.in_type, key=keys[0]),
            # nn.IIDBatchNorm1d(activation1.in_type),
            # activation1,
        )
        
        # Repeat a similar process for a few layers
        
        # 8 spherical signals, bandlimited up to frequency 3
        activation2 = nn.QuotientFourierELU(
            self.gspace,
            subgroup_id=(False, -1),
            channels=8, # specify the number of spherical signals in the output features
            irreps=self.G.bl_sphere_representation(L=3).irreps, # include all frequencies up to L=3
            grid=self.G.sphere_grid(type=sphere_grid, N=40), # build a discretization of the sphere containing 40 equally distributed points            
            # inplace=True
        )
        block2 = nn.SequentialModule(
            nn.Linear(block1.out_type, activation2.in_type, key=keys[1]),
            # nn.IIDBatchNorm1d(activation2.in_type),
            activation2,
        )
        
        # 8 spherical signals, bandlimited up to frequency 3
        activation3 = nn.QuotientFourierELU(
            self.gspace,
            subgroup_id=(False, -1),
            channels=8, # specify the number of spherical signals in the output features
            irreps=self.G.bl_sphere_representation(L=3).irreps, # include all frequencies up to L=3
            grid=self.G.sphere_grid(type=sphere_grid, N=40), # build a discretization of the sphere containing 40 equally distributed points            
            # inplace=True
        )
        block3 = nn.SequentialModule(
            nn.Linear(block2.out_type, activation3.in_type, key=keys[2]),
            # nn.IIDBatchNorm1d(activation3.in_type),
            activation3,
        )
        
        # 5 spherical signals, bandlimited up to frequency 2
        activation4 = nn.QuotientFourierELU(
            self.gspace,
            subgroup_id=(False, -1),
            channels=5, # specify the number of spherical signals in the output features
            irreps=self.G.bl_sphere_representation(L=2).irreps, # include all frequencies up to L=2
            grid=self.G.sphere_grid(type=sphere_grid, N=25), # build a discretization of the sphere containing 25 equally distributed points            
            # inplace=True
        )
        block4 = nn.SequentialModule(
            nn.Linear(block3.out_type, activation4.in_type, key=keys[3]),
            # nn.IIDBatchNorm1d(activation4.in_type),
            activation4,
        )
        
        # Final linear layer mapping to the output features
        # the output is a 5-dimensional vector transforming according to the Wigner-D matrix of frequency 2
        self.out_type = self.gspace.type(self.G.irrep(2))
        block5 = nn.Linear(block4.out_type, self.out_type, key=keys[4])

        self.layers = [block1, block2, block3, block4, block5]
    
    def __call__(self, x: nn.GeometricTensor):
        # check the input has the right type
        assert x.type == self.in_type
        
        # apply each equivariant block
        
        # Each layer has an input and an output type
        # A layer takes a GeometricTensor in input.
        # This tensor needs to be associated with the same representation of the layer's input type
        #
        # The Layer outputs a new GeometricTensor, associated with the layer's output type.
        # As a result, consecutive layers need to have matching input/output types
        for i, layer in enumerate(self.layers):
            x = layer(x)
     
        return x
    
    def evaluate_output_shape(self, input_shape: tuple):
        shape = list(input_shape)
        assert len(shape) ==2, shape
        assert shape[1] == self.in_type.size, shape
        shape[1] = self.out_type.size
        return shape

Let's build the model

In [4]:
SEED = 5678
key = jax.random.PRNGKey(SEED)
model = SO3MLP(key)

Let's test the equivariance of the model

In [5]:
np.set_printoptions(linewidth=10000, precision=4, suppress=True)

model = model.eval()

B = 10

# generates B random points in 3D and wrap them in a GeometricTensor of the right type
x = jax.lax.stop_gradient(jax.random.normal(key, (B, 3)))
x = model.in_type(x)

print('##########################################################################################')
y = model(x)
print("Outputs' magnitudes")
print(jnp.linalg.norm(y.tensor, axis=1).reshape(-1))
print('##########################################################################################')
print("Errors' magnitudes")
for r in range(8):
    # sample a random rotation
    g = model.G.sample()
    
    x_transformed = g @ x
    x_transformed = x_transformed #.to(device)

    y_transformed = model(x_transformed) #.to('cpu')
    
    # verify that f(g@x) = g@f(x)=g@y
    print(jnp.linalg.norm(y_transformed.tensor - (g@y).tensor, axis=1).reshape(-1))        

print('##########################################################################################')
print()



##########################################################################################
Outputs' magnitudes
[0.1641 0.0205 0.1317 0.0536 0.1148 0.0538 0.1501 0.1078 0.1102 0.1649]
##########################################################################################
Errors' magnitudes
[0.0066 0.0004 0.0016 0.0014 0.002  0.0009 0.0038 0.0018 0.002  0.0077]
[0.0086 0.0004 0.0022 0.0014 0.0018 0.0014 0.0018 0.0026 0.0028 0.0082]
[0.0051 0.0004 0.0025 0.0007 0.0026 0.0015 0.0016 0.002  0.0029 0.0071]
[0.0077 0.0001 0.001  0.0009 0.0016 0.0003 0.0036 0.0022 0.0005 0.0024]
[0.009  0.0002 0.0023 0.0012 0.0031 0.0013 0.0046 0.0016 0.0029 0.0068]
[0.0066 0.0003 0.0017 0.0007 0.0013 0.0003 0.0034 0.0018 0.0019 0.0025]
[0.0032 0.0004 0.0023 0.0014 0.0025 0.0015 0.0045 0.001  0.0029 0.0091]
[0.0083 0.0002 0.0017 0.0014 0.0024 0.0014 0.0041 0.0027 0.003  0.0076]
##########################################################################################



In [2]:
class SO3MLPtensor(nn.EquivariantModule):
    G: group.Group
    gspace: gspaces.GSpace
    layers: List
    
    def __init__(self, key: PRNGKeyArray, n_classes=10):
        keys = jax.random.split(key, 8)
        
        super(SO3MLPtensor, self).__init__()
        
        # the model is equivariant to the group SO(3)
        self.G = group.so3_group()
        
        # since we are building an MLP, there is no base-space
        self.gspace = gspaces.no_base_space(self.G)
        
        # the input contains the coordinates of a point in the 3D space
        in_repr = self.G.standard_representation()
        self.in_type = self.gspace.type(in_repr)
        
        # Layer 1
        # We will use the representation of SO(3) acting on signals over a sphere, bandlimited to frequency 2
        # We use the tensor-product non-linearity, which is essentially a quadratic function.
        
        ttype = self.gspace.type(self.G.bl_sphere_representation(L=2))
        activation1 = nn.TensorProductModule(self.in_type, ttype, key=keys[0])
        
        # First we apply batch-norm and then the non-linearity. 
        # In the next blocks, we will also include a Linear layer.
        block1 = nn.SequentialModule(
            # nn.IIDBatchNorm1d(activation1.in_type),
            activation1,
        )
        
        # Repeat a similar process for a few layers
        
        # input and output types must have the same number of fields (here, 8)
        # the input one shouldn't have frequencies higher than the output of the previous block
        activation2 = nn.TensorProductModule(
            in_type = self.gspace.type(*[self.G.bl_sphere_representation(L=2)]*8),
            out_type = self.gspace.type(*[self.G.bl_sphere_representation(L=3)]*8),
            key=keys[1]    
        )
        block2 = nn.SequentialModule(
            nn.Linear(block1.out_type, activation2.in_type, key=keys[2]),
            # nn.IIDBatchNorm1d(activation2.in_type),
            activation2,
        )
        
        activation3 = nn.TensorProductModule(
            in_type = self.gspace.type(*[self.G.bl_sphere_representation(L=3)]*8),
            out_type = self.gspace.type(*[self.G.bl_sphere_representation(L=3)]*8),
            key=keys[3]   
        )
        block3 = nn.SequentialModule(
            nn.Linear(block2.out_type, activation3.in_type, key=keys[4]),
            # nn.IIDBatchNorm1d(activation3.in_type),
            activation3,
        )
        
        activation4 = nn.TensorProductModule(
            in_type = self.gspace.type(*[self.G.bl_sphere_representation(L=3)]*8),
            out_type = self.gspace.type(*[self.G.irrep(2)]*8),    # the final layer only require frequency 2 features, so there is no point in generating other frequencies
            key=keys[5]
        )
        block4 = nn.SequentialModule(
            nn.Linear(block3.out_type, activation4.in_type, key=keys[6]),
            # nn.IIDBatchNorm1d(activation4.in_type),
            activation4,
        )
        
        # Final linear layer mapping to the output features
        # the output is a 5-dimensional vector transforming according to the Wigner-D matrix of frequency 2
        self.out_type = self.gspace.type(self.G.irrep(2))
        block5 = nn.Linear(block4.out_type, self.out_type, key=keys[7])
        self.layers = [block1, block2, block3, block4, block5]
    
    def __call__(self, x: nn.GeometricTensor):
        
        # check the input has the right type
        assert x.type == self.in_type
        
        # apply each equivariant block
        
        # Each layer has an input and an output type
        # A layer takes a GeometricTensor in input.
        # This tensor needs to be associated with the same representation of the layer's input type
        #
        # The Layer outputs a new GeometricTensor, associated with the layer's output type.
        # As a result, consecutive layers need to have matching input/output types
        for i, layer in enumerate(self.layers):
            x = layer(x)
     
        return x

    def evaluate_output_shape(self, input_shape: tuple):
        shape = list(input_shape)
        assert len(shape) ==2, shape
        assert shape[1] == self.in_type.size, shape
        shape[1] = self.out_type.size
        return shape

Let's build the model

In [3]:
SEED = 5678
key = jax.random.PRNGKey(SEED)
model = SO3MLPtensor(key)

Let's test the equivariance of the model

In [4]:
np.set_printoptions(linewidth=10000, precision=4, suppress=True)

model = model.eval()

B = 6

# generates B random points in 3D and wrap them in a GeometricTensor of the right type
x = jax.lax.stop_gradient(jax.random.normal(key, (B, 3)))
x = model.in_type(x)

print('##########################################################################################')
y = model(x)
print("Outputs' magnitudes")
print(jnp.linalg.norm(y.tensor, axis=1).reshape(-1))
print('##########################################################################################')
print("Errors' magnitudes")
for r in range(8):
    # sample a random rotation
    g = model.G.sample()
    
    x_transformed = g @ x
    x_transformed = x_transformed

    y_transformed = model(x_transformed)
    
    # verify that f(g@x) = g@f(x)=g@y
    print(jnp.linalg.norm(y_transformed.tensor - (g@y).tensor, axis=1).reshape(-1))        

print('##########################################################################################')
print()

##########################################################################################
Outputs' magnitudes
[0.1713 0.1712 0.0006 2.4142 0.     0.3532]
##########################################################################################
Errors' magnitudes
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
##########################################################################################



In [8]:
class SO2MLP(nn.EquivariantModule):
    G: group.Group
    gspace: gspaces.GSpace
    layers: List
    
    def __init__(self, key: PRNGKeyArray, n_classes=10):
        keys = jax.random.split(key, 8)
        
        super(SO2MLP, self).__init__()
        
        # the model is equivariant to the group O(2)
        self.G = group.so2_group()
        
        # since we are building an MLP, there is no base-space
        self.gspace = gspaces.no_base_space(self.G)
        
        # the input contains the coordinates of a point in the 2D space
        self.in_type = self.gspace.type(self.G.standard_representation())
        
        # Layer 1
        # We will use the regular representation of SO(2) acting on signals over SO(2) itself, bandlimited to frequency 1
        # Most of the comments on the previous SO(3) network apply here as well
       
        activation1 = nn.FourierELU(
            self.gspace,
            channels=3, # specify the number of signals in the output features
            irreps=self.G.bl_regular_representation(L=1).irreps, # include all frequencies up to L=1
            # inplace=True,
            # the following kwargs are used to build a discretization of the circle containing 6 equally distributed points
            type='regular', N=6,   
        )
        
        # map with an equivariant Linear layer to the input expected by the activation function, apply batchnorm and finally the activation
        block1 = nn.SequentialModule(
            nn.Linear(self.in_type, activation1.in_type, key=keys[0]),
            # nn.IIDBatchNorm1d(activation1.in_type),
            activation1,
        )
        
        # Repeat a similar process for a few layers
        
        # 8 signals, bandlimited up to frequency 3
        activation2 = nn.FourierELU(
            self.gspace,
            channels=8, # specify the number of signals in the output features
            irreps=self.G.bl_regular_representation(L=3).irreps, # include all frequencies up to L=3
            # inplace=True,
            # the following kwargs are used to build a discretization of the circle containing 16 equally distributed points
            type='regular', N=16,
        )
        block2 = nn.SequentialModule(
            nn.Linear(block1.out_type, activation2.in_type, key=keys[1]),
            # nn.IIDBatchNorm1d(activation2.in_type),
            activation2,
        )
        
        # 8 signals, bandlimited up to frequency 3
        activation3 = nn.FourierELU(
            self.gspace,
            channels=8, # specify the number of signals in the output features
            irreps=self.G.bl_regular_representation(L=3).irreps, # include all frequencies up to L=3
            # inplace=True,
            # the following kwargs are used to build a discretization of the circle containing 16 equally distributed points
            type='regular', N=16,
        )
        block3 = nn.SequentialModule(
            nn.Linear(block2.out_type, activation3.in_type, key=keys[2]),
            # nn.IIDBatchNorm1d(activation3.in_type),
            activation3,
        )
        
        # 5 signals, bandlimited up to frequency 2
        activation4 = nn.FourierELU(
            self.gspace,
            channels=5, # specify the number of signals in the output features
            irreps=self.G.bl_regular_representation(L=2).irreps, # include all frequencies up to L=2
            # inplace=True,
            # the following kwargs are used to build a discretization of the circle containing 12 equally distributed points
            type='regular', N=12,
        )
        block4 = nn.SequentialModule(
            nn.Linear(block3.out_type, activation4.in_type, key=keys[3]),
            # nn.IIDBatchNorm1d(activation4.in_type),
            activation4,
        )
        
        # Final linear layer mapping to the output features
        # the output is a 2-dimensional vector rotating with frequency 2
        self.out_type = self.gspace.type(self.G.irrep(2))
        block5 = nn.Linear(block4.out_type, self.out_type, key=keys[4])

        self.layers = [block1, block2, block3, block4, block5]
    
    def __call__(self, x: nn.GeometricTensor):
        
        # check the input has the right type
        assert x.type == self.in_type
        
        # apply each equivariant block
        
        # Each layer has an input and an output type
        # A layer takes a GeometricTensor in input.
        # This tensor needs to be associated with the same representation of the layer's input type
        #
        # The Layer outputs a new GeometricTensor, associated with the layer's output type.
        # As a result, consecutive layers need to have matching input/output types
        for i, layer in enumerate(self.layers):
            x = layer(x)
     
        return x
    
    def evaluate_output_shape(self, input_shape: tuple):
        shape = list(input_shape)
        assert len(shape) ==2, shape
        assert shape[1] == self.in_type.size, shape
        shape[1] = self.out_type.size
        return shape

Let's build the model

In [9]:
SEED = 5678
key = jax.random.PRNGKey(SEED)
model = SO2MLP(key)

Let's test the equivariance of the model

In [12]:
np.set_printoptions(linewidth=10000, precision=4, suppress=True)

model = model.eval()

B = 6

# generates B random points in 3D and wrap them in a GeometricTensor of the right type
x = jax.lax.stop_gradient(jax.random.normal(key, (B, 2)))
x = model.in_type(x)

print('##########################################################################################')
y = model(x)
print("Outputs' magnitudes")
print(jnp.linalg.norm(y.tensor, axis=1).reshape(-1))
print('##########################################################################################')
print("Errors' magnitudes")
for r in range(8):
    # sample a random rotation
    g = model.G.sample()
    
    x_transformed = g @ x
    x_transformed = x_transformed

    y_transformed = model(x_transformed)
    
    # verify that f(g@x) = g@f(x)=g@y
    print(jnp.linalg.norm(y_transformed.tensor - (g@y).tensor, axis=1).reshape(-1))        

print('##########################################################################################')
print()

##########################################################################################
Outputs' magnitudes
[0.0454 0.0744 0.0052 0.0691 0.1648 0.1498]
##########################################################################################
Errors' magnitudes
[0.     0.0006 0.     0.0003 0.0015 0.001 ]
[0.0001 0.0003 0.     0.0003 0.0008 0.0011]
[0.0004 0.0011 0.     0.0003 0.0056 0.0033]
[0.0005 0.0012 0.     0.0006 0.0066 0.0037]
[0.0004 0.0011 0.     0.0004 0.0054 0.0029]
[0.0005 0.0013 0.     0.0004 0.0067 0.0039]
[0.0003 0.0012 0.     0.0003 0.0052 0.0031]
[0.0005 0.0012 0.     0.0007 0.0066 0.0031]
##########################################################################################

