# E(n)-Equivariant Steerable CNNs  -  A concrete example


In [36]:
import sys
sys.path.append('../')

import torch

from escnn import gspaces
from escnn import nn

Finally, we build a **Steerable CNN** and try it on MNIST.

Let's also use a group a bit larger: we now build a model equivariant to $8$ rotations.
We indicate the group of $N$ discrete rotations as $C_N$, i.e. the **cyclic group** of order $N$.
In this case, we will use $C_8$.

Because the inputs are still gray-scale images, the input type of the model is again a *scalar field*.

However, internally we use *regular fields*: this is equivalent to a *group-equivariant convolutional neural network*.

Finally, we build *invariant* features for the final classification task by pooling over the group using *Group Pooling*.

The final classification is performed by a two fully connected layers.

# The model

Here is the definition of our model:

In [58]:
class C8SteerableCNN(torch.nn.Module):
    
    def __init__(self, n_classes=10):
        
        super(C8SteerableCNN, self).__init__()
        
        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.rot2dOnR2(N=8)
        
        # the input image is a scalar field, corresponding to the trivial representation
        in_type = nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        
        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = in_type
        
        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C8
        out_type = nn.FieldType(self.r2_act, 24*[self.r2_act.regular_repr])
        self.block1 = nn.SequentialModule(
            nn.MaskModule(in_type, 29, margin=1), # rotating a square image causes corners and edges to move into previously empty regions. This makes my image inherently not symmetric. Hence masking a circle around it to bring the symmetry back
            nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, bias=False),
            nn.InnerBatchNorm(out_type), # equivariant version of Batch Normalization. It normalizes all channels together
            nn.ReLU(out_type, inplace=True)
        )
        
        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 48 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool1 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        ) # pooling is invariant
        
        # convolution 3
        # the old output type is the input type to the next layer
        in_type = self.block2.out_type
        # the output type of the third convolution layer are 48 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        # convolution 4
        # the old output type is the input type to the next layer
        in_type = self.block3.out_type
        # the output type of the fourth convolution layer are 96 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96*[self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool2 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        )
        
        # convolution 5
        # the old output type is the input type to the next layer
        in_type = self.block4.out_type
        # the output type of the fifth convolution layer are 96 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96*[self.r2_act.regular_repr])
        self.block5 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        # convolution 6
        # the old output type is the input type to the next layer
        in_type = self.block5.out_type
        # the output type of the sixth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 64*[self.r2_act.regular_repr])
        self.block6 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool3 = nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=1, padding=0)
        
        self.gpool = nn.GroupPooling(out_type)
        
        # number of output channels
        c = self.gpool.out_type.size
        
        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )
    
    def forward(self, input: torch.Tensor):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        x = nn.GeometricTensor(input, self.input_type)
        
        # apply each equivariant block
        
        # Each layer has an input and an output type
        # A layer takes a GeometricTensor in input.
        # This tensor needs to be associated with the same representation of the layer's input type
        #
        # The Layer outputs a new GeometricTensor, associated with the layer's output type.
        # As a result, consecutive layers need to have matching input/output types
        print(x.shape)
        x = self.block1(x)
        x = self.block2(x)
        x = self.pool1(x)
        
        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)
        
        x = self.block5(x)
        x = self.block6(x)
        
        # pool over the spatial dimensions
        print(x.shape)
        x = self.pool3(x)
        
        # pool over the group
        print(x.shape)
        
        x = self.gpool(x)
        print(x.shape)
        a=1/0


        # unwrap the output GeometricTensor
        # (take the Pytorch tensor and discard the associated representation)
        x = x.tensor
        
        # classify with the final fully connected layers)
        x = self.fully_net(x.reshape(x.shape[0], -1))
        
        return x


In [26]:
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = torch.nn.Sequential(         
            torch.nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=3, # prob drop this to 3 (which is normally done)
                stride=1,                   
                padding=1, # not much of a benefit + computationally efficient w/out, but if it was 0, the dimensions would reduce to 26x26
            ),
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.conv2 = torch.nn.Sequential(         
            torch.nn.Conv2d(16, 32, 5, 1, 2),     
            torch.nn.ReLU(),                      
            torch.nn.MaxPool2d(kernel_size=2, stride=2),                
        )
        # fully connected layer, output 10 classes
        self.fc = torch.nn.Sequential(
            torch.nn.Dropout(0.5),
            torch.nn.Linear(32 * 7 * 7, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 10) # reduce number of nodes
        )
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)
        output = self.fc(x)
        return output, x    # return x for visualization

Let's try the model on *rotated* MNIST

In [4]:
# download the dataset
!wget -nc http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip
# uncompress the zip file
!unzip -n mnist_rotation_new.zip -d mnist_rotation_new

zsh:1: command not found: wget
unzip:  cannot find or open mnist_rotation_new.zip, mnist_rotation_new.zip.zip or mnist_rotation_new.zip.ZIP.


In [24]:
from torch.utils.data import Dataset
from torchvision.transforms import RandomRotation
from torchvision.transforms import Pad
from torchvision.transforms import Resize
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose
from torchvision.transforms import InterpolationMode

import numpy as np

from PIL import Image

# Device configuration
comp = 'cpu'
if torch.cuda.is_available():
    comp = 'cuda' # nvidia gpu parallelization
elif torch.backends.mps.is_available():
    comp = 'mps' # mac Metal Performance Shaders (high performance gpu)

device = torch.device(comp)
device


device(type='mps')

Build the dataset

In [6]:
class MnistRotDataset(Dataset):
    
    def __init__(self, mode, transform=None):
        assert mode in ['train', 'test']
            
        if mode == "train":
            file = "mnist_rotation_new/mnist_all_rotation_normalized_float_train_valid.amat"
        else:
            file = "mnist_rotation_new/mnist_all_rotation_normalized_float_test.amat"
        
        self.transform = transform

        data = np.loadtxt(file, delimiter=' ')
            
        self.images = data[:, :-1].reshape(-1, 28, 28).astype(np.float32)
        self.labels = data[:, -1].astype(np.int64)
        self.num_samples = len(self.labels)
    
    def __getitem__(self, index):
        image, label = self.images[index], self.labels[index]
        image = Image.fromarray(image, mode='F')
        if self.transform is not None:
            image = self.transform(image)
        return image, label
    
    def __len__(self):
        return len(self.labels)

# Padding: images are 28x28; we pad the right and bottom with 1 pixel; images are padded to have shape 29x29.
# this allows to use odd-size filters with stride 2 when downsampling a feature map in the model
pad = Pad((0, 0, 1, 1), fill=0)

# to reduce interpolation artifacts (e.g. when testing the model on rotated images),
# we upsample an image by a factor of 3, rotate it and finally downsample it again (resize1)
# 1. Resize from 28x28 to 87x87 (3x upsampling)
resize1 = Resize(87)
# 2. Resize back down to 29x29
resize2 = Resize(29)

# applies random rotation
# 180.  ->  Random rotation from -180° to +180°
# interpolation=InterpolationMode.BILINEAR  ->  Smooths pixels during rotation, reducing aliasing artifacts
# expand=False  ->   Keeps the output image the same size (may crop corners)
randomRotation = RandomRotation(180., interpolation=InterpolationMode.BILINEAR, expand=False)

totensor = ToTensor()

Let's build the model

In [59]:
model = C8SteerableCNN().to(device)
# model = CNN().to(device)

The model is now randomly initialized. 
Therefore, we do not expect it to produce the right class probabilities.

However, the model should still produce the same output for rotated versions of the same image.
This is true for rotations by multiples of $\frac{\pi}{2}$, but is only approximate for rotations by $\frac{\pi}{4}$.

Let's test it on a random test image:
we feed eight rotated versions of the first image in the test set and print the output logits of the model for each of them.

In [53]:

def test_model(model: torch.nn.Module, x: Image):
    np.set_printoptions(linewidth=10000)
    
    # evaluate the `model` on 8 rotated versions of the input image `x`
    model.eval()
    
    x = resize1(pad(x))
    
    print()
    print('##########################################################################################')
    header = 'angle |  ' + '  '.join(["{:6d}".format(d) for d in range(10)])
    print(header)
    with torch.no_grad():
        for r in range(8):
            x_transformed = totensor(resize2(x.rotate(r*45., Image.BILINEAR))).reshape(1, 1, 29, 29)
            x_transformed = x_transformed.to(device)

            y = model(x_transformed)
            y = y.to('cpu').numpy().squeeze()
            
            angle = r * 45
            print("{:5d} : {}".format(angle, y))
    print('##########################################################################################')
    print()


In [40]:
# build the test set    
raw_mnist_test = MnistRotDataset(mode='test')

In [60]:
# retrieve the first image from the test set
x, y = next(iter(raw_mnist_test))

# evaluate the model
test_model(model, x)


##########################################################################################
angle |       0       1       2       3       4       5       6       7       8       9
torch.Size([1, 1, 29, 29])
torch.Size([1, 512, 5, 5])
torch.Size([1, 512, 1, 1])
torch.Size([1, 64, 1, 1])


ZeroDivisionError: division by zero

The output of the model is already almost invariant.
However, we still observe small fluctuations in the outputs.

This is because the model contains some operations which might break equivariance.
For instance, every convolution includes a padding of $2$ pixels per side. This is adds information about the actual orientation of the grid where the image/feature map is sampled because the padding is not rotated with the image. 

During training, the model will observe rotated patterns and will learn to ignore the noise coming from the padding.

So, let's train the model now.
The model is exactly the same used to train a normal *PyTorch* architecture:

In [17]:
train_transform = Compose([
    pad,
    resize1,
    randomRotation,
    resize2,
    totensor,
])

mnist_train = MnistRotDataset(mode='train', transform=train_transform)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64)


test_transform = Compose([
    pad,
    totensor,
])
mnist_test = MnistRotDataset(mode='test', transform=test_transform)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64)

loader = {
    "train": train_loader,
    "test": test_loader
}

loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-5)

In [42]:
# Hady
import torch.nn.functional as F

# Function to test the model using test data
def evaluate(cnn, test_loader):
    cnn.eval()  # Set model to evaluation mode
    total_test_loss = 0
    correct_test = 0
    total_test = 0
    
    with torch.no_grad():  # Disable gradient calculation
        for images, labels in test_loader["test"]:
            images = images.to(device)
            labels = labels.to(device)
            outputs = cnn(images)[0]

            # Apply softmax to get probabilities
            probabilities = F.softmax(outputs, dim=1)

            loss = loss_function(outputs, labels)
            total_test_loss += loss.item()
            
            # Predictions for test accuracy
            _, predicted = torch.max(probabilities, 1)
            print(probabilities)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()
    
    avg_test_loss = total_test_loss / len(test_loader)
    test_accuracy = 100 * correct_test / total_test
    return avg_test_loss, test_accuracy

test_loss, test_accuracy = evaluate(model, loader)

print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
#1:13 secs

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [44]:
# Hady added
import torch.nn.functional as F
import matplotlib as plt

num_epochs = 10

def train_and_evaluate(num_epochs, cnn, loaders):
    # Lists to store loss and accuracy for each epoch
    train_losses = []
    test_losses = []
    train_accuracies = []
    test_accuracies = []
    
    for epoch in range(num_epochs):
        cnn.train()  # Set model to training mode
        
        total_train_loss = 0
        correct_train = 0
        total_train = 0
        
        # Training loop
        for images, labels in loaders['train']:  
            images = images.to(device)
            labels = labels.to(device)
            # Forward pass
            output = cnn(images)[0]
            print(output.shape)

            # Apply softmax to get probabilities
            probabilities = F.softmax(output, dim=1)

            loss = loss_function(output, labels)
            total_train_loss += loss.item()
            
            # Predictions for training accuracy
            _, predicted = torch.max(probabilities, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
            
            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # Calculate training metrics
        avg_train_loss = total_train_loss / len(loaders['train'])
        train_accuracy = 100 * correct_train / total_train
        train_losses.append(avg_train_loss)
        train_accuracies.append(train_accuracy)
        
        # Evaluate on test data
        test_loss, test_accuracy = evaluate(cnn, loaders['test'])
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)
        
        print(f'Epoch [{epoch + 1}/{num_epochs}], '
              f'Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
              f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
    
    # Plot training and test accuracy
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, num_epochs + 1), train_accuracies, label='Training Accuracy')
    plt.plot(range(1, num_epochs + 1), test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training vs. Test Accuracy')
    plt.legend()
    plt.show()
    
    # Plot training and test loss
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
    plt.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training vs. Test Loss')
    plt.legend()
    plt.show()

# Call the function
train_and_evaluate(num_epochs, model, loader)


# NOTE: If loss is bouncing around a lot, that means the lr is too high
# NOTE: train doesnt always show improvements that test does. Hence running a test every epoch and comparing the 

# Overfitting: What it looks like
# train accuracy will increase, but test accuracy decreases

# Underfitting: What it looks like
# both train and test accuracy plateau early

torch.Size([10])


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [49]:
max_batches = 25
for epoch in range(11):
    print(epoch)
    model.train()
    print(len(train_loader))
    for i, (x, t) in enumerate(train_loader):
        # if i >= max_batches:
        #     break
        print("inner", i)
        
        optimizer.zero_grad()

        x = x.to(device)
        t = t.to(device)

        y = model(x)

        loss = loss_function(y, t)

        loss.backward()

        optimizer.step()
    
    if epoch % 10 == 0:
        total = 0
        correct = 0
        with torch.no_grad():
            model.eval()
            print(len(test_loader))
            for i, (x, t) in enumerate(test_loader):
                # if i >= max_batches:
                #     break
                print("inner test", i)

                x = x.to(device)
                t = t.to(device)
                
                y = model(x)

                _, prediction = torch.max(y.data, 1)
                total += t.shape[0]
                correct += (prediction == t).sum().item()
        print(f"epoch {epoch} | test accuracy: {correct/total*100.}")


0
188
inner 0
inner 1
inner 2
inner 3
inner 4
inner 5
inner 6
inner 7
inner 8
inner 9
inner 10
inner 11
inner 12
inner 13
inner 14
inner 15
inner 16
inner 17
inner 18
inner 19
inner 20
inner 21
inner 22
inner 23
inner 24
inner 25
inner 26
inner 27
inner 28
inner 29
inner 30
inner 31
inner 32
inner 33
inner 34
inner 35
inner 36
inner 37
inner 38
inner 39
inner 40
inner 41
inner 42
inner 43


KeyboardInterrupt: 

In [13]:
    
# retrieve the first image from the test set
x, y = next(iter(raw_mnist_test))


# evaluate the model
test_model(model, x)


##########################################################################################
angle |       0       1       2       3       4       5       6       7       8       9
    0 : [-0.0811 -0.5303 -2.0083 -1.6987 -0.951  -3.4474  8.7345 -1.9093 -2.324  -0.8096]
   45 : [-0.0796 -0.3873 -1.6391 -1.8317 -0.5286 -3.4706  8.6717 -2.3546 -2.0176 -0.8092]
   90 : [-0.0811 -0.5303 -2.0083 -1.6987 -0.951  -3.4474  8.7345 -1.9093 -2.324  -0.8096]
  135 : [-0.0796 -0.3873 -1.6391 -1.8317 -0.5286 -3.4706  8.6717 -2.3546 -2.0176 -0.8092]
  180 : [-0.0811 -0.5303 -2.0083 -1.6987 -0.951  -3.4474  8.7345 -1.9093 -2.324  -0.8096]
  225 : [-0.0796 -0.3873 -1.6391 -1.8317 -0.5286 -3.4706  8.6717 -2.3546 -2.0176 -0.8092]
  270 : [-0.0811 -0.5303 -2.0083 -1.6987 -0.951  -3.4474  8.7345 -1.9093 -2.324  -0.8096]
  315 : [-0.0796 -0.3873 -1.6391 -1.8317 -0.5286 -3.4706  8.6717 -2.3546 -2.0176 -0.8092]
##########################################################################################

