# Equivariance Tutorial

In this tutorial, we will discuss the concept of _equivariance_. As a model problem, we consider classifying FashionMNIST, though the concepts can be easily extended to other image processing such as segmentation and denoising (and also tasks that don't involve images...). FashionMNIST is a dataset consisting of thumbnails of items of clothing from Zalando, accompanied by a label (`["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]`). Throughout this tutorial, we will use an "Ankle boot" sample.

In [None]:
import torch
from torch.nn.functional import affine_grid, grid_sample, relu, softmax, pad, conv2d
from torchvision.io import read_image
from torchvision.transforms import Normalize
from models import CNN, PDEGCNN
from lietorch.nn.m2 import LiftM2Cakewavelets
from lietorch.nn.r2 import morphological_convolution_r2, morphological_kernel_r2_isotropic
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['figure.constrained_layout.use'] = True

shoe = read_image("content/shoe.png")[None, ...] / 255.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.set_axis_off()
ax.set_title("Ankle boot")
ax.imshow(shoe.squeeze());

In [None]:
labels = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]

def logits_to_label(logits):
    return labels[logits.argmax()]

norm = Normalize(0.5, 0.5)

cnn_model = CNN()
cnn_model.load_state_dict(torch.load("CNN.pth", weights_only=True))
cnn_model.eval()

def cnn(image):
    with torch.no_grad():
        logits = cnn_model(norm(image))[0]
    return logits_to_label(logits), softmax(logits, dim=-1)

pdegcnn_model = PDEGCNN()
pdegcnn_model.load_state_dict(torch.load("PDEGCNN.pth", weights_only=True))
pdegcnn_model.eval()

def pdegcnn(image):
    with torch.no_grad():
        logits = pdegcnn_model(norm(image))[0]
    return logits_to_label(logits), softmax(logits, dim=-1)

def plot_classification(image, model, model_name, ax):
    probs = model(image)[1]
    colours = len(labels) * ["tab:blue"]
    colours[probs.argmax()] = "tab:red"
    ax.bar(labels, probs, color=colours)
    ax.set_title(model_name)
    ax.set_xlabel("Class")
    ax.set_ylim(0, 1)
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels, rotation=45, ha='right')

In [None]:
def transform(image, x, A):
    """Apply action (x, A) to image with replication padding."""
    B, C, H, W = image.shape
    x = torch.tensor([-1., 1.]) * x
    affine_matrix = torch.hstack((torch.linalg.inv(A.T), x[None, ...].T))
    grid = affine_grid(affine_matrix[None, ...], (B, C, H, W), align_corners=False)
    return grid_sample(image, grid, align_corners=False)

## Lie Groups

> **Definition (Lie Group)** $G$ is a _Lie group_ if it is 
> 1. a _smooth manifold_ - so smooth and looks locally like $\mathbb{R}^n$ - and
> 2. a _group_ - we have a smooth, well-behaved product $\cdot: G \times G \to G$.

The most important Lie groups (imo) encode continuous symmetries on other spaces, with the group product simply given by composition.

> **Example (Translation Group)**
> The $n$-dimensional _translation group_ $\mathbb{R}^n$ acts on Euclidean space $\mathbb{R}^n$ by translation, namely
> $$ (\mathbf{x}, \mathbf{y}) \mapsto \mathbf{x} + \mathbf{y}, $$
> and has group product
> $$ (\mathbf{x}, \mathbf{y}) \mapsto \mathbf{x} + \mathbf{y}. $$
> Consequently, $\mathbb{R}^n$ also acts on the functions on Euclidean space by
> $$ (\mathbf{x}, f) \mapsto (\mathbf{y} \mapsto f(\mathbf{y} - \mathbf{x})). $$
Of course, this is an incredibly boring example. 
Slightly less trivial is the following:
> **Example (Special Orthogonal Group)**
> The _special orthogonal group_ $\operatorname{SO}(n)$ acts on Euclidean space $\mathbb{R}^n$ by rotation, namely
> $$ (R, \mathbf{y}) \mapsto R\mathbf{y}, $$
> and has group product
> $$ (R, S) \mapsto RS, $$
> Consequently, $\operatorname{SO}(n)$ also acts on the functions on Euclidean space by
> $$ (R, f) \mapsto (\mathbf{x} \mapsto f(R^{-1} \mathbf{x})). $$
Here, we represent the elements of $\operatorname{SO}(n)$ as $n \times n$ orthogonal matrices with determinant $1$. 
For example, in two dimensions the counter-clockwise rotation by angle $\theta$ is given by
$$ R = \begin{pmatrix} \cos(\theta) & -\sin(\theta) \\
\sin(\theta) & \cos(\theta) \end{pmatrix}. $$

In this example the group and the space that is acted on can no longer be identified.

We get our favourite group by combining the two previous ones:
> **Example (Special Euclidean Group)**
> The _special Euclidean group_ $\operatorname{SE}(n) := \mathbb{R}^n \rtimes \operatorname{SO}(n)$ acts on Euclidean space $\mathbb{R}^n$ by roto-translation, namely
> $$ ((\mathbf{x}, R), \mathbf{y}) \mapsto \mathbf{x} + R\mathbf{y}, $$
> and has group product
> $$ ((\mathbf{x}, R), (\mathbf{y}, S)) \mapsto (\mathbf{x} + R\mathbf{y}, RS). $$
> Consequently, $\operatorname{SE}(n)$ also acts on the functions on Euclidean space by
> $$ (R, f) \mapsto (\mathbf{y} \mapsto f(R^{-1}(\mathbf{y} - \mathbf{x}))). $$

> _Remark_ The actions on functions are examples of so-called _group representations_, a term often encountered in the equivariance literature. For the sake of simplicity here we simply refer to them as actions.

Finally, we have the most simple group action: doing nothing:
> **Definition (Trivial Action)** Let $G$ be a Lie group and $X$ a set. Then we $G$ acts _trivially_ on $X$ if $g x = x$ for all $g \in G$, $x \in X$.

Many problems have inherent symmetries. For example, if we want to classify the object in an image, rotating the object shouldn't change the classification: 

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].set_axis_off()
ax[1].set_axis_off()
ax[0].set_title("Ankle boot")
ax[1].set_title("Still an ankle boot")

x = torch.tensor([0., 0.])
R = torch.tensor([[1., 0.], [0., 1.]])
ax[0].imshow(transform(shoe, x, R).squeeze())

# Your roto-translation here ⬇️
x = torch.tensor([0., 0.])
theta = torch.tensor([0.8712346])
R = torch.tensor([
    [torch.cos(theta), -torch.sin(theta)],
    [torch.sin(theta), torch.cos(theta)]
])
ax[1].imshow(transform(shoe, x, R).squeeze());

Lie groups give us a mathematically formal way of thinking about these symmetries. In particular, we can now define _equivariance_, which in essence is a symmetry preservation property.

## Equivariance

> **Definition (Equivariance)** Let $G$ be a Lie group acting on $U$ and $V$. 
> $\Phi: U \to V$ is called _equivariant_ if it commutes with the group actions, i.e.
> $$ \Phi \circ g = g \circ \Phi, $$
> for all $g \in G$.

If the action on $V$ is trivial, then we say $\Phi$ is _invariant_. 

Let's work this out for our classification problem.
- We can see images as functions on $\mathbb{R}^2$, on which the Lie group $\operatorname{SE}(2)$ acts by roto-translation: $(\mathbf{x}, R) f(\vec{y}) = f(R^{-1} (\mathbf{y} - \mathbf{x}))$.
- We have a classifier $\Phi$ which maps an image $f: \mathbb{R}^2 \to \mathbb{R}$ to a probability distribution $p \in \mathbb{P}_c := \{(p_1, \ldots, p_c) \mid \sum_{i = 1}^c p_i = 1, p_i \geq 0\}$ over labels $\{1, \ldots, c\}$, where $c$ is the number of classes.
- $\operatorname{SE}(2)$ acts trivially on the range $\mathbb{P}_c$, so we have $(\mathbf{x}, R) p = p$ for all $p \in \mathbb{P}_c$.

Then the classifier is invariant if $\Phi(f) \circ g = g \circ \Phi(f) = \Phi(f)$ for all $g \in \operatorname{SE}(2)$ and images $f: \mathbb{R}^2 \to \mathbb{R}$.

In this problem invariance is clearly a desirable property. But how would we go about constructing an invariant classifier? 

We could train a normal convolutional neural network, and hope that it learns to be invariant. This is highly unlikely, unless we perform _data augmentation_, and even then there are no guarantees.

Alternatively, we could construct a neural network architecture that is inherently invariant. For this, we can make use of the following result:
> **Lemma (Composition of Equivariant Maps)** Let $G$ be a Lie group acting on $U$, $V$, and $W$. Suppose $\Phi: U \to V$ and $\Psi: V \to W$ are equivariant. Then, their composition $\Psi \circ \Phi: U \to W$ is also equivariant. 

_proof_: Simply note that $\Psi \circ \Phi \circ g = \Psi \circ g \circ \Phi = g \circ \Psi \circ \Phi$.

Hence, we can make an equivariant neural network architecture by composing equivariant layers. A typical layer in a neural network consists of the composition of something linear (e.g. matrix multiplication, convolution, linear combinations) with something nonlinear (e.g. activation function, normalisation). Common nonlinearities such as the ReLU activation function act point-wise; it is not hard to see that such point-wise operations are equivariant.

In [None]:
x = torch.tensor([0., 0.])
theta = torch.tensor([torch.pi/2]) # Random number
R = torch.tensor([
    [torch.cos(theta), -torch.sin(theta)],
    [torch.sin(theta), torch.cos(theta)]
])

rotation_before_relu = relu(transform(shoe - 0.5, x, R))
relu_before_rotation = transform(relu(shoe - 0.5), x, R)

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].set_axis_off()
ax[1].set_axis_off()
ax[2].set_axis_off()
ax[0].set_title("Rotation before ReLU")
ax[1].set_title("ReLU before rotation")
ax[2].set_title("Difference")
ax[0].imshow(rotation_before_relu.squeeze())
ax[1].imshow(relu_before_rotation.squeeze())
cbar = ax[2].imshow((relu_before_rotation - rotation_before_relu).squeeze())
fig.colorbar(cbar, ax=ax[2]);

We're working with images in this tutorial; convolutions are a natural choice for the linear part. In a convolutional layer an image is convolved/cross-correlated with a trainable filter (Fig. 5.3 from [1]):
![Schematic depiction of a convolution.](content/convolution.png)
Convolutional layers are translation equivariant: if you shift the input image, the output is shifted accordingly (up to boundary effects...):

In [None]:
kernel = torch.randn(1, 1, 3, 3)

k = 1 # number of pixels to shift
padding = k+1 # deal with boundary issues
x_before = torch.tensor([2 * k / (shoe.shape[-1] + 2 * padding), 0.])
x_after = torch.tensor([2 * k / (shoe.shape[-1] + 2 * (padding - 1)), 0.])
theta = torch.tensor([0.])
R = torch.tensor([
    [torch.cos(theta), -torch.sin(theta)],
    [torch.sin(theta), torch.cos(theta)]
])

shift_before_convolution = conv2d(transform(pad(shoe, 4*[padding], mode="constant"), x, R), kernel, torch.tensor([0.]))
convolution_before_shift = transform(conv2d(pad(shoe, 4*[padding], mode="constant"), kernel, torch.tensor([0.])), x, R)

fig, ax = plt.subplots(2, 2, figsize=(10, 10))
ax[0, 0].set_axis_off()
ax[0, 1].set_axis_off()
ax[1, 0].set_axis_off()
ax[1, 1].set_axis_off()
ax[0, 0].set_title("Shift before convolution")
ax[0, 1].set_title("Convolution before shift")
ax[1, 0].set_title("Difference")
ax[1, 1].set_title("Convolution kernel")
ax[0, 0].imshow(shift_before_convolution.squeeze())
ax[0, 1].imshow(convolution_before_shift.squeeze())
cbar = ax[1, 0].imshow((convolution_before_shift - shift_before_convolution).squeeze())
fig.colorbar(cbar, ax=ax[1, 0])
ax[1, 1].imshow(kernel.squeeze());

However, convolutional layers typically won't be rotation equivariant.

In [None]:
x = torch.tensor([0., 0.])
theta = torch.tensor([torch.pi/2])
R = torch.tensor([
    [torch.cos(theta), -torch.sin(theta)],
    [torch.sin(theta), torch.cos(theta)]
])

rotation_before_convolution = conv2d(transform(shoe, x, R), kernel, torch.tensor([0.]))
convolution_before_rotation = transform(conv2d(shoe, kernel, torch.tensor([0.])), x, R)

fig, ax = plt.subplots(2, 2, figsize=(10, 10))
ax[0, 0].set_axis_off()
ax[0, 1].set_axis_off()
ax[1, 0].set_axis_off()
ax[1, 1].set_axis_off()
ax[0, 0].set_title("Rotation before convolution")
ax[0, 1].set_title("Convolution before rotation")
ax[1, 0].set_title("Difference")
ax[1, 1].set_title("Convolution kernel")
ax[0, 0].imshow(rotation_before_convolution.squeeze())
ax[0, 1].imshow(convolution_before_rotation.squeeze())
cbar = ax[1, 0].imshow((convolution_before_rotation - rotation_before_convolution).squeeze())
fig.colorbar(cbar, ax=ax[1, 0])
ax[1, 1].imshow(kernel.squeeze());

Indeed, convolutions are rotation equivariant if and only if the kernel is isotropic. We can achieve this by averaging a normal (anisotropic) kernel over all possible rotations:

In [None]:
isotropic_kernel = (
    kernel + 
    torch.rot90(kernel, dims=(-2, -1)) + 
    torch.rot90(kernel, k=-1, dims=(-2, -1)) +
    torch.rot90(torch.rot90(kernel, dims=(-2, -1)), dims=(-2, -1))
) / 4

x = torch.tensor([0., 0.])
theta = torch.tensor([torch.pi/2])
R = torch.tensor([
    [torch.cos(theta), -torch.sin(theta)],
    [torch.sin(theta), torch.cos(theta)]
])

rotation_before_convolution = conv2d(transform(shoe, x, R), isotropic_kernel, torch.tensor([0.]))
convolution_before_rotation = transform(conv2d(shoe, isotropic_kernel, torch.tensor([0.])), x, R)

fig, ax = plt.subplots(2, 2, figsize=(10, 10))
ax[0, 0].set_axis_off()
ax[0, 1].set_axis_off()
ax[1, 0].set_axis_off()
ax[1, 1].set_axis_off()
ax[0, 0].set_title("Rotation before convolution")
ax[0, 1].set_title("Convolution before rotation")
ax[1, 0].set_title("Difference")
ax[1, 1].set_title("Convolution kernel")
ax[0, 0].imshow(rotation_before_convolution.squeeze())
ax[0, 1].imshow(convolution_before_rotation.squeeze())
cbar = ax[1, 0].imshow((convolution_before_rotation - rotation_before_convolution).squeeze())
fig.colorbar(cbar, ax=ax[1, 0])
ax[1, 1].imshow(isotropic_kernel.squeeze());

One limitation of this approach is that we are greatly limiting the number of models that can be expressed with the same number of parameters. For example, if we rotate the initial kernel, then that will lead to the same isotropic kernel, even though the kernel parameters are different. This issue can be addressed by _lifting_.

## Lifting

Up to now, we have seen images as functions that map the plane $\mathbb{R}^2$ to a scalar in $\mathbb{R}$. Roughly speaking, the gray value (or maybe the change thereof) at a given point is a measure of the presence of structure at that location. However, when we look at an image, we can say more. 

In [None]:
cross = torch.zeros(9, 9, 3)
cross[4, 1:8] = 1.
cross[1:8, 4] = 1.
cross[4, 2] = torch.tensor([1., 0., 0.])
cross[2, 4] = torch.tensor([0., 0., 1.])

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.set_axis_off()
ax.imshow(cross);

In this simple image of a cross, we can see that the red point is horizontally oriented, while the blue dot is vertically oriented. In other words, at the position of the red dot and the horizontal orientation, there is structure present, but not at e.g. the position of the red dot and the vertical orientation. Hence, we can ascribe to every position and orientation a gray value; we call this an _orientation score_:
> **Definition (Orientation Score Transform)** Let $\psi : \mathbb{R}^2 \to \mathbb{R}$ be a wavelet. Then the _orientation score transform_ $\mathcal{W}_\psi$ maps an image $f : \mathbb{R}^2 \to \mathbb{R}$ to an _orientation score_ $\mathcal{W}_\psi f : \mathbb{R}^2 \times S^1 \cong \operatorname{SE}(2) \to \mathbb{R}$ via 
> $$ \mathcal{W}_\psi f(\mathbf{x}, \theta) = \int_{\mathbb{R}^2} ((\mathbf{x}, R_\theta) \psi)(\mathbf{y}) f(\mathbf{y}) \mathrm{d} \mathbf{y} = \int_{\mathbb{R}^2} \psi(R_\theta^{-1} (\mathbf{y} - \mathbf{x})) f(\mathbf{y}) \mathrm{d} \mathbf{y}. $$
Classically, we choose the filter wavelet $\psi$ to pick up horizontally oriented features; then the gray scale values in the "image" $\mathcal{W}_\psi f(\cdot, \theta): \mathbb{R}^2 \to \mathbb{R}$ are a measure for the presence of a feature at some position with orientation $\theta$:

In [None]:
Or = 16
cross = torch.zeros(1, 1, 64, 64)
cross[..., 31:34, 8:56] = 1.
cross[..., 8:56, 31:34] = 1.
cakewavelet_lift = LiftM2Cakewavelets(cross.shape[-3], Or, inflection_point=1.)
lifted_cross = cakewavelet_lift(cross)

k = 0
dtheta = 2 * torch.pi / Or
fig, ax = plt.subplots(1, 4, figsize=(20, 5))
ax[0].set_axis_off()
ax[1].set_axis_off()
ax[2].set_axis_off()
ax[3].set_axis_off()
ax[0].set_title("$f$")
ax[1].set_title(fr"$\mathcal{{W}}_\psi f(\cdot , {k * dtheta / torch.pi:.2f} \pi)$")
ax[2].set_title(fr"$\mathcal{{W}}_\psi f(\cdot , {(k+Or//4) * dtheta / torch.pi:.2f} \pi)$")
ax[3].set_title(r"$\sum_{\theta \in S^1} \mathcal{W}_\psi f(\cdot , \theta)$")
ax[0].imshow(cross.squeeze())
ax[1].imshow(lifted_cross[0, 0, k])
ax[2].imshow(lifted_cross[0, 0, k+(Or//4)])
ax[3].imshow(lifted_cross.sum(-3).squeeze());

The orientation score transform is equivariant for any wavelet $\psi$:
> **Lemma (Orientation Score Transform is Equivariant)**
> Let $\psi, f: \mathbb{R}^2 \to \mathbb{R}$, and let $g \in \operatorname{SE}(2)$. Then,
> $$ g (\mathcal{W}_\psi f) = \mathcal{W}_\psi (g f). $$
_proof:_ We can simply rewrite:
$$\begin{split}
(\mathbf{x}, R) (\mathcal{W}_\psi f)(\mathbf{y}, S) & := \mathcal{W}_\psi f ((\mathbf{x}, R)^{-1} (\mathbf{y}, S)) = \int_{\mathbb{R}^2} (\mathbf{x}, R)^{-1} (\mathbf{y}, S) \psi(\mathbf{z}) f(\mathbf{z}) \mathrm{d} \mathbf{z} = \int_{\mathbb{R}^2} \psi((\mathbf{y}, S)^{-1} (\mathbf{x}, R) \mathbf{z}) f(\mathbf{z}) \mathrm{d} \mathbf{z} \\
& = \int_{\mathbb{R}^2} \psi((\mathbf{y}, S)^{-1} \mathbf{v}) f((\mathbf{x}, R)^{-1} \mathbf{v}) \mathrm{d} (\mathbf{x}, R)^{-1} \mathbf{v} \overset{(1)}{=} \int_{\mathbb{R}^2} \psi((\mathbf{y}, S)^{-1} \mathbf{v}) f((\mathbf{x}, R)^{-1} \mathbf{v}) \mathrm{d} \mathbf{v} \\
& = \mathcal{W}_\psi ((\mathbf{x}, R) f) (\mathbf{y}, S),
\end{split}$$
where we used the $\operatorname{SE}(2)$ invariance of the standard Lebesgue measure in $(1)$.

We now have a way of lifting images from $\mathbb{R}^2$ to orientation scores on the group $\operatorname{SE}(2)$. The next step is to find equivariant operators on such orientation scores. It turns out that convolutions can be generalised to group convolutions. 
> **Definition (Group Convolution)** Let $G$ be a Lie group, and let $f, k : G \to \mathbb{R}$ be functions thereon. Then we define the _group convolution_ of $f$ and $k$ as
> $$ (k * f)(g) := \int_G k(h^{-1} g) f(h) \mathrm{d} h. $$

To get a feel for group convolutions, consider the following examples.
> **Example (Convolution on the Translation Group)**
> Let $f, k : \mathbb{R}^2 \to \mathbb{R}$. Then, 
> $$ (k * f)(\mathbf{x}) := \int_{\mathbb{R}^2} k(\mathbf{x} - \mathbf{y}) f(\mathbf{y}) \mathrm{d} \mathbf{y}. $$

> **Example (Convolution on the Special Orthogonal Group)**
> Let $f, k : \operatorname{SO}(2) \cong S^1 \to \mathbb{R}$. Then, 
> $$ (k * f)(\theta) := \int_{\operatorname{SO}(2)} k(S^{-1} R_\theta) f(S) \mathrm{d} S = \int_{S^1} k(\theta - \phi) f(\phi) \mathrm{d} \phi. $$

> **Example (Convolution on the Special Euclidean Group)**
> Let $f, k : \operatorname{SE}(2) \cong \mathbb{R}^2 \times S^1 \to \mathbb{R}$. Then, 
> $$ (k * f)(\mathbf{x}, \theta) := \int_{\operatorname{SE}(2)} k((\mathbf{y}, S)^{-1} (\mathbf{x}, R_\theta)) f(\mathbf{y}, S) \mathrm{d} \mathbf{y} \mathrm{d} S = \int_{\mathbb{R}^2 \times S^1} k(R_{\phi}^{-1} (\mathbf{x} - \mathbf{y}), \theta - \phi) f(\mathbf{y}, \phi) \mathrm{d} \mathbf{y} \mathrm{d} \phi. $$

In the same way that normal convolutions - which are defined on functions on Euclidean space/translation group - are translation equivariant, group convolutions are equivariant to the corresponding group.
> **Lemma (Group Convolutions are Equivariant)**
> Let $G$ be a Lie group, let $f, k : G \to \mathbb{R}$ be functions thereon, and let $r \in G$. Then, 
> $$ r (k * f) = k * (r f). $$
_proof:_ We can simply rewrite:
$$\begin{split}
r (k * f)(g) & := (k * f)(r^{-1} g) := \int_G k(h^{-1} r^{-1} g) f(h) \mathrm{d} h = \int_G k((r h)^{-1} g) f(h) \mathrm{d} h \\
& = \int_G k(q^{-1} g) f(r^{-1} q) \mathrm{d} r^{-1}q \overset{(1)}{=} \int_G k(q^{-1} g) f(r^{-1} q) \mathrm{d} q \\
& =  (k * (r f))(g),
\end{split}$$
where we used the left-invariance of the Haar measure in $(1)$.

In practice, we usually want our output to be an image again, so we need a way to equivariantly map orientation scores to images. We call this projection. In machine learning, we typically use max projection:
> **Definition (Max Projection)** Let $f : \operatorname{SE}(2) \to \mathbb{R}$. Then we define the _max projection_ of $f$
> $$ \operatorname{Proj} f(\mathbf{x}, R) := \max_{R \in \operatorname{SO}(2)} f(\mathbf{x}, R). $$
It is not hard to see that max projection is indeed equivariant:
$$\begin{split}
(\mathbf{x}, R) (\operatorname{Proj} f)(\mathbf{y}) & := \operatorname{Proj} f((\mathbf{x}, R)^{-1} \mathbf{y}) = \operatorname{Proj} f(R^{-1}(\mathbf{y} - \mathbf{x})) \\
& = \max_{S \in \operatorname{SO}(2)} f(R^{-1}(\mathbf{y} - \mathbf{x}), S) = \max_{S \in \operatorname{SO}(2)} f(R^{-1}(\mathbf{y} - \mathbf{x}), R^{-1} S) = \max_{S \in \operatorname{SO}(2)} (\mathbf{x}, R) f(\mathbf{y}, S) \\
& = (\operatorname{Proj} (\mathbf{x}, R) f)(\mathbf{y}).
\end{split}$$

Hence, one way to make a roto-translation equivariant neural network architecture is as follows:
1. Start with a lifting layer that maps images on $\mathbb{R}^2$ to orientation scores on the group $\operatorname{SE}(2)$. The lifting wavelets $\psi$ are usually trained.
2. Subsequently apply $\operatorname{SE}(2)$ group convolutions and point-wise nonlinearities in alternating fashion.
3. Project the orientation scores back to images using max projection.

Neural networks with this architecture, which mirrors the classical multi-orientation image processing pipeline [2], are called Group equivariant Convolutional Neural Networks (G-CNNs).
![Multi-orientation processing pipeline: first lift with the orientation score transform, then perform equivariant processing on the orientation scores, and finally project back down to an image.](content/multi-orientation_processing.png)

For our classification task, the output should not be an image but a "vector" in $\mathbb{P}_c := \{(p_1, \ldots, p_c) \mid \sum_{i = 1}^c p_i = 1, p_i \geq 0\}$, on which the group $\operatorname{SE}(2)$ acts trivially. We can achieve this by adding additional layers to our network. The output images need to be invariantly converted to numbers, e.g. by taking the maximum over the image, which can then be combined in whatever manner we like. 

Similar approaches can be taken for other affine groups, such as the translation-scaling group and the similarity group. One limitation of lifting is that it increases memory use: if we use $K$ discrete orientations, then the orientation score will use $K$ times as much memory as the image. This problem gets worse for higher dimensional groups, such as the similarity group $\operatorname{SIM}(2) := \mathbb{R}^2 \rtimes (\operatorname{SO}(2) \times \mathbb{R}_+)$, the group of translations, rotations and scalings: if we use $K$ discrete orientations and $M$ discrete scales, memory use will be $K \times M$ times as large. In other words, memory use scales exponentially in the dimension of the group. Consequently, it is important to choose only the most important symmetries to integrate into the architecture.

## Generalised Convolutions

In the previous section we introduced Group equivariant CNNs (G-CNNs), which first lift the data, then apply a succession of linear convolutions and (ReLU) point-wise nonlinearities in alternating fashion, and finally project the data back to an image. The alternating of linear and nonlinear operations turns out to be crucial for the expressivity of the model. One way to see this is with the following example:
> **Example (Collapse of Linear Functions)** A composition of linear layers can only ever make a linear model, and so we cannot describe nonlinear models.
> Additionally, there is an "inefficiency" in composing linear layers.
> Let $A, B: \mathbb{R}^n \to \mathbb{R}^n$ be linear. Then, $B A: \mathbb{R}^n \to \mathbb{R}^n$ is also linear. 
> Then $B A$ can be described by $n \times n$ parameters in a matrix, whereas together $A$ and $B$ have $2 \times n \times n$ parameters. 

Our research group developed a generalisation of G-CNNs, in which the (fixed) point-wise nonlinearities are replaced with (trainable) generalised convolutions.

Recall the linear group convolution:
$$ (k * f)(g) := \int_G k(h^{-1} g) f(h) \mathrm{d} h. $$
Essentially, we place the kernel $k$ at the correct location $g$, multiply it point-wise with the function $f$, and then integrate that. In a generalised convolution, we change our notion of multiplication and integration. For example, for _morphological_ convolutions, multiplication becomes addition and integration becomes taking the infimum:
$$ (k \square f)(g) := \inf_G (k(h^{-1} g) + f(h)). $$
Morphological convolutions are the core operations in mathematical morphology, and can be used to perform _dilation_ and _erosion_:
$$\begin{align*}
\textrm{Dilation of $f$ by $k$: } & -(k \square -f), \\
\textrm{Erosion of $f$ by $k$: } & (k \square f).
\end{align*}$$
Dilation expands light areas of $f$ according to the kernel $k$, whereas erosion expands the dark areas.

In [None]:
begin = 10
end = 35
data = torch.zeros(1, 1, 64, 64)
data[0, 0, begin:end, begin:end] = 1.
data[0, 0, -end:-begin, -end:-begin] = 1.
kernel = morphological_kernel_r2_isotropic(torch.tensor([0.2]), 4, 0.65)
dilated = -morphological_convolution_r2(-data, kernel)
eroded = morphological_convolution_r2(data, kernel)

fig, ax = plt.subplots(2, 2, figsize=(10, 10))
ax[0, 0].set_axis_off()
ax[0, 1].set_axis_off()
ax[1, 0].set_axis_off()
ax[1, 1].set_axis_off()
ax[0, 0].set_title("Input $f$")
ax[0, 1].set_title("Kernel $k$")
ax[1, 0].set_title("Dilated $-(k □ -f)$")
ax[1, 1].set_title("Eroded $k □ f$")
ax[0, 0].imshow(data.squeeze())
cbar = ax[0, 1].imshow(kernel.squeeze())
fig.colorbar(cbar, ax=ax[0, 1])
ax[1, 0].imshow(dilated.squeeze())
ax[1, 1].imshow(eroded.squeeze());

We can even go a step further, by using parametrised kernels such that the convolutions solve classical image processing PDEs. Such models are called PDE-based Group equivariant CNNs (PDE-G-CNNs) [3]. Here are some of the PDEs that are currently available in LieTorch, which is our implementation of PDE-G-CNNs.
> **Example (Diffusion)** Diffusion, that is 
> $$\partial_t U = \Delta U,$$
> is solved by a linear group convolution with the heat kernel, which is determined by a small number of parameters.

> **Example (Dilation)** The dilation PDE,
> $$\partial_t U = \Vert \nabla U \Vert^\alpha,$$
> is solved by a morphological group convolution with a kernel that is determined by small number of parameters.

> **Example (Erosion)** The erosion PDE,
> $$\partial_t U = -\Vert \nabla U \Vert^\alpha,$$
> is solved by a morphological group convolution with a kernel that is again determined by small number of parameters.

> **Example (Convection)** The convection PDE,
> $$\partial_t U = -\mathcal{A} U,$$
> where $\mathcal{A}$ is a left-invariant vector field, is solved by a linear group convolution with a correctly placed delta peak. 

Here is a brief overview comparing a typical layer in a (G-)CNN with a PDE-G-CNN layer. The PDE evolution consists of a composition of a number of PDEs, for example Convection-Dilation-Erosion (CDE).

![Comparison of a traditional CNN layer and a PDE-G-CNN layer.](content/layer_comparison.webp)

It has been shown in various works that PDE-G-CNNs can achieve performance competitive with e.g. CNNs, but with a large reduction in the number of model parameters. Additionally, they tend to be more data efficient, so they can be applied in situations where data is not abundant.

## Application

This repository contains two trained models: a CNN and a roto-translation invariant PDE-G-CNN. The architectures can be found in the [`models.py`](models.py) module. The models have trained using the [`train.py`](train.py) script.

The CNN is based on the classic LeNet-5 architecture. It consists of two convolutional layers with max-pooling and ReLU activation function, which are (approximately) translation equivariant, followed by three fully connected layers. Notably, the first fully connected layer is _not_ translation invariant, and we also do not have guaranteed translation invariance on the whole model.

For our invariant network, we use a PDE-G-CNN. The first layer lifts the data. Then there are two layers of Convection-Dilation-Erosion PDE evolution, followed by a max projection layer. This leaves us with a collection of images. If we now immediately used a fully connected layer, that would break the equivariance. Hence, we max pool over each image, creating individual invariants. These can safely be used in a fully connected layer.

In [None]:
x = torch.tensor([0., 0.])
theta_on_grid = torch.tensor([torch.pi / 2])
R_on_grid = torch.tensor([
    [torch.cos(theta_on_grid), -torch.sin(theta_on_grid)],
    [torch.sin(theta_on_grid), torch.cos(theta_on_grid)]
])
theta_off_grid = torch.tensor([-0.682354])
R_off_grid = torch.tensor([
    [torch.cos(theta_off_grid), -torch.sin(theta_off_grid)],
    [torch.sin(theta_off_grid), torch.cos(theta_off_grid)]
])

shoe_on_grid = transform(shoe, x, R_on_grid)
shoe_off_grid = transform(shoe, x, R_off_grid)

fig, ax = plt.subplots(3, 3, figsize=(15, 15))
ax[0, 0].set_axis_off()
ax[0, 1].set_axis_off()
ax[0, 2].set_axis_off()
ax[0, 0].set_title("Original")
ax[0, 1].set_title("On-grid rotation")
ax[0, 2].set_title("Off-grid rotation")
ax[0, 0].imshow(shoe.squeeze())
ax[0, 1].imshow(shoe_on_grid.squeeze())
ax[0, 2].imshow(shoe_off_grid.squeeze())
plot_classification(shoe, pdegcnn, "PDE-G-CNN", ax[1, 0])
plot_classification(shoe_on_grid, pdegcnn, "PDE-G-CNN", ax[1, 1])
plot_classification(shoe_off_grid, pdegcnn, "PDE-G-CNN", ax[1, 2])
plot_classification(shoe, cnn, "CNN", ax[2, 0])
plot_classification(shoe_on_grid, cnn, "CNN", ax[2, 1])
plot_classification(shoe_off_grid, cnn, "CNN", ax[2, 2])

This example of an "Ankle boot" is in its canonical orientation correctly classified by both the CNN and the equivariant PDE-G-CNN. However, when we rotate by $\pi / 2$, the output of the PDE-G-CNN remains unchanged, while now the CNN confidently predicts the wrong class label. Rotations by $k \cdot \pi / 2$ are nice because they map the grid to itself. With other rotations, interpolation is necessary. Because of this, even the PDE-G-CNN is not exactly equivariant to off-grid rotations. However, as we can see in this example, it still fairly confidently predicts the right class, whereas the CNN yet again confidently predicts the wrong class.

# References

<a id="Smets2024GeometricProcessing"></a>[1] B.M.N. Smets. Geometric Partial Differential Equations in Deep Learning and Image Processing (2024). <https://research.tue.nl/en/publications/geometric-partial-differential-equations-in-deep-learning-and-ima>

```bib
@phdthesis{smets2024geometric,
  title={Geometric Partial Differential Equations in Deep Learning and Image Processing},
  author={Smets, Bart M.N.},
  year={2024},
  isbn={978-90-386-6133-9},
}
```

<a id="Sherry2025DiffusionSpace"></a>[2] F.M. Sherry, K. Schaefer, R. Duits. Diffusion-Shock PDEs for Deep Learning on Position-Orientation Space. arXiv preprint (2025). <https://doi.org/10.48550/arXiv.2509.06405>

```bib
@article{sherry2025diffusion,
  title={Diffusion-Shock PDEs for Deep Learning on Position-Orientation Space},
  author={Sherry, F.M. and Schaefer, K. and Duits, R.},
  journal={arXiv preprint},
  year={2025},
  doi={10.48550/arXiv.2509.06405},
}
```

<a id="Smets2022PDENetworks"></a>[3] B.M.N. Smets, J. Portegies, E.J. Bekkers, R. Duits. PDE-Based Group Equivariant Convolutional Neural Networks. J Math Imaging Vis (2022). <https://doi.org/10.1007/s10851-022-01114-x>

```bib
@article{smets2022pde,
  title={PDE-based Group Equivariant Convolutional Neural Networks},
  author={Smets, Bart M.N. and Portegies, Jim and Bekkers, Erik J. and Duits, Remco},
  journal={Journal of Mathematical Imaging and Vision},
  publisher={Springer},
  year={2022},
  doi={10.1007/s10851-022-01114-x},
}
```