In [1]:
%load_ext autoreload
%autoreload 2

import torch

from e2cnn import gspaces
from e2cnn import nn
from tqdm import tqdm
import numpy as np

from tutorial import MnistRotDataset, test_model, device

In [2]:
download_data = False

if download_data:
    # download the dataset
    !wget -nc http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip
    # uncompress the zip file
    !unzip -n mnist_rotation_new.zip -d mnist_rotation_new

In [3]:
# build the test set    
raw_mnist_test = MnistRotDataset(mode='test')

# retrieve the first image from the test set
x, y = next(iter(raw_mnist_test))

In [4]:
class C8SteerableCNN(torch.nn.Module):
    
    def __init__(self, n_classes=10):
        
        super(C8SteerableCNN, self).__init__()
        
        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=8)
        
        # the input image is a scalar field, corresponding to the trivial representation
        in_type = nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        
        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = in_type
        
        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C8
        out_type = nn.FieldType(self.r2_act, 24*[self.r2_act.regular_repr])
        self.block1 = nn.SequentialModule(
            nn.MaskModule(in_type, 29, margin=1),
            nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 48 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool1 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        )
        
        # convolution 3
        # the old output type is the input type to the next layer
        in_type = self.block2.out_type
        # the output type of the third convolution layer are 48 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        # convolution 4
        # the old output type is the input type to the next layer
        in_type = self.block3.out_type
        # the output type of the fourth convolution layer are 96 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96*[self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool2 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        )
        
        # convolution 5
        # the old output type is the input type to the next layer
        in_type = self.block4.out_type
        # the output type of the fifth convolution layer are 96 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96*[self.r2_act.regular_repr])
        self.block5 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        # convolution 6
        # the old output type is the input type to the next layer
        in_type = self.block5.out_type
        # the output type of the sixth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 64*[self.r2_act.regular_repr])
        self.block6 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool3 = nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=1, padding=0)
        
        self.gpool = nn.GroupPooling(out_type)
        
        # number of output channels
        c = self.gpool.out_type.size
        
        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )
    
    def forward(self, input: torch.Tensor):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        print('\nInput')
        print(input.shape)
        x = nn.GeometricTensor(input, self.input_type)
        print(x.shape)
        
        # apply each equivariant block
        
        # Each layer has an input and an output type
        # A layer takes a GeometricTensor in input.
        # This tensor needs to be associated with the same representation of the layer's input type
        #
        # The Layer outputs a new GeometricTensor, associated with the layer's output type.
        # As a result, consecutive layers need to have matching input/output types
        print('\nBlock A')
        print(x.shape)
        x = self.block1(x)
        print(x.shape)
        x = self.block2(x)
        print(x.shape)
        x = self.pool1(x)
        print(x.shape)
        
        print('\nBlock B')
        print(x.shape)
        x = self.block3(x)
        print(x.shape)
        x = self.block4(x)
        print(x.shape)
        x = self.pool2(x)
        print(x.shape)
        
        print('\nBlock C')
        print(x.shape)
        x = self.block5(x)
        print(x.shape)
        x = self.block6(x)
        print(x.shape)
        x = self.pool3(x) # pool over the spatial dimensions
        print(x.shape)
        

        print('Group Pooling')
        # pool over the group
        print(x.shape)
        x = self.gpool(x)
        print(x.shape)

        # unwrap the output GeometricTensor
        # (take the Pytorch tensor and discard the associated representation)
        x = x.tensor
        print(x.shape)
        
        print('Fully Connected')
        # classify with the final fully connected layers)
        print(x.shape)
        x = self.fully_net(x.reshape(x.shape[0], -1))
        print(x.shape)
        
        return x

model = C8SteerableCNN().to(device)
# evaluate the model
# test_model(model, x)
model.eval()
model(torch.randn(1, 1, 29, 29).to(device))

  full_mask[mask] = norms.to(torch.uint8)



Input
torch.Size([1, 1, 29, 29])
torch.Size([1, 1, 29, 29])

Block A
torch.Size([1, 1, 29, 29])
torch.Size([1, 192, 25, 25])
torch.Size([1, 384, 25, 25])
torch.Size([1, 384, 13, 13])

Block B
torch.Size([1, 384, 13, 13])
torch.Size([1, 384, 13, 13])
torch.Size([1, 768, 13, 13])
torch.Size([1, 768, 7, 7])

Block C
torch.Size([1, 768, 7, 7])
torch.Size([1, 768, 7, 7])
torch.Size([1, 512, 5, 5])
torch.Size([1, 512, 1, 1])
Group Pooling
torch.Size([1, 512, 1, 1])
torch.Size([1, 64, 1, 1])
torch.Size([1, 64, 1, 1])
Fully Connected
torch.Size([1, 64, 1, 1])
torch.Size([1, 10])


tensor([[-0.5282, -0.0270,  0.1523, -0.1359, -0.1879,  0.0852,  0.1005,  0.0568,
          0.2910,  0.3013]], grad_fn=<AddmmBackward0>)

In [7]:
24*8, 48*8, 48*8, 96*8, 96*8, 64*8

(192, 384, 384, 768, 768, 512)