# General E(2)-Equivariant Steerable CNNs  -  A concrete example


In [1]:
import os
import csv
import time
import random
from tqdm import tqdm

import numpy as np
from PIL import Image
import cv2

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.transforms import RandomRotation, Pad, Resize, ToTensor, Compose
from torchsummary import summary

from e2cnn import gspaces
from e2cnn import nn

Finally, we build a **Steerable CNN** and try it MNIST.

Let's also use a group a bit larger: we now build a model equivariant to $8$ rotations.
We indicate the group of $N$ discrete rotations as $C_N$, i.e. the **cyclic group** of order $N$.
In this case, we will use $C_8$.

Because the inputs are still gray-scale images, the input type of the model is again a *scalar field*.

However, internally we use *regular fields*: this is equivalent to a *group-equivariant convolutional neural network*.

Finally, we build *invariant* features for the final classification task by pooling over the group using *Group Pooling*.

The final classification is performed by a two fully connected layers.

# The model

Here is the definition of our model:

In [2]:
class C8SteerableCNN(torch.nn.Module):
    
    def __init__(self, n_classes=10, cyclic_group=8):
        self.cyclic_group=cyclic_group
        
        super(C8SteerableCNN, self).__init__()
        
        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=self.cyclic_group)
        #print("self.r2_act", self.r2_act)
        
        # the input image is a scalar field, corresponding to the trivial representation
        in_type = nn.FieldType(self.r2_act, 3*[self.r2_act.trivial_repr])
        #print("in_type", in_type)
        
        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = in_type
        
        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C8
        out_type = nn.FieldType(self.r2_act, 24*[self.r2_act.regular_repr])
        #print("out_type", out_type)
        self.block1 = nn.SequentialModule(
            nn.MaskModule(in_type, 80, margin=1),
            nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        print('block1', self.block1.out_type.size)
        
        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 48 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        print('block2', self.block2.out_type.size)
        self.pool1 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        )
        print('pool1', self.pool1.out_type.size)
        
        # convolution 3
        # the old output type is the input type to the next layer
        in_type = self.block2.out_type
        # the output type of the third convolution layer are 48 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        print('block3', self.block3.out_type.size)
        
        # convolution 4
        # the old output type is the input type to the next layer
        in_type = self.block3.out_type
        # the output type of the fourth convolution layer are 96 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96*[self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        print('block4', self.block4.out_type.size)
        self.pool2 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        )
        print('pool2', self.pool2.out_type.size)
        
        # convolution 5
        # the old output type is the input type to the next layer
        in_type = self.block4.out_type
        # the output type of the fifth convolution layer are 96 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96*[self.r2_act.regular_repr])
        self.block5 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        print('block5', self.block5.out_type.size)
        
        # convolution 6
        # the old output type is the input type to the next layer
        in_type = self.block5.out_type
        # the output type of the sixth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 64*[self.r2_act.regular_repr])
        self.block6 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        print('block6', self.block6.out_type.size)
        self.pool3 = nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=1, padding=0)
        print('pool3', self.pool3.out_type.size)
        
        self.gpool = nn.GroupPooling(out_type) # pool3.out_type
        
        # number of output channels
        c = self.gpool.out_type.size
        print('gpool', c)
        
        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c*13*13, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )
    
    def forward(self, input: torch.Tensor):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        #print("input.shape", input.shape)
        x = nn.GeometricTensor(input, self.input_type)
        
        # apply each equivariant block
        
        # Each layer has an input and an output type
        # A layer takes a GeometricTensor in input.
        # This tensor needs to be associated with the same representation of the layer's input type
        #
        # The Layer outputs a new GeometricTensor, associated with the layer's output type.
        # As a result, consecutive layers need to have matching input/output types
        x = self.block1(x)
        x = self.block2(x)
        x = self.pool1(x)
        
        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)
        
        x = self.block5(x)
        x = self.block6(x)
        
        # pool over the spatial dimensions
        x = self.pool3(x)
        #print("x.shape1", x.shape)
        
        # pool over the group
        x = self.gpool(x)
        #print("x.shape2", x.shape)

        # unwrap the output GeometricTensor
        # (take the Pytorch tensor and discard the associated representation)
        x = x.tensor
        #print("x.shape3", x.shape)
        
        # classify with the final fully connected layers)
        x = self.fully_net(x.reshape(x.shape[0], -1))
        
        return x

In [6]:
# Not rotationally-equivariant architecture, mimics above as closely as possible
class NonRECNN(torch.nn.Module):
    def __init__(self, n_classes=10):
        super().__init__()
        
        # convolution 1
        self.block1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3, out_channels=24, kernel_size=7, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(num_features=24),
            torch.nn.ReLU(inplace=True)
        )
        
        # convolution 2
        self.block2 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=24, out_channels=48, kernel_size=5, stride=1, padding=2, bias=False),
            torch.nn.BatchNorm2d(num_features=48),
            torch.nn.ReLU(inplace=True)
        )
        self.pool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        
        # convolution 3
        self.block3 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=48, out_channels=48, kernel_size=5, stride=1, padding=2, bias=False),
            torch.nn.BatchNorm2d(num_features=48),
            torch.nn.ReLU(inplace=True)
        )
        
        # convolution 4
        self.block4 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=48, out_channels=96, kernel_size=5, stride=1, padding=2, bias=False),
            torch.nn.BatchNorm2d(num_features=96),
            torch.nn.ReLU(inplace=True)
        )
        self.pool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        
        # convolution 5
        self.block5 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=5, stride=1, padding=2, bias=False),
            torch.nn.BatchNorm2d(num_features=96),
            torch.nn.ReLU(inplace=True)
        )
        
        # convolution 6
        self.block6 = torch.nn.Sequential(
            # NOTE 1: changed padding=1 to padding=0 in this layer to help match input size of fc layer
            torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=5, stride=1, padding=0, bias=False),
            torch.nn.BatchNorm2d(num_features=64),
            torch.nn.ReLU(inplace=True)
        )
        self.pool3 = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=0)
        # NOTE 2: added another avgpool2d to match input size of fc layer
        self.pool4 = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=0)
#        self.pool5 = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=0)
#        self.pool6 = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=0)

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(64*13*13, 64),
            torch.nn.BatchNorm1d(num_features=64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.pool1(x)
        
        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)
        
        x = self.block5(x)
        x = self.block6(x)
        x = self.pool3(x)
        x = self.pool4(x)
#        x = self.pool5(x)
#        x = self.pool6(x)
        
        # flatten all dimensions except batch
        x = torch.flatten(x, 1)
        
        # classify with the final fully connected layers)
        x = self.fully_net(x.reshape(x.shape[0], -1))
        
        return x

Let's build the model

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'
print(device)

cuda


In [236]:
cyclic_group = 8
n_classes = 15

model = C8SteerableCNN(n_classes=n_classes, cyclic_group=cyclic_group).to(device)

summary(model, input_size=(3, 80, 80))

# cyclic_group = 1
# n_classes = 15

# model = NonRECNN(n_classes=n_classes).to(device)

# summary(model, input_size=(3, 80, 80))

block1 192
block2 384
pool1 384
block3 384
block4 768
pool2 768
block5 768
block6 512
pool3 512
gpool 64
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
        MaskModule-1            [-1, 3, 80, 80]               0
SingleBlockBasisExpansion-2             [-1, 8, 1, 49]               0
BlocksBasisExpansion-3                [-1, 3, 49]               0
            R2Conv-4          [-1, 192, 76, 76]               0
       BatchNorm3d-5        [-1, 24, 8, 76, 76]              48
    InnerBatchNorm-6          [-1, 192, 76, 76]               0
              ReLU-7          [-1, 192, 76, 76]               0
  SequentialModule-8          [-1, 192, 76, 76]               0
SingleBlockBasisExpansion-9             [-1, 8, 8, 25]               0
SingleBlockBasisExpansion-10             [-1, 8, 8, 25]               0
SingleBlockBasisExpansion-11             [-1, 8, 8, 25]               0
SingleBlockBasisExpansion-12  

In [8]:
cyclic_group = 8
n_classes = 15

model1 = C8SteerableCNN(n_classes=n_classes, cyclic_group=cyclic_group).to(device)
summary(model1, input_size=(3, 80, 80))

block1 192
block2 384
pool1 384
block3 384
block4 768
pool2 768
block5 768
block6 512
pool3 512
gpool 64
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
        MaskModule-1            [-1, 3, 80, 80]               0
SingleBlockBasisExpansion-2             [-1, 8, 1, 49]               0
BlocksBasisExpansion-3                [-1, 3, 49]               0
            R2Conv-4          [-1, 192, 76, 76]               0
       BatchNorm3d-5        [-1, 24, 8, 76, 76]              48
    InnerBatchNorm-6          [-1, 192, 76, 76]               0
              ReLU-7          [-1, 192, 76, 76]               0
  SequentialModule-8          [-1, 192, 76, 76]               0
SingleBlockBasisExpansion-9             [-1, 8, 8, 25]               0
SingleBlockBasisExpansion-10             [-1, 8, 8, 25]               0
SingleBlockBasisExpansion-11             [-1, 8, 8, 25]               0
SingleBlockBasisExpansion-12  

In [9]:
model2 = NonRECNN(n_classes=n_classes).to(device)
summary(model2, input_size=(3, 80, 80))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 24, 76, 76]           3,528
       BatchNorm2d-2           [-1, 24, 76, 76]              48
              ReLU-3           [-1, 24, 76, 76]               0
            Conv2d-4           [-1, 48, 76, 76]          28,800
       BatchNorm2d-5           [-1, 48, 76, 76]              96
              ReLU-6           [-1, 48, 76, 76]               0
         AvgPool2d-7           [-1, 48, 38, 38]               0
            Conv2d-8           [-1, 48, 38, 38]          57,600
       BatchNorm2d-9           [-1, 48, 38, 38]              96
             ReLU-10           [-1, 48, 38, 38]               0
           Conv2d-11           [-1, 96, 38, 38]         115,200
      BatchNorm2d-12           [-1, 96, 38, 38]             192
             ReLU-13           [-1, 96, 38, 38]               0
        AvgPool2d-14           [-1, 96,

The model is now randomly initialized. 
Therefore, we do not expect it to produce the right class probabilities.

However, the model should still produce the same output for rotated versions of the same image.
This is true for rotations by multiples of $\frac{\pi}{2}$, but is only approximate for rotations by $\frac{\pi}{4}$.

Let's try the model on *rotated* MNIST

In [238]:
# # download the dataset
# !wget -nc http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip
# # uncompress the zip file
# !unzip -n mnist_rotation_new.zip -d mnist_rotation_new

In [239]:
# import requests
# url = 'http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip'
# doc = requests.get(url)
# with open('mnistrot.zip', 'wb') as f:
#     f.write(doc.content)

Build the dataset

In [240]:
class DOTARotDataset(Dataset):
    
    def __init__(self, mode, transform=None, max_num_examples=0):
        assert mode in ['train', 'test']
        
        basedir = 'C:/Users/Admin/Desktop/data/DOTAv1.0/'
        
        if mode == "train":
            file = basedir+"chips_train/"
        else:
            file = basedir+"chips_val/"
        
        self.transform = transform
        self.max_num_examples = max_num_examples
        self.images = []
        self.labels = []
        new_image_width = 80
        new_image_height = 80
        color = (0,0,0)
        self.classdict = {}
        self.classval = 0

        #data = np.loadtxt(file, delimiter=' ')
        data = []
        
        for root, dirs, filenames in os.walk(file, topdown=False):
            pass
        for d in dirs:
            for root, dirs, filenames in os.walk(file+d, topdown=False):
                pass
            #print(d, len(filenames))
            for f in filenames:
                imgstr = file+d+"/"+f
                #img = Image.open(imgstr)
                #arr = np.asarray(img)
                img = cv2.imread(imgstr)
                old_image_height, old_image_width, channels = img.shape
                #print('old', old_image_height, old_image_width)
                if img.shape[0] < img.shape[1]:
                    result1 = np.full((img.shape[1],img.shape[1], channels), color, dtype=np.uint8)
                    #print('shape', result1.shape)
                    # compute center offset
                    #x_center = np.abs(new_image_width - old_image_width) // 2
                    y_center = np.abs(img.shape[1] - old_image_height) // 2
                    #print('y_center', y_center)
                    # copy img image into center of result image
                    result1[y_center:y_center+old_image_height, :] = img
                elif img.shape[0] > img.shape[1]:
                    result1 = np.full((img.shape[0],img.shape[0], channels), color, dtype=np.uint8)
                    #print('shape', result1.shape)
                    # compute center offset
                    x_center = np.abs(img.shape[0] - old_image_width) // 2
                    #y_center = np.abs(new_image_height - old_image_height) // 2
                    #print('x_center', x_center)
                    # copy img image into center of result image
                    result1[:, x_center:x_center+old_image_width] = img
                else:
                    result1 = img[:,:,:]
                # check if we need to resize
                if not (result1.shape[0] == 80 and result1.shape[1] == 80):
                    #print('result1.shape1', result1.shape)
                    result1 = cv2.resize(result1, dsize=(80, 80), interpolation=cv2.INTER_CUBIC)
                    #print('result1.shape2', result1.shape)
                res2 = result1.reshape(-1, 80, 80).astype(np.float32) # = arr[:, :-1].reshape(-1, 80, 80)
                #print('res2.shape', res2.shape)
                
                # ToTensor screws up the order, so we have to undo it:
                # https://discuss.pytorch.org/t/torchvision-totensor-dont-change-channel-order/82038/2
                #res2 = res2.permute((1, 2, 0)).contiguous()
                
                # convert back to PIL Image object for pytorch transforms (e.g. RandomRotation) to work
                #image = Image.fromarray(image)
                
                self.images.append(res2)
        
                # labels
                if self.classdict.get(d, -1) >= 0:
                    self.labels.append(self.classdict[d])
                else: # it's not in self.classdict yet
                    self.classdict[d] = self.classval
                    self.classval += 1
        
        if self.max_num_examples > 0:
            #z1 = list(zip(self.images, self.labels))
            classdict = {}
            for i in range(len(self.labels)):
                if classdict.get(self.labels[i], False):
                    classdict[self.labels[i]].append(self.images[i])
                else:
                    classdict[self.labels[i]] = [self.images[i]]
            self.images = []
            self.labels = []
            for i in range(self.classval):
                if len(classdict[i]) > self.max_num_examples:
                    temp = random.sample(classdict[i], self.max_num_examples)
                    for t in temp:
                        self.images.append(t)
                        self.labels.append(i)
                else:
                    for t in classdict[i]:
                        self.images.append(t)
                        self.labels.append(i)

        self.num_samples = len(self.labels)
        print("self.num_samples", self.num_samples)
    
    def __getitem__(self, index):
        image, label = self.images[index], self.labels[index]
        # image is a numpy ndarray instead of PIL Image object
        # NOTE: certain pytorch functions (aka RandomRotate) require PIL Image objects
        # ToTensor screws up the shape/channel order, so we have to undo it:
        # https://discuss.pytorch.org/t/torchvision-totensor-dont-change-channel-order/82038/2
        #image = image.permute((1, 2, 0)).contiguous()
        #image = Image.fromarray(image)
        if self.transform is not None:
            image = self.transform(image)
        return image, label
    
    def __len__(self):
        return len(self.labels)

# images are padded to have shape 29x29.
# this allows to use odd-size filters with stride 2 when downsampling a feature map in the model
#pad = Pad((0, 0, 1, 1), fill=0)

# to reduce interpolation artifacts (e.g. when testing the model on rotated images),
# we upsample an image by a factor of 3, rotate it and finally downsample it again
#resize1 = Resize(80*3)
#resize2 = Resize(80)

#totensor = ToTensor()

In [241]:
# build the test set
#raw_mnist_test = DOTARotDataset(mode='test')

In [242]:
# retrieve the first image from the test set
#x, y = next(iter(raw_mnist_test))
#
#print(x.shape)
#print(raw_mnist_test[3530])

In [243]:
def test_model(model: torch.nn.Module, x: Image):
    # evaluate the `model` on 8 rotated versions of the input image `x`
    model.eval()
    totensor = ToTensor()
    
    wrmup = model(torch.randn(1, 3, 80, 80).to(device))
    del wrmup
    
    #x = resize1(pad(x))
    
    print()
    print('##########################################################################################')
    header = 'angle |  ' + '  '.join(["{:6d}".format(d) for d in range(15)])
    print(header)
    with torch.no_grad():
        for r in range(8):
            #print(np.min(x), np.max(x), x.shape)
            intimg = x.astype(np.uint8)
            #print(type(intimg))
            #print(np.min(intimg), np.max(intimg), intimg.shape)
            # go from 3,80,80 -> 80,80,3
            #intimg2 = intimg.permute((1, 2, 0)).contiguous() # can't permute numpy arrays
            intimg2 = intimg.transpose(1, 2, 0)
            #print(np.min(intimg2), np.max(intimg2), intimg2.shape)
            img = Image.fromarray(intimg2)
            #print(np.min(img), np.max(img))
            rotimg = img.rotate(r*45., Image.BILINEAR)
            #print(rotimg.shape)
            x_transformed = totensor(rotimg).reshape(-1, 3, 80, 80)
            x_transformed = x_transformed.to(device)

            y = model(x_transformed)
            y = y.to('cpu').numpy().squeeze()
            
            angle = r * 45
            print("{:3d}:{}".format(angle, y))
    print('##########################################################################################')
    print()

In [244]:
# evaluate the model
#test_model(model, x)

The output of the model is already almost invariant.
However, we still observe small fluctuations in the outputs.

This is because the model contains some operations which might break equivariance.
For instance, every convolution includes a padding of $2$ pixels per side. This is adds information about the actual orientation of the grid where the image/feature map is sampled because the padding is not rotated with the image. 

During training, the model will observe rotated patterns and will learn to ignore the noise coming from the padding.

So, let's train the model now.
The model is exactly the same used to train a normal *PyTorch* architecture:

In [245]:
# Prep the training dataset
train_transform = Compose([
    #pad,
    #resize1,
    #RandomRotation(180, resample=Image.BILINEAR, expand=False),
    #resize2,
    #ToTensor()
])

#data_train = DOTARotDataset(mode='train', transform=train_transform)
#train_loader = torch.utils.data.DataLoader(data_train, batch_size=8, shuffle=True)
#print(len(data_train), len(train_loader))

In [246]:
def get_fresh_training_data(n):
    data_train = DOTARotDataset(mode='train', transform=train_transform, max_num_examples=n)
    train_loader = torch.utils.data.DataLoader(data_train, batch_size=8, shuffle=True, drop_last=True)
    #print(len(data_train), len(train_loader))
    return train_loader

In [247]:
# Prep the testing dataset
test_transform = Compose([
    #pad,
    #ToTensor()
])

data_test = DOTARotDataset(mode='test', transform=test_transform)
test_loader = torch.utils.data.DataLoader(data_test, batch_size=8, shuffle=False, drop_last=True)
print(len(data_test), len(test_loader))

self.num_samples 28838
28838 3604


In [248]:
# mini = 99999
# maxi = 0
# for m in data_train: # data_test
#     if m[1] < mini:
#         mini = m[1]
#     if m[1] > maxi:
#         maxi = m[1]
# print(mini, maxi, "<--- should be 0 14 instead of 1 15")

In [249]:
# Initialize the model
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-5)

In [250]:
def save_the_model(model, cyclic_group, epoch, train_acc, train_loss, test_acc, test_loss):
    if not os.path.exists("models"):
        os.mkdir("models")
    model_name = "models/model_dota_C"+str(cyclic_group)+"_"+str(epoch)+"_"+str(round(train_acc,4))+"_"+str(round(train_loss,4))+"_"+str(round(test_acc,4))+"_"+str(round(test_loss,4))+".pth"
#    if not os.path.exists("models/model_dota.pth"):
    torch.save(model.state_dict(), model_name)
#    print('saved')
#    else:
#         model.load_state_dict(torch.load("models/model_dota.pth"))
#         print('loaded')

In [251]:
def save_the_results(cyclic_group, epoch, train_accuracy, train_loss, test_accuracy, test_loss):
    add_headers = False
    if not os.path.exists("results"):
        os.mkdir("results")
    if not os.path.exists("results/model_dota.csv"):
        add_headers = True
    # writing to csv file
    with open("results/model_dota.csv", 'a', newline='') as csvfile: 
        # creating a csv writer object 
        csvwriter = csv.writer(csvfile)

        if add_headers:
            csvwriter.writerow(['Cyclic Group', 'Epoch', 'Train Accuracy', 'Train Loss', 'Test Accuracy', 'Test Loss'])

        # writing the data rows
        csvwriter.writerow([cyclic_group, epoch, train_accuracy, train_loss, test_accuracy, test_loss])

In [252]:
# # get initial performance with random weights
# test_total = 0
# test_correct = 0
# test_loss = 0
# with torch.no_grad():
#     model.eval()
    
#     #for i, (x, t) in enumerate(test_loader):
#     with tqdm(test_loader, unit="batch") as tepoch:
#         for x, t in tepoch:
#             tepoch.set_description(f"Epoch {-1}")
            
#             #if i%1000==0:
#             #    print(i, "/", len(test_loader))

#             x = x.to(device)
#             t = t.to(device)

#             y = model(x)

#             _, prediction = torch.max(y.data, 1)
#             if prediction.shape[0] != t.shape[0]:
#                 print(t)
#                 t = t[:-(t.shape[0]-prediction.shape[0])]
#                 print(t)
#             test_total += t.shape[0]
#             test_correct += (prediction == t).sum().item()

#             loss = loss_function(y, t)
#             test_loss += loss

# test_accuracy = test_correct/test_total*100.
    
# print(f"test accuracy: {test_accuracy}")
# print(f"test loss: {test_loss.item()}")

In [253]:
# save_the_results(cyclic_group, 0, 0, test_accuracy, test_loss.item())
# save_the_model(model, cyclic_group, 0, 0, 0, test_accuracy, test_loss.item())

In [254]:
######################################################
# IF CONTINUING TRAINING, LOAD MODEL FROM LAST EPOCH
######################################################

# train_loader = get_fresh_training_data(30000)
# model.load_state_dict(torch.load("models/model_dota_59_97.0042_1035.413_83.879_2458.3879.pth"))

In [255]:
samples_per_class = 500
start_epoch = 0
max_epochs = 100

for epoch in range(start_epoch, max_epochs):
    
    print('starting epoch', epoch)
    
    ########################################
    # TRAIN
    ########################################
    train_total = 0
    train_correct = 0
    train_loss = 0
    model.train()
    
    train_loader = get_fresh_training_data(samples_per_class)
    print("samples_per_class", samples_per_class)
    samples_per_class += 100
    
    #for i, (x, t) in enumerate(train_loader):
    with tqdm(train_loader, unit="batch") as tepoch:
        for x, t in tepoch:
            tepoch.set_description(f"Epoch {epoch}")

            #if i%5000==0:
            #    print(i, "/", len(train_loader))

            optimizer.zero_grad()

            x = x.to(device)
            t = t.to(device)

            y = model(x)

            _, prediction = torch.max(y.data, 1)
            # sometimes at the end of an epoch, prediction.shape can be < t.shape
            if prediction.shape[0] != t.shape[0]:
                #print(t)
                t = t[:-(t.shape[0]-prediction.shape[0])]
                #print(t)
            train_total += t.shape[0]
            train_correct += (prediction == t).sum().item()

            loss = loss_function(y, t)
            train_loss += loss

            loss.backward()

            optimizer.step()
    tepoch.close()

    ########################################
    # TEST
    ########################################
    test_total = 0
    test_correct = 0
    test_loss = 0
    
    with torch.no_grad():
        model.eval()
        
        #for i, (x, t) in enumerate(test_loader):
        with tqdm(test_loader, unit="batch") as tepoch:
            for x, t in tepoch:
                tepoch.set_description(f"Epoch {epoch}")
                
                #if i%5000==0:
                #    print(i, "/", len(test_loader))

                x = x.to(device)
                t = t.to(device)

                y = model(x)

                _, prediction = torch.max(y.data, 1)
                if prediction.shape[0] != t.shape[0]:
                    #print(t)
                    t = t[:-(t.shape[0]-prediction.shape[0])]
                    #print(t)
                test_total += t.shape[0]
                test_correct += (prediction == t).sum().item()

                loss = loss_function(y, t)
                test_loss += loss

    train_accuracy = train_correct/train_total*100.
    test_accuracy = test_correct/test_total*100.

    #print(f"epoch {epoch} | train accuracy: {train_accuracy}")
    #print(f"epoch {epoch} | train loss: {train_loss.item()}")
    #print(f"epoch {epoch} | test accuracy: {test_accuracy}")
    #print(f"epoch {epoch} | test loss: {test_loss.item()}")

    tepoch.set_postfix({"train_accuracy":train_accuracy, 
                        "train_loss":train_loss.item(),
                        "test_accuracy":test_accuracy,
                        "test_loss":test_loss.item()})
    tepoch.close()
    time.sleep(0.5)
    
    save_the_results(cyclic_group, epoch, train_accuracy, train_loss.item(), test_accuracy, test_loss.item())
    save_the_model(model, cyclic_group, epoch, train_accuracy, train_loss.item(), test_accuracy, test_loss.item())

starting epoch 0
self.num_samples 6961


Epoch 0: 100%|██████████| 870/870 [05:05<00:00,  2.84batch/s] 
Epoch 0: 100%|██████████| 3604/3604 [05:11<00:00, 11.58batch/s]


starting epoch 1
self.num_samples 7975


Epoch 1: 100%|██████████| 996/996 [04:32<00:00,  3.66batch/s]
Epoch 1: 100%|██████████| 3604/3604 [05:06<00:00, 11.77batch/s]


starting epoch 2
self.num_samples 8904


Epoch 2: 100%|██████████| 1113/1113 [05:04<00:00,  3.65batch/s]
Epoch 2: 100%|██████████| 3604/3604 [05:07<00:00, 11.73batch/s]


starting epoch 3
self.num_samples 9804


Epoch 3: 100%|██████████| 1225/1225 [05:43<00:00,  3.56batch/s]
Epoch 3: 100%|██████████| 3604/3604 [05:07<00:00, 11.71batch/s]


starting epoch 4
self.num_samples 10704


Epoch 4: 100%|██████████| 1338/1338 [06:05<00:00,  3.66batch/s]
Epoch 4: 100%|██████████| 3604/3604 [05:07<00:00, 11.72batch/s]


starting epoch 5
self.num_samples 11604


Epoch 5: 100%|██████████| 1450/1450 [06:36<00:00,  3.66batch/s]
Epoch 5: 100%|██████████| 3604/3604 [05:06<00:00, 11.77batch/s]


starting epoch 6
self.num_samples 12504


Epoch 6: 100%|██████████| 1563/1563 [07:07<00:00,  3.66batch/s]
Epoch 6: 100%|██████████| 3604/3604 [05:06<00:00, 11.75batch/s]


starting epoch 7
self.num_samples 13404


Epoch 7: 100%|██████████| 1675/1675 [07:37<00:00,  3.66batch/s]
Epoch 7: 100%|██████████| 3604/3604 [05:07<00:00, 11.74batch/s]


starting epoch 8
self.num_samples 14304


Epoch 8: 100%|██████████| 1788/1788 [08:09<00:00,  3.65batch/s]
Epoch 8: 100%|██████████| 3604/3604 [05:09<00:00, 11.66batch/s]


starting epoch 9
self.num_samples 15204


Epoch 9: 100%|██████████| 1900/1900 [08:55<00:00,  3.55batch/s]
Epoch 9: 100%|██████████| 3604/3604 [05:12<00:00, 11.54batch/s]


starting epoch 10
self.num_samples 16104


Epoch 10: 100%|██████████| 2013/2013 [09:28<00:00,  3.54batch/s]
Epoch 10: 100%|██████████| 3604/3604 [05:14<00:00, 11.46batch/s]


starting epoch 11
self.num_samples 17004


Epoch 11: 100%|██████████| 2125/2125 [09:59<00:00,  3.54batch/s]
Epoch 11: 100%|██████████| 3604/3604 [05:14<00:00, 11.46batch/s]


starting epoch 12
self.num_samples 17904


Epoch 12: 100%|██████████| 2238/2238 [10:31<00:00,  3.54batch/s]
Epoch 12: 100%|██████████| 3604/3604 [05:15<00:00, 11.41batch/s]


starting epoch 13
self.num_samples 18739


Epoch 13: 100%|██████████| 2342/2342 [11:01<00:00,  3.54batch/s]
Epoch 13: 100%|██████████| 3604/3604 [05:18<00:00, 11.33batch/s]


starting epoch 14
self.num_samples 19539


Epoch 14: 100%|██████████| 2442/2442 [11:29<00:00,  3.54batch/s]
Epoch 14: 100%|██████████| 3604/3604 [05:13<00:00, 11.49batch/s]


starting epoch 15
self.num_samples 20339


Epoch 15: 100%|██████████| 2542/2542 [11:42<00:00,  3.62batch/s]
Epoch 15: 100%|██████████| 3604/3604 [05:16<00:00, 11.40batch/s]


starting epoch 16
self.num_samples 21085


Epoch 16: 100%|██████████| 2635/2635 [12:08<00:00,  3.62batch/s]
Epoch 16: 100%|██████████| 3604/3604 [05:15<00:00, 11.42batch/s]


starting epoch 17
self.num_samples 21785


Epoch 17: 100%|██████████| 2723/2723 [12:32<00:00,  3.62batch/s]
Epoch 17: 100%|██████████| 3604/3604 [05:16<00:00, 11.40batch/s]


starting epoch 18
self.num_samples 22485


Epoch 18: 100%|██████████| 2810/2810 [12:57<00:00,  3.62batch/s]
Epoch 18: 100%|██████████| 3604/3604 [05:17<00:00, 11.36batch/s]


starting epoch 19
self.num_samples 23151


Epoch 19: 100%|██████████| 2893/2893 [13:27<00:00,  3.58batch/s]
Epoch 19: 100%|██████████| 3604/3604 [05:21<00:00, 11.21batch/s]


starting epoch 20
self.num_samples 23751


Epoch 20: 100%|██████████| 2968/2968 [13:58<00:00,  3.54batch/s]
Epoch 20: 100%|██████████| 3604/3604 [05:21<00:00, 11.20batch/s]


starting epoch 21
self.num_samples 24351


Epoch 21: 100%|██████████| 3043/3043 [14:20<00:00,  3.54batch/s]
Epoch 21: 100%|██████████| 3604/3604 [05:22<00:00, 11.17batch/s]


starting epoch 22
self.num_samples 24951


Epoch 22: 100%|██████████| 3118/3118 [14:40<00:00,  3.54batch/s]
Epoch 22: 100%|██████████| 3604/3604 [05:23<00:00, 11.15batch/s]


starting epoch 23
self.num_samples 25551


Epoch 23: 100%|██████████| 3193/3193 [15:04<00:00,  3.53batch/s]
Epoch 23: 100%|██████████| 3604/3604 [05:24<00:00, 11.11batch/s]


starting epoch 24
self.num_samples 26151


Epoch 24: 100%|██████████| 3268/3268 [15:26<00:00,  3.53batch/s]
Epoch 24: 100%|██████████| 3604/3604 [05:27<00:00, 11.02batch/s]


starting epoch 25
self.num_samples 26751


Epoch 25: 100%|██████████| 3343/3343 [15:48<00:00,  3.52batch/s]
Epoch 25: 100%|██████████| 3604/3604 [05:26<00:00, 11.03batch/s]


starting epoch 26
self.num_samples 27351


Epoch 26: 100%|██████████| 3418/3418 [16:10<00:00,  3.52batch/s]
Epoch 26: 100%|██████████| 3604/3604 [05:28<00:00, 10.97batch/s]


starting epoch 27
self.num_samples 27951


Epoch 27: 100%|██████████| 3493/3493 [16:30<00:00,  3.53batch/s]
Epoch 27: 100%|██████████| 3604/3604 [05:27<00:00, 10.99batch/s]


starting epoch 28
self.num_samples 28551


Epoch 28: 100%|██████████| 3568/3568 [16:50<00:00,  3.53batch/s]
Epoch 28: 100%|██████████| 3604/3604 [05:28<00:00, 10.96batch/s]


starting epoch 29
self.num_samples 29151


Epoch 29: 100%|██████████| 3643/3643 [17:11<00:00,  3.53batch/s]
Epoch 29: 100%|██████████| 3604/3604 [05:29<00:00, 10.95batch/s]


starting epoch 30
self.num_samples 29751


Epoch 30: 100%|██████████| 3718/3718 [17:31<00:00,  3.54batch/s]
Epoch 30: 100%|██████████| 3604/3604 [05:28<00:00, 10.96batch/s]


starting epoch 31
self.num_samples 30351


Epoch 31: 100%|██████████| 3793/3793 [17:51<00:00,  3.54batch/s]
Epoch 31: 100%|██████████| 3604/3604 [05:29<00:00, 10.95batch/s]


starting epoch 32
self.num_samples 30951


Epoch 32: 100%|██████████| 3868/3868 [18:13<00:00,  3.54batch/s]
Epoch 32: 100%|██████████| 3604/3604 [05:29<00:00, 10.95batch/s]


starting epoch 33
self.num_samples 31551


Epoch 33: 100%|██████████| 3943/3943 [18:34<00:00,  3.54batch/s]
Epoch 33: 100%|██████████| 3604/3604 [05:29<00:00, 10.94batch/s]


starting epoch 34
self.num_samples 32151


Epoch 34: 100%|██████████| 4018/4018 [18:54<00:00,  3.54batch/s]
Epoch 34: 100%|██████████| 3604/3604 [05:29<00:00, 10.94batch/s]


starting epoch 35
self.num_samples 32751


Epoch 35: 100%|██████████| 4093/4093 [19:14<00:00,  3.54batch/s]
Epoch 35: 100%|██████████| 3604/3604 [05:22<00:00, 11.19batch/s]


starting epoch 36
self.num_samples 33351


Epoch 36: 100%|██████████| 4168/4168 [19:36<00:00,  3.54batch/s]
Epoch 36: 100%|██████████| 3604/3604 [05:30<00:00, 10.89batch/s]


starting epoch 37
self.num_samples 33951


Epoch 37: 100%|██████████| 4243/4243 [19:56<00:00,  3.55batch/s]
Epoch 37: 100%|██████████| 3604/3604 [05:30<00:00, 10.90batch/s]


starting epoch 38
self.num_samples 34551


Epoch 38: 100%|██████████| 4318/4318 [20:18<00:00,  3.55batch/s]
Epoch 38: 100%|██████████| 3604/3604 [05:30<00:00, 10.89batch/s]


starting epoch 39
self.num_samples 35151


Epoch 39: 100%|██████████| 4393/4393 [20:38<00:00,  3.55batch/s]
Epoch 39: 100%|██████████| 3604/3604 [05:31<00:00, 10.88batch/s]


starting epoch 40
self.num_samples 35751


Epoch 40: 100%|██████████| 4468/4468 [21:00<00:00,  3.54batch/s]
Epoch 40: 100%|██████████| 3604/3604 [05:33<00:00, 10.81batch/s]


starting epoch 41
self.num_samples 36351


Epoch 41: 100%|██████████| 4543/4543 [21:21<00:00,  3.54batch/s]
Epoch 41: 100%|██████████| 3604/3604 [05:32<00:00, 10.83batch/s]


starting epoch 42
self.num_samples 36951


Epoch 42: 100%|██████████| 4618/4618 [21:43<00:00,  3.54batch/s]
Epoch 42: 100%|██████████| 3604/3604 [05:33<00:00, 10.80batch/s]


starting epoch 43
self.num_samples 37551


Epoch 43: 100%|██████████| 4693/4693 [22:04<00:00,  3.54batch/s]
Epoch 43: 100%|██████████| 3604/3604 [05:34<00:00, 10.78batch/s]


starting epoch 44
self.num_samples 38151


Epoch 44: 100%|██████████| 4768/4768 [22:25<00:00,  3.54batch/s]
Epoch 44: 100%|██████████| 3604/3604 [05:34<00:00, 10.79batch/s]


starting epoch 45
self.num_samples 38751


Epoch 45: 100%|██████████| 4843/4843 [22:18<00:00,  3.62batch/s] 
Epoch 45: 100%|██████████| 3604/3604 [05:33<00:00, 10.80batch/s]


starting epoch 46
self.num_samples 39279


Epoch 46: 100%|██████████| 4909/4909 [22:42<00:00,  3.60batch/s]
Epoch 46: 100%|██████████| 3604/3604 [05:24<00:00, 11.11batch/s]


starting epoch 47
self.num_samples 39779


Epoch 47: 100%|██████████| 4972/4972 [22:59<00:00,  3.60batch/s]
Epoch 47: 100%|██████████| 3604/3604 [07:18<00:00,  8.23batch/s]


starting epoch 48
self.num_samples 40279


Epoch 48: 100%|██████████| 5034/5034 [42:38<00:00,  1.97batch/s]  
Epoch 48: 100%|██████████| 3604/3604 [05:37<00:00, 10.69batch/s]


starting epoch 49
self.num_samples 40779


Epoch 49: 100%|██████████| 5097/5097 [23:54<00:00,  3.55batch/s]
Epoch 49: 100%|██████████| 3604/3604 [05:30<00:00, 10.89batch/s]


starting epoch 50
self.num_samples 41279


Epoch 50: 100%|██████████| 5159/5159 [23:48<00:00,  3.61batch/s]
Epoch 50: 100%|██████████| 3604/3604 [05:31<00:00, 10.89batch/s]


starting epoch 51
self.num_samples 41779


Epoch 51: 100%|██████████| 5222/5222 [24:34<00:00,  3.54batch/s]
Epoch 51: 100%|██████████| 3604/3604 [05:40<00:00, 10.57batch/s]


starting epoch 52
self.num_samples 42279


Epoch 52: 100%|██████████| 5284/5284 [24:46<00:00,  3.55batch/s]
Epoch 52: 100%|██████████| 3604/3604 [05:37<00:00, 10.67batch/s]


starting epoch 53
self.num_samples 42779


Epoch 53: 100%|██████████| 5347/5347 [28:40<00:00,  3.11batch/s]
Epoch 53: 100%|██████████| 3604/3604 [05:41<00:00, 10.56batch/s]


starting epoch 54
self.num_samples 43279


Epoch 54: 100%|██████████| 5409/5409 [25:20<00:00,  3.56batch/s]
Epoch 54: 100%|██████████| 3604/3604 [05:38<00:00, 10.66batch/s]


starting epoch 55
self.num_samples 43761


Epoch 55: 100%|██████████| 5470/5470 [25:37<00:00,  3.56batch/s]
Epoch 55: 100%|██████████| 3604/3604 [05:40<00:00, 10.59batch/s]


starting epoch 56
self.num_samples 44161


Epoch 56: 100%|██████████| 5520/5520 [25:59<00:00,  3.54batch/s]
Epoch 56: 100%|██████████| 3604/3604 [05:42<00:00, 10.52batch/s]


starting epoch 57
self.num_samples 44561


Epoch 57: 100%|██████████| 5570/5570 [26:06<00:00,  3.56batch/s]
Epoch 57: 100%|██████████| 3604/3604 [05:40<00:00, 10.58batch/s]


starting epoch 58
self.num_samples 44961


Epoch 58: 100%|██████████| 5620/5620 [26:21<00:00,  3.55batch/s]
Epoch 58: 100%|██████████| 3604/3604 [05:42<00:00, 10.53batch/s]


starting epoch 59
self.num_samples 45361


Epoch 59: 100%|██████████| 5670/5670 [26:35<00:00,  3.55batch/s]
Epoch 59: 100%|██████████| 3604/3604 [05:43<00:00, 10.50batch/s]


starting epoch 60
self.num_samples 45761


Epoch 60: 100%|██████████| 5720/5720 [26:48<00:00,  3.56batch/s]
Epoch 60: 100%|██████████| 3604/3604 [05:44<00:00, 10.48batch/s]


starting epoch 61
self.num_samples 46161


Epoch 61: 100%|██████████| 5770/5770 [27:02<00:00,  3.56batch/s]
Epoch 61: 100%|██████████| 3604/3604 [05:43<00:00, 10.48batch/s]


starting epoch 62
self.num_samples 46561


Epoch 62: 100%|██████████| 5820/5820 [27:15<00:00,  3.56batch/s]
Epoch 62: 100%|██████████| 3604/3604 [05:43<00:00, 10.49batch/s]


starting epoch 63
self.num_samples 46961


Epoch 63: 100%|██████████| 5870/5870 [27:28<00:00,  3.56batch/s]
Epoch 63: 100%|██████████| 3604/3604 [05:43<00:00, 10.50batch/s]


starting epoch 64
self.num_samples 47361


Epoch 64: 100%|██████████| 5920/5920 [27:42<00:00,  3.56batch/s]
Epoch 64: 100%|██████████| 3604/3604 [05:46<00:00, 10.39batch/s]


starting epoch 65
self.num_samples 47761


Epoch 65:   9%|▉         | 543/5970 [02:32<25:20,  3.57batch/s]


KeyboardInterrupt: 

In [None]:
# train_accuracy = train_correct/train_total*100.
# print(train_correct, train_total, train_accuracy)
#save_the_results(cyclic_group, epoch, train_accuracy, train_loss.item(), test_accuracy, test_loss.item())
#save_the_model(model, cyclic_group, epoch, train_accuracy, train_loss.item(), test_accuracy, test_loss.item())

In [None]:
# print(prediction.shape, t.shape)
# print(prediction)
# print(t)
# print(t[:-1])
# print(y.shape)
# print(torch.max(y.data, 1)) # 1 is the dimension

In [None]:
# epoch -1 | test accuracy: 10.011096469935502
# epoch -1 | test loss: 2109988.5

# epoch 0 | train accuracy: 52.91684784257414
# epoch 0 | train loss: 42925.43359375
# epoch 0 | test accuracy: 18.846660656078782
# epoch 0 | test loss: 201328.765625

In [None]:
#data_test = DOTARotDataset(mode='test')

In [54]:
# retrieve the first image from the test set
x, y = next(iter(data_test))

# evaluate the model
test_model(model, x)


##########################################################################################
angle |       0       1       2       3       4       5       6       7       8       9      10      11      12      13      14
  0:[ -0.9596  -4.8888   5.9464  -3.2338  -2.3149 -11.3817  -2.4831 -11.5027  -1.2079   0.5564   0.737   -5.1735  -1.4842 -11.1275  -0.8537]
 45:[ -0.9515  -4.8922   5.9272  -3.2435  -2.309  -11.3752  -2.4823 -11.5275  -1.2121   0.574    0.7156  -5.1749  -1.484  -11.1018  -0.8391]
 90:[ -0.9535  -4.8953   5.9379  -3.2481  -2.3075 -11.3579  -2.4622 -11.5154  -1.2262   0.5651   0.72    -5.1597  -1.479  -11.0946  -0.8476]
135:[ -0.9528  -4.9033   5.9461  -3.2624  -2.2945 -11.3566  -2.4489 -11.4927  -1.2168   0.5497   0.6955  -5.1521  -1.4605 -11.0987  -0.8426]
180:[ -0.9504  -4.895    5.9334  -3.2491  -2.2891 -11.3684  -2.4532 -11.491   -1.2094   0.5527   0.7088  -5.1644  -1.4622 -11.1106  -0.8529]
225:[ -0.9533  -4.8964   5.9366  -3.2424  -2.2844 -11.3635  -2.4812 -11.481