# Equivariance Tutorial

In [None]:
import torch
from torch.nn.functional import affine_grid, grid_sample, relu, softmax, pad, conv2d
from torchvision.io import read_image
from torchvision.transforms import Normalize
from models import CNN, PDEGCNN
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'gray'

In [None]:
norm = Normalize(0.5, 0.5)
shoe = norm(read_image("images/shoe.png")[None, ...] / 255.)

In [None]:
labels = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]

def logits_to_label(logits):
    return labels[logits.argmax()]

cnn_model = CNN()
cnn_model.load_state_dict(torch.load("CNN.pth", weights_only=True))
cnn_model.eval()

def cnn(image):
    with torch.no_grad():
        logits = cnn_model(image)[0]
    return logits_to_label(logits), softmax(logits, dim=-1)

pdegcnn_model = PDEGCNN()
pdegcnn_model.load_state_dict(torch.load("PDEGCNN.pth", weights_only=True))
pdegcnn_model.eval()

def pdegcnn(image):
    with torch.no_grad():
        logits = pdegcnn_model(image)[0]
    return logits_to_label(logits), softmax(logits, dim=-1)

def plot_classification(image, model, model_name, ax):
    probs = model(image)[1]
    colours = len(labels) * ["tab:blue"]
    colours[probs.argmax()] = "tab:red"
    ax.bar(labels, probs, color=colours)
    ax.set_title(model_name)
    ax.set_xlabel("Class")
    ax.set_ylim(0, 1)
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels, rotation=45, ha='right')

In [None]:
def transform(image, x, A):
    """Apply action (x, A) to image with circular padding."""
    B, C, H, W = image.shape
    x = torch.tensor([-1., 1.]) * x
    affine_matrix = torch.hstack((torch.linalg.inv(A.T), x[None, ...].T))
    grid = affine_grid(affine_matrix[None, ...], (B, C, H, W), align_corners=False)
    return grid_sample(image, grid, padding_mode="border", align_corners=False)

## Theory

### Lie Groups

> **Definition (Lie Group)** $G$ is a _Lie group_ if it is 
> 1. a _smooth manifold_ - so smooth and looks locally like $\mathbb{R}^n$ - and
> 2. a _group_ - we have a smooth, well-behaved product $\cdot: G \times G \to G$.

The most important Lie groups (imo) encode continuous symmetries on other spaces, with the group product simply given by composition.

> **Example (Translation Group)**
> The $n$-dimensional _translation group_ $\mathbb{R}^n$ acts on Euclidean space $\mathbb{R}^n$ by translation, namely
> $$ (\mathbf{x}, \mathbf{y}) \mapsto \mathbf{x} + \mathbf{y}, $$
> and has group product
> $$ (\mathbf{x}, \mathbf{y}) \mapsto \mathbf{x} + \mathbf{y}. $$
> Consequently, $\mathbb{R}^n$ also acts on the functions on Euclidean space by
> $$ (\mathbf{x}, f) \mapsto (\mathbf{y} \mapsto f(\mathbf{y} - \mathbf{x})). $$
Of course, this is an incredibly boring example. 
Slightly less trivial is the following:
> **Example (Special Orthogonal Group)**
> The _special orthogonal group_ $\operatorname{SO}(n)$ acts on Euclidean space $\mathbb{R}^n$ by rotation, namely
> $$ (R, \mathbf{y}) \mapsto R\mathbf{y}, $$
> and has group product
> $$ (R, S) \mapsto RS, $$
> Consequently, $\operatorname{SO}(n)$ also acts on the functions on Euclidean space by
> $$ (R, f) \mapsto (\mathbf{x} \mapsto f(R^{-1} \mathbf{x})). $$
Here, we represent the elements of $\operatorname{SO}(n)$ as $n \times n$ orthogonal matrices with determinant $1$. 
For example, in two dimensions the counter-clockwise rotation by angle $\theta$ is given by
$$ R = \begin{pmatrix} \cos(\theta) & -\sin(\theta) \\
\sin(\theta) & \cos(\theta) \end{pmatrix}. $$

In this example the group and the space that is acted on can no longer be identified.

We get our favourite group by combining the two previous ones:
> **Example (Special Euclidean Group)**
> The _special Euclidean group_ $\operatorname{SE}(n)$ acts on Euclidean space $\mathbb{R}^n$ by roto-translation, namely
> $$ ((\mathbf{x}, R), \mathbf{y}) \mapsto \mathbf{x} + R\mathbf{y}, $$
> and has group product
> $$ ((\mathbf{x}, R), (\mathbf{y}, S)) \mapsto (\mathbf{x} + R\mathbf{y}, RS). $$
> Consequently, $\operatorname{SE}(n)$ also acts on the functions on Euclidean space by
> $$ (R, f) \mapsto (\mathbf{y} \mapsto f(R^{-1}(\mathbf{y} - \mathbf{x}))). $$

> _Remark_ The actions on functions are examples of so-called _group representations_, a term often encountered in the equivariance literature. For the sake of simplicity here we simply refer to them as actions.

Finally, we have the most simple group action: doing nothing:
> **Definition (Trivial Action)** Let $G$ be a Lie group and $X$ a set. Then we $G$ acts _trivially_ on $X$ if $g x = x$ for all $g \in G$, $x \in X$.

Many problems have inherent symmetries. For example, if we want to classify the object in an image, rotating the object shouldn't change the classification: 

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)
ax[0].set_axis_off()
ax[1].set_axis_off()
ax[0].set_title("Shoe")
ax[1].set_title("Still a shoe")

x = torch.tensor([0., 0.])
R = torch.tensor([[1., 0.], [0., 1.]])
ax[0].imshow(transform(shoe, x, R).squeeze())

# Your roto-translation here ⬇️
x = torch.tensor([0., 0.])
theta = torch.tensor([0.8712346])
R = torch.tensor([
    [torch.cos(theta), -torch.sin(theta)],
    [torch.sin(theta), torch.cos(theta)]
])
ax[1].imshow(transform(shoe, x, R).squeeze());

Lie groups give us a mathematically formal way of thinking about these symmetries. In particular, we can now define _equivariance_, which in essence is a symmetry preservation property.

### Equivariance

> **Definition (Equivariance)** Let $G$ be a Lie group acting on $U$ and $V$. 
> $\Phi: U \to V$ is called _equivariant_ if it commutes with the group actions, i.e.
> $$ \Phi \circ g = g \circ \Phi, $$
> for all $g \in G$.

If the action on $V$ is trivial, then we say $\Phi$ is _invariant_. 

Let's work this out for our classification problem.
- We can see images as functions on $\mathbb{R}^2$, on which the Lie group $\operatorname{SE}(2)$ acts by roto-translation: $(\mathbf{x}, R) f(\vec{y}) = f(R^{-1} (\mathbf{y} - \mathbf{x}))$.
- We have a classifier $\Phi$ which maps an image $f: \mathbb{R}^2 \to \mathbb{R}$ to a label $k \in \{1, \ldots, c\}$, where $c$ is the number of classes.
- $\operatorname{SE}(2)$ acts trivially on the range $\{1, \ldots, c\}$, so we have $(\mathbf{x}, R) k = k$ for all $k \in \{1, \ldots, c\}$.

Then the classifier is invariant if $\Phi(f) \circ g = g \circ \Phi(f) = \Phi(f)$ for all $g \in \operatorname{SE}(2)$ and images $f: \mathbb{R}^2 \to \mathbb{R}$.

In this problem invariance is clearly a desirable property. But how would we go about constructing an invariant classifier? 

We could train a normal convolutional neural network, and hope that it learns to be invariant. This is highly unlikely, unless we perform _data augmentation_ - - and even then there are no guarantees.

Alternatively, we could construct a neural network architecture that is inherently invariant. For this, we can make use of the following result:
> **Lemma (Composition of Equivariant Maps)** Let $G$ be a Lie group acting on $U$, $V$, and $W$. Suppose $\Phi: U \to V$ and $\Psi: V \to W$ are equivariant. Then, their composition $\Psi \circ \Phi: U \to W$ is also equivariant. 

_proof_: Simply note that $\Psi \circ \Phi \circ g = \Psi \circ g \circ \Phi = g \circ \Psi \circ \Phi$.

Hence, we can make an equivariant neural network architecture by composing equivariant layers. A typical layer in a neural network consists of the composition of something linear (e.g. matrix multiplication, convolution, linear combinations) with something nonlinear (e.g. activation function, normalisation). Common nonlinearities such as the ReLU activation function and batch normalisation act point-wise; it is not hard to see that such point-wise operations are equivariant.

In [None]:
x = torch.tensor([0., 0.])
theta = torch.tensor([-0.862364]) # Random number
R = torch.tensor([
    [torch.cos(theta), -torch.sin(theta)],
    [torch.sin(theta), torch.cos(theta)]
])

fig, ax = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)
ax[0].set_axis_off()
ax[1].set_axis_off()
ax[0].set_title("Transform before ReLU")
ax[1].set_title("ReLU before Transform")
ax[0].imshow(relu(transform(shoe, x, R)).squeeze())
ax[1].imshow(transform(relu(shoe), x, R).squeeze());

We're working with images in this tutorial; convolutions are a natural choice for the linear part. In a convolutional layer an image is convolved/cross-correlated with a trainable filter (Fig. 5.3 from [1]):
![Schematic depiction of a convolution.](convolution.png)
Convolutional layers are translation equivariant: if you shift the input image, the output is shifted accordingly (up to boundary effects...):

In [None]:
kernel = torch.randn(1, 1, 3, 3)

k = 4 # number of pixels to shift
padding = k + 1 # deal with boundary issues
x = torch.tensor([k * (2 / (shoe.shape[-1] + 2 * padding)), 0.])
theta = torch.tensor([0.])
R = torch.tensor([
    [torch.cos(theta), -torch.sin(theta)],
    [torch.sin(theta), torch.cos(theta)]
])

fig, ax = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)
ax[0].set_axis_off()
ax[1].set_axis_off()
ax[0].set_title("Shift before Convolution")
ax[1].set_title("Convolution before Shift")
ax[0].imshow(conv2d(transform(pad(shoe, 4*[padding], mode="replicate"), x, R), kernel, torch.tensor([0.])).squeeze())
ax[1].imshow(transform(conv2d(pad(shoe, 4*[padding], mode="replicate"), kernel, torch.tensor([0.])), x, R).squeeze())

However, convolutional layers typically won't be rotation equivariant.

In [None]:
x = torch.tensor([0., 0.])
theta = torch.tensor([torch.pi/2])
R = torch.tensor([
    [torch.cos(theta), -torch.sin(theta)],
    [torch.sin(theta), torch.cos(theta)]
])

fig, ax = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)
ax[0].set_axis_off()
ax[1].set_axis_off()
ax[0].set_title("Rotate before Convolution")
ax[1].set_title("Convolution before Rotate")
ax[0].imshow(conv2d(transform(shoe, x, R), kernel, torch.tensor([0.])).squeeze())
ax[1].imshow(transform(conv2d(shoe, kernel, torch.tensor([0.])), x, R).squeeze())

Indeed, convolutions are rotation equivariant if and only if the kernel is isotropic. We can achieve this by averaging a normal (anisotropic) kernel over all possible rotations:

In [None]:
isotropic_kernel = (
    kernel + 
    torch.rot90(kernel, dims=(-2, -1)) + 
    torch.rot90(kernel, k=-1, dims=(-2, -1)) +
    torch.rot90(torch.rot90(kernel, dims=(-2, -1)), dims=(-2, -1))
) / 4

x = torch.tensor([0., 0.])
theta = torch.tensor([torch.pi/2])
R = torch.tensor([
    [torch.cos(theta), -torch.sin(theta)],
    [torch.sin(theta), torch.cos(theta)]
])

fig, ax = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)
ax[0].set_axis_off()
ax[1].set_axis_off()
ax[0].set_title("Rotate before Convolution")
ax[1].set_title("Convolution before Rotate")
ax[0].imshow(conv2d(transform(shoe, x, R), isotropic_kernel, torch.tensor([0.])).squeeze())
ax[1].imshow(transform(conv2d(shoe, isotropic_kernel, torch.tensor([0.])), x, R).squeeze());

One limitation of this approach is that we are greatly limiting the number of models that can be expressed with the same number of parameters. For example, if we rotate the initial kernel, then that will lead to the same isotropic kernel, even though the kernel parameters are different. This issue can be addressed by _lifting_.

### Lifting

## Application

This repository contains two trained models: a CNN and a roto-translation invariant PDE-G-CNN. The architectures can be found in the [`models.py`](models.py) module. The models have trained using the [`train.py`](train.py) script.

The CNN is based on the classic LeNet-5 architecture. It consists of two convolutional layers with max-pooling, which are (approximately) translation equivariant, followed by three fully connected layers. Notably, the first fully connected layer is _not_ translation invariant, and we also do not have guaranteed translation invariance on the whole model.

PDE-based Group equivariant CNNs (PDE-G-CNNs) [2]. They are closely related to Group equivariant CNNs (G-CNNs)

![Comparison of a traditional CNN layer and a PDE-G-CNN layer.](layer_comparison.webp)

In [None]:
x = torch.tensor([0., 0.])
theta = torch.tensor([torch.pi/2])
R = torch.tensor([
    [torch.cos(theta), -torch.sin(theta)],
    [torch.sin(theta), torch.cos(theta)]
])

fig, ax = plt.subplots(2, 2, figsize=(10, 10), constrained_layout=True)
plot_classification(shoe, pdegcnn, "PDE-G-CNN", ax[0, 0])
plot_classification(transform(shoe, x, R), pdegcnn, "PDE-G-CNN", ax[0, 1])
plot_classification(shoe, cnn, "CNN", ax[1, 0])
plot_classification(transform(shoe, x, R), cnn, "CNN", ax[1, 1])

# References

<a id="Smets2024GeometricProcessing"></a>[1] B.M.N. Smets. Geometric Partial Differential Equations in Deep Learning and Image Processing (2024). <https://research.tue.nl/en/publications/geometric-partial-differential-equations-in-deep-learning-and-ima>

```bib
@phdthesis{smets2024geometric,
  title={Geometric Partial Differential Equations in Deep Learning and Image Processing},
  author={Smets, Bart M.N.},
  year={2024},
  isbn={978-90-386-6133-9},
}
```

<a id="Smets2022PDENetworks"></a>[2] B.M.N. Smets, J. Portegies, E.J. Bekkers, R. Duits. PDE-Based Group Equivariant Convolutional Neural Networks. J Math Imaging Vis (2022). <https://doi.org/10.1007/s10851-022-01114-x>

```bib
@article{smets2022pde,
  title={PDE-based Group Equivariant Convolutional Neural Networks},
  author={Smets, Bart M.N. and Portegies, Jim and Bekkers, Erik J. and Duits, Remco},
  journal={Journal of Mathematical Imaging and Vision},
  publisher={Springer},
  year={2022},
  doi={10.1007/s10851-022-01114-x},
}
```