# Gauge Equivariant Convolutional Networks and the Icosahedral CNN
Authors: Zheng Xu
## 1. Introduction and Motivation

Nowadays, deep learning lead by intuition-guided experimentation has achieved great success, but it has not solved the problem that we can not understand why and when certain architectures work well. So every new application of deep learning requires a lot of labor and energy cost to do an extensive architecture search. People seek for general principles to guide architechture search, and one successfull design principle states that network architectures should be equivariant to symmetries.

Equivariant networks have been developed for sets, graphs, and homogeneous spaces like the sphere. In each case, the network is made equivariant to the global symmetries of the underlying space. However, manifolds do not usually have global symmetries, so it's not common to develop equivariant CNNs for manifolds.

General manifolds have local gauge symmetries, which is crucial for building manifold CNNs that depend only on intrinsic geometry. In this paper, the author define a convolution-like operation on general manifolds M that is equivariant to local gauge transformations. The type of inputs and outputs are all feature fields. Each field represented by a number of new feature maps, whose activations are coefficients of a geometrical object. So if the gauge changed, the coefficients change in a predictable way so as to preserve their geometrical meaning, which turns the problem for searching for a geometrically natural definition of "manifold convolution" to gauge equivariance.

Furthurmore, this paper apply the gauge equivariant networks on one specific manifold: the icosehedron. This manifold has some global symmetries(discrete rotations), and the regularity and local flatness of this manifold allows for a very efficient implementation using exsiting deep learning primitives(conv2d).

## 2. Related work
### Equivariant Deep Learning 
Equivariant networks have been proposed for permutation-equivariant analysis and prediction of sets, graphs, translations and rotations of the plane and 3D space. In this paper the author generalize G-CNNs from homogeneous spaces to general manifolds. 
### Geometric Deep Learning
 Geometric Deep Learning Geometric deep learning is concerned with the generalization of (convolutional) neural networks to manifolds. Some of the manifold convolution are gauge equivariant. However, these methods are all limited to particular feature types $\rho$ (typically scalar), and/or use a parameterization of the kernel that is not maximally flexible.

### Spherical CNNs
The Icosahedral CNN can be viewed as a fast and simple alternative to Spherical CNNs.Use a spherical grid based on a subvision of the icosahedron, and convolve over it.

## 3. Implemention of icosahedral CNN
Firstly, the article introduces the gauges, gauge transformations, and the exponential map. The article defines gauge as a position-dependent invertible linear map $w_{p}: \mathbb{R}^{d} \rightarrow T_{p} M$, where $T_{p} M$ is the tangent space of $M$ at $p$. A gauge transformation is a position-dependent change of frame, which can be described by maps $g_{p} \in$ $\mathrm{GL}(d, \mathbb{R})$ (the group of invertible $d \times d$ matrices).The exponential map gives a convenient parameterization of the local neighbourhood of $p \in M$. This map $\exp _{p}: T_{p} M \rightarrow M$ takes a tangent vector $V \in T_{p} M$, follows the geodesic (shortest curve) in the direction of $V$ with speed $\|V\|$ for one unit of time, to arrive at a point $q=\exp _{p} V$ .

Next part is about defining gauge equivariant convolution. Begin with scalar input and output fields. Define a filter as a locally supported function $K: \mathbb{R}^{d} \rightarrow$ $\mathbb{R}$, where $\mathbb{R}^{d}$ may be identified with $T_{p} M$ via the gauge $w_{p}$. Then, writing $q_{v}=\exp _{p} w_{p}(v)$ for $v \in \mathbb{R}^{d}$, we define the scalar convolution of $K$ and $f: M \rightarrow \mathbb{R}$ at $p$ as follows:
$$
(K \star f)(p)=\int_{\mathbb{R}^{d}} K(v) f\left(q_{v}\right) d v
$$

In general case, the transformation behaviour of a $C$ dimensional geometrical quantity is described by a representation of the structure group $G$. This is a mapping $\rho: G \rightarrow \operatorname{GL}(C, \mathbb{R})$ that satisfies $\rho(g h)=\rho(g) \rho(h)$. For general fields, consider a stack of $C_{\text {in }}$ input feature maps on $M$, which represents a $C_{\text {in }}$-dimensional $\rho_{\text {in }}$-field, and define a convolutional operation that output a $C_{\text {out }}$-dimensional $\rho_{\text {out }}$-field. Then describe the filter bank as a matrix-valued kernel $K: \mathbb{R}^{d} \rightarrow \mathbb{R}^{C_{\text {out }} \times C_{\text {in }}}$.Thus we can get the generalized form for general fields:
$$
(K \star f)(p)=\int_{\mathbb{R}^{d}} K(v) \rho_{\text {in }}\left(g_{p \leftarrow q_{v}}\right) f\left(q_{v}\right) d v .
$$

Under a gauge transformation:
$$
\begin{aligned}
v & \mapsto g_{p}^{-1} v, & f\left(q_{v}\right) & \mapsto \rho_{\text {in }}\left(g_{q_{v}}^{-1}\right) f\left(q_{v}\right), \\
w_{p} & \mapsto w_{p} g_{p}, & g_{p \leftarrow q_{v}} & \mapsto g_{p}^{-1} g_{p \leftarrow q_{v}} g_{q_{v}}
\end{aligned}
$$

Gauge equivariant convolution on the icosahedron is implemented in three steps: G-Padding, kernel expansion, and 2d convolution/HexaConv. If the output is to be of the same size as the input, we zero padding.In this paper, instead of zero padding, we copy the pixels from the neighbouring chart. The transformation $g_{i j}(p)$ acts on the feature vector at $p$ via the matrix $\rho\left(g_{i j}(p)\right)$, where $\rho$ is the representation of $G=$ $C_{6}$ associated with the feature space under consideration. In this work we only consider two kinds of representations: scalar features with $\rho(g)=1$, and regular features with $\rho$ equal to the regular representation:

$$
\rho(2 \pi / 6)=\left[\begin{array}{llllll}
0 & 1 & 0 & 0 & 0 & 0 \\
0 & 0 & 1 & 0 & 0 & 0 \\
0 & 0 & 0 & 1 & 0 & 0 \\
0 & 0 & 0 & 0 & 1 & 0 \\
0 & 0 & 0 & 0 & 0 & 1 \\
1 & 0 & 0 & 0 & 0 & 0
\end{array}\right]
$$

That is, a cyclic permutation of 6 elements. Since $2 \pi / 6$ is a generator of $C_{6}$, the value of $\rho$ at the other group elements is determined by this matrix: $\rho(k \cdot 2 \pi / 6)=\rho(2 \pi / 6)^{k}$. If the feature vector consists of multiple scalar or regular features, we would have a block-diagonal matrix $\rho\left(g_{i j}(p)\right)$.

The article implements G-padding by indexing operations on the feature maps. For each position $p$ to be padded, the author precompute $g_{i j}(p)$, which can be $+1 \cdot (2 \pi / 6)$ or 0 or $-1 \cdot (2 \pi / 6)$. We use these to precompute four indexing operations (for the top, bottom, left and right side of the charts).

For the convolution to be gauge equivariant, the kernel must satisfy:
$$
(K \star f)(p)=\int_{\mathbb{R}^{d}} K(v) \rho_{\text {in }}\left(g_{p \leftarrow q_{v}}\right) f\left(q_{v}\right) d v .
$$
. The kernel $K: \mathbb{R}^{2} \rightarrow \mathbb{R}^{R_{\text {out }} C_{\text {out }} \times R_{\text {in }} C_{\text {in }}}$ is stored in an array of shape $\left(R_{\text {out }} C_{\text {out }}, R_{\text {in }} C_{\text {in }}, 3,3\right)$, with the top-right and bottom-left pixel of each $3 \times 3$ filter fixed at zero so that it corresponds to a 1-ring hexagonal kernel.
Weight sharing can be implemented by constructing a basis of kernels, each of which has shape $\left(R_{\text {out }}, R_{\mathrm{in}}, 3,3\right)$ and has value 1 at all pixels of a certain color/shade, and 0 elsewhere. Then one can construct the full kernel by linearly combining these basis filters using learned weights (one for each $C_{\text {in }} \cdot C_{\text {out }}$ input/output channels and basis kernel).

## 4. Algorithm
The algorithm is shown as:
$$
\operatorname{GConv}(\mathrm{f}, \mathrm{w})=\operatorname{conv} 2 \mathrm{~d}(\mathrm{GPad}(\mathrm{f}), \operatorname{expand}(\mathrm{w}))
$$
Where $f$ and $\operatorname{GPad}(f)$ both have shape $\left(B, C_{\text {in }} R_{\text {in }}, 5 H, W\right)$, the weights $w$ have shape $\left(C_{\text {out }}, C_{\text {in }} R_{\text {in }}, 7\right)$, and expand $(w)$ has shape $\left(C_{\text {out }} R_{\text {out }}, C_{\text {in }} R_{\text {in }}, 3,3\right)$. The output of GConv has shape $\left(B, C_{\text {out }} R_{\text {out }}, 5 H, W\right)$.


In [9]:
pip install geomstats
import os
import numpy as np
import matplotlib.pyplot as plt
import geomstats.backend as gs
import geomstats.visualization as visualization


Note: you may need to restart the kernel to use updated packages.


In [None]:
pip install einops

In [None]:
pip install https://github.com/DavidDiazGuerra/icoCNN/zipball/master

In [None]:
from .icoCNN import *
from .icoGrid import icosahedral_grid_coordinates

import icoCNN.tools
import icoCNN.plots

In [None]:
from .icoCNN import *
from .icoGrid import icosahedral_grid_coordinates

import icoCNN.tools
import icoCNN.plots

In [18]:
import torch
import einops
from math import sqrt

__all__ = ["CleanVertices", "SmoothVertices", "PadIco", "ConvIco", "PoolIco", "LNormIco"]


class CleanVertices(torch.nn.Module):
""" 
r : Resolution of the input icosahedral signal
Input : [..., 5, 2^r, 2^(r+1)]
Output : [..., 5, 2^r, 2^(r+1)]
"""
def __init__(self, r):
        super().__init__()
        self.register_buffer('mask', torch.ones((2**r, 2**(r+1))))
        self.mask[0, 0] = 0
        self.mask[0, 2**r] = 0

    def forward(self, x):
        return x * self.mask


class SmoothVertices(torch.nn.Module):
    def __init__(self, r):
        super().__init__()
        self.r = r
        self.clear_vertices = CleanVertices(r)
        self.v1_neighbors = torch.LongTensor([[[chart, 1, 0],
                                                [chart, 1, 1],
                                                [chart, 0, 1],
                                                [chart-1, -1, 2**r],
                                                [chart-1, -1, 2**r-1]] for chart in range(5)])
        self.v2_neighbors = torch.LongTensor([[[chart, 1, 2**r],
                                                [chart, 1, 2**r + 1],
                                                [chart, 0, 2**r + 1],
                                                [chart-1, -1, -1],
                                                [chart, 0, 2**r-1]] for chart in range(5)])
    
    def forward(self, x):
        x = self.clear_vertices(x)
        x[..., 0, 0] += einops.reduce(x[...,
                                        self.v1_neighbors[..., 0],
                                        self.v1_neighbors[..., 1],
                                        self.v1_neighbors[..., 2]],
                                        '... R charts neighbors -> ... 1 charts', 'mean')
        x[..., 0, 2**self.r] += einops.reduce(x[...,
                                                self.v2_neighbors[..., 0],
                                                self.v2_neighbors[..., 1],
                                                self.v2_neighbors[..., 2]],
                                                '... R charts neighbors -> ... 1 charts', 'mean')
        return x


class PadIco(torch.nn.Module):
    """ 
    r : int
    R : int, 1 or 6
        6 when the input signal includes the 6 kernel orientation channels or 1 if it doesn't
    Input : [..., R, 5, 2^r, 2^(r+1)]
    Output : [..., R, 5, 2^r+2, 2^(r+1)+2]
    """
    def __init__(self, r, R, smooth_vertices=False, preserve_vertices=False):
        super().__init__()
        assert R==1 or R==6
        self.R = R
        self.r = r
        self.H = 2**r
        self.W = 2**(r+1)
        self.smooth_vertices = smooth_vertices
        if not preserve_vertices:
            self.process_vertices = SmoothVertices(r) if smooth_vertices else CleanVertices(r)
        else:
            assert not smooth_vertices
            self.process_vertices = lambda x: x
        idx_in= torch.arange(R * 5 * self.H * self.W, dtype=torch.long).reshape(R, 5, self.H, self.W)
        idx_out = torch.zeros((R, 5, self.H + 2, self.W + 2), dtype=torch.long)
        idx_out[..., 1:-1, 1:-1] = idx_in
        idx_out[..., 0, 1:2 ** r + 1] = idx_in.roll(1, -3)[..., -1, 2 ** r:]
        idx_out[..., 0, 2 ** r + 1:-1] = idx_in.roll(1, -3).roll(-1, -4)[..., :, -1].flip(-1)
        idx_out[..., -1, 2:2 ** r + 2] = idx_in.roll(-1, -3).roll(-1, -4)[..., :, 0].flip(-1)
        idx_out[..., -1, 2 ** r + 1:-1] = idx_in.roll(-1, -3)[..., 0, 0:2 ** r]
        idx_out[..., 1:-1, 0] = idx_in.roll(1, -3).roll(1, -4)[..., -1, 0:2 ** r].flip(-1)
        idx_out[..., 2:, -1] = idx_in.roll(-1, -3).roll(1, -4)[..., 0, 2 ** r:].flip(-1)
        self.reorder_idx = idx_out
        def forward(self, x):
        x = self.process_vertices(x)
        if self.smooth_vertices:
            smooth_north_pole = einops.reduce(x[..., -1, 0], '... R charts -> ... 1 1', 'mean')
            smooth_south_pole = einops.reduce(x[..., 0, -1], '... R charts -> ... 1 1', 'mean')
        x = einops.rearrange(x, '... R charts H W -> ... (R charts H W)', R=self.R, charts=5, H=self.H, W=self.W)
        y = x[..., self.reorder_idx]
        if self.smooth_vertices:
            y[..., -1, 1] = smooth_north_pole
            y[..., 1, -1] = smooth_south_pole

        return y


class ConvIco(torch.nn.Module):

    def __init__(self, r, Cin, Cout, Rin, Rout=6, bias=True, smooth_vertices=False):
        super().__init__()
        assert Rin == 1 or Rin == 6
        self.r = r
        self.Cin = Cin
        self.Cout = Cout
        self.Rin = Rin
        self.Rout = Rout

        self.process_vertices = SmoothVertices(r) if smooth_vertices else CleanVertices(r)
        self.padding = PadIco(r, Rin, smooth_vertices=smooth_vertices)

        s = sqrt(2 / (3 * 3 * Cin * Rin))
        self.weight = torch.nn.Parameter(s * torch.randn((Cout, Cin, Rin, 7)))  # s * torch.randn((Cout, Cin, Rin, 7))  #
        if bias:
            self.bias = torch.nn.Parameter(torch.zeros(Cout))
        else:
            self.register_parameter('bias', None)

        self.kernel_expansion_idx = torch.zeros((Cout, Rout, Cin, Rin, 9, 4), dtype=int)
        self.kernel_expansion_idx[..., 0] = torch.arange(Cout).reshape((Cout, 1, 1, 1, 1))
        self.kernel_expansion_idx[..., 1] = torch.arange(Cin).reshape((1, 1, Cin, 1, 1))
        idx_r = torch.arange(0, Rin)
        idx_k = torch.Tensor(((5, 4, -1, 6, 0, 3, -1, 1, 2),
                                (4, 3, -1, 5, 0, 2, -1, 6, 1),
                                (3, 2, -1, 4, 0, 1, -1, 5, 6),
                                (2, 1, -1, 3, 0, 6, -1, 4, 5),
                                (1, 6, -1, 2, 0, 5, -1, 3, 4),
                                (6, 5, -1, 1, 0, 4, -1, 2, 3)))
        for i in range(Rout):
            self.kernel_expansion_idx[:, i, :, :, :, 2] = idx_r.reshape((1, 1, Rin, 1))
            self.kernel_expansion_idx[:, i, :, :, :, 3] = idx_k[i,:]
            idx_r = idx_r.roll(1)

    def extra_repr(self):
        return "r={}, Cin={}, Cout={}, Rin={}, Rout={}, bias={}"\
            .format(self.r, self.Cin, self.Cout, self.Rin, self.Rout, self.bias is not None)

    def get_kernel(self):
        kernel = self.weight[self.kernel_expansion_idx[..., 0],
                                self.kernel_expansion_idx[..., 1],
                                self.kernel_expansion_idx[..., 2],
                                self.kernel_expansion_idx[..., 3]]
        kernel = kernel.reshape((self.Cout, self.Rout, self.Cin, self.Rin, 3, 3))
        kernel[..., 0, 2] = 0
        kernel[..., 2, 0] = 0
        return kernel

    def forward(self, x):
        x = self.padding(x)
        x = einops.rearrange(x, '... C R charts H W -> ... (C R) (charts H) W', C=self.Cin, R=self.Rin, charts=5)
        if x.ndim == 3:
            x = x.unsqueeze(0)
            remove_batch_size = True
        else:
            remove_batch_size = False
            batch_shape = x.shape[:-3]
            x = x.reshape((-1,) + x.shape[-3:])

        kernel = self.get_kernel()
        kernel = einops.rearrange(kernel, 'Cout Rout Cin Rin Hk Wk -> (Cout Rout) (Cin Rin) Hk Wk', Hk=3, Wk=3)
        bias = einops.repeat(self.bias, 'Cout -> (Cout Rout)', Cout=self.Cout, Rout=self.Rout) \
            if self.bias is not None else None

        y = torch.nn.functional.conv2d(x, kernel, bias, padding=(1, 1))
        y = einops.rearrange(y, '... (C R) (charts H) W -> ... C R charts H W', C=self.Cout, R=self.Rout, charts=5)
        y = y[..., 1:-1, 1:-1]
        if remove_batch_size: y = y[0, ...]
        else: y = y.reshape(batch_shape + y.shape[1:])

        return self.process_vertices(y)

class PoolIco(torch.nn.Module):
    def __init__(self, r, R, function=torch.mean, smooth_vertices=False):
        super().__init__()
        self.function = function
        self.padding = PadIco(r, R, smooth_vertices=smooth_vertices)
        self.process_vertices = SmoothVertices(r-1) if smooth_vertices else CleanVertices(r-1)
        self.neighbors = torch.zeros((2**(r-1), 2**r, 7, 2), dtype=torch.long)
        for h in range(self.neighbors.shape[0]):
            for w in range(self.neighbors.shape[1]):
                self.neighbors[h,w,...] = torch.Tensor([[1+2*h,   1+2*w  ],
                                                        [1+2*h+1, 1+2*w  ],
                                                        [1+2*h+1, 1+2*w+1],
                                                        [1+2*h,   1+2*w+1],
                                                        [1+2*h-1, 1+2*w  ],
                                                        [1+2*h-1, 1+2*w-1],
                                                        [1+2*h,   1+2*w-1]])

def forward(self, x):
        x = self.padding(x)
        receptive_field = x[..., self.neighbors[...,0], self.neighbors[...,1]]
        y = self.function(receptive_field, -1)
        return self.process_vertices(y)

class LNormIco(torch.nn.Module):
    def __init__(self, C, R):
        super().__init__()
        self.norm = torch.nn.LayerNorm((C, R), elementwise_affine=False)
        self.weight = torch.nn.Parameter(torch.ones((C, 1)))
        self.bias = torch.nn.Parameter(torch.zeros((C, 1)))

    def forward(self, x):
        x = einops.rearrange(x, "... C R charts H W -> ... charts H W C R")
        original_shape = x.shape
        x = einops.rearrange(x, "... charts H W C R -> (... charts H W) C R")
        x = self.norm(x)
        x = x * self.weight + self.bias
        x = x.reshape(original_shape)
        x = einops.rearrange(x, "... charts H W C R -> ... C R charts H W")
        return x

IndentationError: unindent does not match any outer indentation level (<tokenize>, line 20)

## Analysis

We can compare the Icosahedral CNN with the original Spherical CNN. The Spherical CNN uses feature maps on the sphere $S^{2}$ and rotation group $\mathrm{SO}(3)$ (the latter of which can be thought of a regular field on the sphere), sampled on the SOFT grids defined by (Kostelec \& Rockmore, 2007), which have shape $2 B \times 2 B$ and $2 B \times 2 B \times 2 B$, respectively (here $B$ is the bandwidth / resolution parameter). Specifically, the grid points are:
$$
\begin{aligned}
\alpha_{j_{1}} &=\frac{2 \pi j_{1}}{2 B}, \\
\beta_{k} &=\frac{\pi(2 k+1)}{4 B}, \\
\gamma_{j_{2}} &=\frac{2 \pi j_{2}}{2 B}
\end{aligned}
$$


where $\left(\alpha_{j_{1}}, \beta_{k}\right)$ form a spherical grid and $\left(\alpha_{j_{1}}, \beta_{k}, \gamma_{j_{2}}\right)$ form an $\mathrm{SO}(3)$ grid (for $j_{1}, k, j_{2}=0, \ldots 2 B-1$ ). These grids have two downsides.

Firstly, to get a sufficiently high sampling near the equator, we are forced to oversample the poles, and thus waste computational resources. For almost all applications, a more homogeneous grid is more suitable.

The second downside of the SOFT grid on $\operatorname{SO}(3)$ is that we increase the resolution of the spherical image, the number of rotations applied to each filter is increased as well, which is undesirable.

The grid used in the Icosahedral CNN addresses both concerns. It is spatially very homogeneous, and we apply the filters in 6 orientations, regardless of spatial resolution.

The $\mathrm{SO}(3)$ convolution (used in most layers of a typical Spherical CNN) has complexity $O\left(B^{3} \log B\right)$ which compares favorably to the naive $O\left(B^{6}\right)$ spatial implementation. If we use filters with a fixed (and usually small) size, the complexity of a naive spatial implementation reduces to $O\left(B^{3}\right)$, which is slightly better. Furthermore, because the Icosahedral CNN uses a fixed number of orientations per filter , its computational complexity is even better: it is linear in the number of pixels of the grid, and so comparable to $O\left(B^{2}\right)$ for the SOFT grid.
