# General E(2)-Equivariant Steerable CNNs  -  A concrete example


In [1]:
import os
import csv
import time
import random
from tqdm import tqdm

import numpy as np
from PIL import Image
import cv2

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.transforms import RandomRotation, Pad, Resize, ToTensor, Compose
from torchsummary import summary

from e2cnn import gspaces
from e2cnn import nn

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

Finally, we build a **Steerable CNN** and try it MNIST.

Let's also use a group a bit larger: we now build a model equivariant to $8$ rotations.
We indicate the group of $N$ discrete rotations as $C_N$, i.e. the **cyclic group** of order $N$.
In this case, we will use $C_8$.

Because the inputs are still gray-scale images, the input type of the model is again a *scalar field*.

However, internally we use *regular fields*: this is equivalent to a *group-equivariant convolutional neural network*.

Finally, we build *invariant* features for the final classification task by pooling over the group using *Group Pooling*.

The final classification is performed by a two fully connected layers.

# The model

Here is the definition of our model:

In [2]:
class C8SteerableCNN(torch.nn.Module):
    
    def __init__(self, n_classes=10, cyclic_group=8):
        self.cyclic_group=cyclic_group
        
        super(C8SteerableCNN, self).__init__()
        
        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=self.cyclic_group)
        #print("self.r2_act", self.r2_act)
        
        # the input image is a scalar field, corresponding to the trivial representation
        in_type = nn.FieldType(self.r2_act, 3*[self.r2_act.trivial_repr])
        #print("in_type", in_type)
        
        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = in_type
        
        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C8
        out_type = nn.FieldType(self.r2_act, 24*[self.r2_act.regular_repr])
        #print("out_type", out_type)
        self.block1 = nn.SequentialModule(
            nn.MaskModule(in_type, 80, margin=1),
            nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        print('block1', self.block1.out_type.size)
        
        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 48 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        print('block2', self.block2.out_type.size)
        self.pool1 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        )
        print('pool1', self.pool1.out_type.size)
        
        # convolution 3
        # the old output type is the input type to the next layer
        in_type = self.block2.out_type
        # the output type of the third convolution layer are 48 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        print('block3', self.block3.out_type.size)
        
        # convolution 4
        # the old output type is the input type to the next layer
        in_type = self.block3.out_type
        # the output type of the fourth convolution layer are 96 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96*[self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        print('block4', self.block4.out_type.size)
        self.pool2 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        )
        print('pool2', self.pool2.out_type.size)
        
        # convolution 5
        # the old output type is the input type to the next layer
        in_type = self.block4.out_type
        # the output type of the fifth convolution layer are 96 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96*[self.r2_act.regular_repr])
        self.block5 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        print('block5', self.block5.out_type.size)
        
        # convolution 6
        # the old output type is the input type to the next layer
        in_type = self.block5.out_type
        # the output type of the sixth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 64*[self.r2_act.regular_repr])
        self.block6 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        print('block6', self.block6.out_type.size)
        self.pool3 = nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=1, padding=0)
        print('pool3', self.pool3.out_type.size)
        
        self.gpool = nn.GroupPooling(out_type) # pool3.out_type
        
        # number of output channels
        c = self.gpool.out_type.size
        print('gpool', c)
        
        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c*13*13, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )
    
    def forward(self, input: torch.Tensor):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        #print("input.shape", input.shape)
        x = nn.GeometricTensor(input, self.input_type)
        
        # apply each equivariant block
        
        # Each layer has an input and an output type
        # A layer takes a GeometricTensor in input.
        # This tensor needs to be associated with the same representation of the layer's input type
        #
        # The Layer outputs a new GeometricTensor, associated with the layer's output type.
        # As a result, consecutive layers need to have matching input/output types
        x = self.block1(x)
        x = self.block2(x)
        x = self.pool1(x)
        
        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)
        
        x = self.block5(x)
        x = self.block6(x)
        
        # pool over the spatial dimensions
        x = self.pool3(x)
        #print("x.shape1", x.shape)
        
        # pool over the group
        x = self.gpool(x)
        #print("x.shape2", x.shape)

        # unwrap the output GeometricTensor
        # (take the Pytorch tensor and discard the associated representation)
        x = x.tensor
        #print("x.shape3", x.shape)
        
        # classify with the final fully connected layers)
        x = self.fully_net(x.reshape(x.shape[0], -1))
        
        return x

In [3]:
# Not rotationally-equivariant architecture, mimics above as closely as possible
class NonRECNN(torch.nn.Module):
    def __init__(self, n_classes=10):
        super().__init__()
        
        # convolution 1
        self.block1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3, out_channels=24, kernel_size=7, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(num_features=24),
            torch.nn.ReLU(inplace=True)
        )
        
        # convolution 2
        self.block2 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=24, out_channels=48, kernel_size=5, stride=1, padding=2, bias=False),
            torch.nn.BatchNorm2d(num_features=48),
            torch.nn.ReLU(inplace=True)
        )
        self.pool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        
        # convolution 3
        self.block3 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=48, out_channels=48, kernel_size=5, stride=1, padding=2, bias=False),
            torch.nn.BatchNorm2d(num_features=48),
            torch.nn.ReLU(inplace=True)
        )
        
        # convolution 4
        self.block4 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=48, out_channels=96, kernel_size=5, stride=1, padding=2, bias=False),
            torch.nn.BatchNorm2d(num_features=96),
            torch.nn.ReLU(inplace=True)
        )
        self.pool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        
        # convolution 5
        self.block5 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=5, stride=1, padding=2, bias=False),
            torch.nn.BatchNorm2d(num_features=96),
            torch.nn.ReLU(inplace=True)
        )
        
        # convolution 6
        self.block6 = torch.nn.Sequential(
            # NOTE 1: changed padding=1 to padding=0 in this layer to help match input size of fc layer
            torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=5, stride=1, padding=0, bias=False),
            torch.nn.BatchNorm2d(num_features=64),
            torch.nn.ReLU(inplace=True)
        )
        self.pool3 = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=0)
        # NOTE 2: added another avgpool2d to match input size of fc layer
        self.pool4 = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=0)
#        self.pool5 = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=0)
#        self.pool6 = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=0)

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(64*13*13, 64),
            torch.nn.BatchNorm1d(num_features=64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.pool1(x)
        
        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)
        
        x = self.block5(x)
        x = self.block6(x)
        x = self.pool3(x)
        x = self.pool4(x)
#        x = self.pool5(x)
#        x = self.pool6(x)
        
        # flatten all dimensions except batch
        x = torch.flatten(x, 1)
        
        # classify with the final fully connected layers)
        x = self.fully_net(x.reshape(x.shape[0], -1))
        
        return x

Let's build the model

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'
print(device)

cuda


In [5]:
EQUIVARIANT = True
dataname = 'DOTAv1.0'

# 15 for DOTA, 62 for xView
n_classes = -1
if dataname == 'DOTAv1.0':
    n_classes = 15
elif dataname == 'xView':
    n_classes = 62

assert n_classes > 0

In [6]:
if EQUIVARIANT:
    cyclic_group = 8
    model = C8SteerableCNN(n_classes=n_classes, cyclic_group=cyclic_group).to(device)
else:
    cyclic_group = 1
    model = NonRECNN(n_classes=n_classes).to(device)

#summary(model, input_size=(3, 80, 80))

block1 192


  full_mask[mask] = norms.to(torch.uint8)


block2 384
pool1 384
block3 384
block4 768
pool2 768
block5 768
block6 512
pool3 512
gpool 64


In [7]:
# cyclic_group = 8
# n_classes = 15

# model1 = C8SteerableCNN(n_classes=n_classes, cyclic_group=cyclic_group).to(device)
# summary(model1, input_size=(3, 80, 80))

In [8]:
# model2 = NonRECNN(n_classes=n_classes).to(device)
# summary(model2, input_size=(3, 80, 80))

The model is now randomly initialized. 
Therefore, we do not expect it to produce the right class probabilities.

However, the model should still produce the same output for rotated versions of the same image.
This is true for rotations by multiples of $\frac{\pi}{2}$, but is only approximate for rotations by $\frac{\pi}{4}$.

Let's try the model on *rotated* MNIST

In [9]:
# # download the dataset
# !wget -nc http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip
# # uncompress the zip file
# !unzip -n mnist_rotation_new.zip -d mnist_rotation_new

In [10]:
# import requests
# url = 'http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip'
# doc = requests.get(url)
# with open('mnistrot.zip', 'wb') as f:
#     f.write(doc.content)

Build the dataset

In [11]:
class TheDataset(Dataset):

    def __init__(self, mode, transform=None, max_num_examples=999999, 
                 chip_size=80, dataname='xView', index_lists={}, classdict={}):
        assert mode in ['train', 'val']
        
        #dataname: DOTAv1.0, xView, FAIR1M, etc
        self.basedir = 'C:/Users/Admin/Desktop/data/'+dataname+'/'
        self.chipdir = self.basedir+"chips_"+mode+"/"

        self.transform = transform
        self.max_num_examples = max_num_examples
        #data = np.loadtxt(self.chipdir, delimiter=' ')
        self.index_lists = index_lists
        self.labels = []
        self.chip_size = chip_size
        self.color = (0,0,0)
        self.classdict = classdict
        self.classval = 0

        for root, dirs, filenames in os.walk(self.chipdir, topdown=False):
            pass

        # first time through, we need to create self.index_lists and self.classdict
        if not self.index_lists:
            # first get all indices of each class and put them into dict self.index_lists
            for d in dirs:
                for root, dirs, filenames in os.walk(self.chipdir+d, topdown=False):
                    pass
                for f in filenames:
                    if self.index_lists.get(d, False):
                        self.index_lists[d].append(int(f[:-4]))
                    else:
                        self.index_lists[d] = [int(f[:-4])]

        # now make self.classdict, or get it from when we made validation data
        if not self.classdict:
            # NOTE: labels must start at 0, not 1, or PyTorch is unhappy
            for k in self.index_lists.keys():
                self.classdict[k] = self.classval
                self.classval += 1

        # now take random sample from each list
        self.images = []
        self.labels = []
        #for k in self.index_lists.keys():
        with tqdm(self.index_lists.keys(), unit="class") as tkeys:
            for k in tkeys:
                ####################################################
                # if too many examples, take self.max_num_examples random examples for that class
                ####################################################
                if len(self.index_lists[k]) > self.max_num_examples:
                    temp = random.sample(self.index_lists[k], self.max_num_examples)

                    for t in temp:
                        self.images.append(self.prep_image(k, t))
                        self.labels.append(classdict[k])
                ####################################################
                # if not enough examples, take all examples of that class
                ####################################################
                else:
                    for t in self.index_lists[k]:
                        self.images.append(self.prep_image(k, t))
                        self.labels.append(classdict[k])
        tkeys.close()


    def prep_image(self, d, f):
        imgstr = self.chipdir+d+"/"+str(f)+".png"
        #print(imgstr)
        #img = Image.open(imgstr)
        #arr = np.asarray(img)
        img = cv2.imread(imgstr)
        old_image_height, old_image_width, channels = img.shape
        #print('old', old_image_height, old_image_width)
        
        # first pad to become a square (then we'll resize later)
        if img.shape[0] < img.shape[1]:
            result1 = np.full((img.shape[1],img.shape[1], channels), self.color, dtype=np.uint8)
            #print('shape', result1.shape)
            # compute center offset
            #x_center = np.abs(self.chip_size - old_image_width) // 2
            y_center = np.abs(img.shape[1] - old_image_height) // 2
            #print('y_center', y_center)
            # copy img image into center of result image
            result1[y_center:y_center+old_image_height, :] = img
        elif img.shape[0] > img.shape[1]:
            result1 = np.full((img.shape[0],img.shape[0], channels), self.color, dtype=np.uint8)
            #print('shape', result1.shape)
            # compute center offset
            x_center = np.abs(img.shape[0] - old_image_width) // 2
            #y_center = np.abs(self.chip_size - old_image_height) // 2
            #print('x_center', x_center)
            # copy img image into center of result image
            result1[:, x_center:x_center+old_image_width] = img
        else:
            result1 = img[:,:,:]
        
        # now resize
        if not (result1.shape[0] == self.chip_size and result1.shape[1] == self.chip_size):
            #print('result1.shape1', result1.shape)
            result1 = cv2.resize(result1, dsize=(self.chip_size, self.chip_size), interpolation=cv2.INTER_CUBIC)
            #print('result1.shape2', result1.shape)
        
        res2 = result1.reshape(-1, self.chip_size, self.chip_size).astype(np.float32) # = arr[:, :-1].reshape(-1, self.chip_size, self.chip_size)
        #print('res2.shape', res2.shape)

        # ToTensor screws up the order, so we have to undo it:
        # https://discuss.pytorch.org/t/torchvision-totensor-dont-change-channel-order/82038/2
        #res2 = res2.permute((1, 2, 0)).contiguous()

        # convert back to PIL Image object for pytorch transforms (e.g. RandomRotation) to work
        #image = Image.fromarray(image)

        #self.images.append(res2)
        
        return res2

#         # labels
#         # NOTE: labels must start at 0, not 1, or PyTorch is unhappy
#         if self.classdict.get(d, -1) >= 0:
#             self.labels.append(self.classdict[d])
#         else: # it's not in self.classdict yet
#             self.classdict[d] = self.classval
#             self.classval += 1

#         if self.max_num_examples > 0:
#             #z1 = list(zip(self.images, self.labels))
#             classdict = {}
#             for i in range(len(self.labels)):
#                 if classdict.get(self.labels[i], False):
#                     classdict[self.labels[i]].append(self.images[i])
#                 else:
#                     classdict[self.labels[i]] = [self.images[i]]
#             self.images = []
#             self.labels = []
#             for i in range(self.classval):
#                 ####################################################
#                 # if too many examples, take self.max_num_examples random examples for that class
#                 ####################################################
#                 if len(classdict[i]) > self.max_num_examples:
#                     temp = random.sample(classdict[i], self.max_num_examples)

#                     for t in temp:
#                         self.images.append(t)
#                         self.labels.append(i)
#                 ####################################################
#                 # if not enough examples, take all examples of that class
#                 ####################################################
#                 else:
#                     for t in classdict[i]:
#                         self.images.append(t)
#                         self.labels.append(i)

        self.num_samples = len(self.labels)
        print("self.num_samples", self.num_samples)


    def __getitem__(self, index):
        image, label = self.images[index], self.labels[index]
        # image is a numpy ndarray instead of PIL Image object
        # NOTE: certain pytorch functions (aka RandomRotate) require PIL Image objects
        # ToTensor screws up the shape/channel order, so we have to undo it:
        # https://discuss.pytorch.org/t/torchvision-totensor-dont-change-channel-order/82038/2
        #image = image.permute((1, 2, 0)).contiguous()
        #image = Image.fromarray(image)
        if self.transform is not None:
            image = self.transform(image)
        return image, label


    def __len__(self):
        return len(self.labels)

# images are padded to have shape 29x29.
# this allows to use odd-size filters with stride 2 when downsampling a feature map in the model
#pad = Pad((0, 0, 1, 1), fill=0)

# to reduce interpolation artifacts (e.g. when testing the model on rotated images),
# we upsample an image by a factor of 3, rotate it and finally downsample it again
#resize1 = Resize(80*3)
#resize2 = Resize(80)

#totensor = ToTensor()

In [12]:
# build the test set
vtest = TheDataset(mode='val', dataname=dataname, max_num_examples=1)
#print(vtest.index_lists['75'])
print(vtest.classdict)
del vtest

100%|██████████| 15/15 [00:00<00:00, 934.89class/s]

{'baseball-diamond': 0, 'basketball-court': 1, 'bridge': 2, 'ground-track-field': 3, 'harbor': 4, 'helicopter': 5, 'large-vehicle': 6, 'plane': 7, 'roundabout': 8, 'ship': 9, 'small-vehicle': 10, 'soccer-ball-field': 11, 'storage-tank': 12, 'swimming-pool': 13, 'tennis-court': 14}





In [13]:
# retrieve the first image from the test set
#x, y = next(iter(test))
#
#print(x.shape)
#print(test[3530])

In [14]:
def test_model(model: torch.nn.Module, x: Image):
    # evaluate the `model` on 8 rotated versions of the input image `x`
    model.eval()
    totensor = ToTensor()
    
    wrmup = model(torch.randn(1, 3, 80, 80).to(device))
    del wrmup
    
    #x = resize1(pad(x))
    
    print()
    print('##########################################################################################')
    header = 'angle |  ' + '  '.join(["{:6d}".format(d) for d in range(15)])
    print(header)
    with torch.no_grad():
        for r in range(8):
            #print(np.min(x), np.max(x), x.shape)
            intimg = x.astype(np.uint8)
            #print(type(intimg))
            #print(np.min(intimg), np.max(intimg), intimg.shape)
            # go from 3,80,80 -> 80,80,3
            #intimg2 = intimg.permute((1, 2, 0)).contiguous() # can't permute numpy arrays
            intimg2 = intimg.transpose(1, 2, 0)
            #print(np.min(intimg2), np.max(intimg2), intimg2.shape)
            img = Image.fromarray(intimg2)
            #print(np.min(img), np.max(img))
            rotimg = img.rotate(r*45., Image.BILINEAR)
            #print(rotimg.shape)
            x_transformed = totensor(rotimg).reshape(-1, 3, 80, 80)
            x_transformed = x_transformed.to(device)

            y = model(x_transformed)
            y = y.to('cpu').numpy().squeeze()
            
            angle = r * 45
            print("{:3d}:{}".format(angle, y))
    print('##########################################################################################')
    print()

In [15]:
# evaluate the model
#test_model(model, x)

The output of the model is already almost invariant.
However, we still observe small fluctuations in the outputs.

This is because the model contains some operations which might break equivariance.
For instance, every convolution includes a padding of $2$ pixels per side. This is adds information about the actual orientation of the grid where the image/feature map is sampled because the padding is not rotated with the image. 

During training, the model will observe rotated patterns and will learn to ignore the noise coming from the padding.

So, let's train the model now.
The model is exactly the same used to train a normal *PyTorch* architecture:

In [16]:
#dataname: DOTAv1.0, xView, FAIR1M, etc
#dataname = 'xView'

In [17]:
# Prep the training dataset
train_transform = Compose([
    #pad,
    #resize1,
    #RandomRotation(180, resample=Image.BILINEAR, expand=False),
    #resize2,
    #ToTensor()
])

#data_train = TheDataset(mode='train', transform=train_transform, dataname=dataname)
#train_loader = torch.utils.data.DataLoader(data_train, batch_size=8, shuffle=True)
#print(len(data_train), len(train_loader))

In [18]:
def get_fresh_training_data(n, index_lists={}, classdict={}):
    data_train = TheDataset(mode='train', transform=train_transform, max_num_examples=n, 
                            chip_size=80, dataname=dataname, index_lists=index_lists, classdict=classdict)
    train_loader = torch.utils.data.DataLoader(data_train, batch_size=8, shuffle=True, drop_last=True)
    #print(len(data_train), len(train_loader))
    return train_loader, data_train.index_lists, data_train.classdict

In [19]:
# Prep the testing dataset
val_transform = Compose([
    #pad,
    #ToTensor()
])

data_val = TheDataset(mode='val', transform=val_transform, 
                      chip_size=80, dataname=dataname)
val_loader = torch.utils.data.DataLoader(data_val, batch_size=8, shuffle=True, drop_last=True)
print(len(data_val), "labels")
print(len(val_loader), "batches")

100%|██████████| 15/15 [00:17<00:00,  1.16s/class]

28853 labels
3606 batches





In [20]:
label_is = []
for v in val_loader:
    label_is.append(v[1])
    break
print(label_is)

[tensor([ 9,  6,  9,  9, 12,  9,  7,  9])]


In [21]:
# mini = 99999
# maxi = 0
# for m in data_train: # data_val
#     if m[1] < mini:
#         mini = m[1]
#     if m[1] > maxi:
#         maxi = m[1]
# print(mini, maxi, "<--- should be 0 14 instead of 1 15")

In [22]:
def save_the_model(model, cyclic_group, epoch, train_acc, train_loss, test_acc, test_loss, dataname):
    if not os.path.exists("models"):
        os.mkdir("models")
    model_name = "models/model_"+dataname+"_C"+str(cyclic_group)+"_"+str(epoch)+"_"+str(round(train_acc,4))+"_"+str(round(train_loss,4))+"_"+str(round(test_acc,4))+"_"+str(round(test_loss,4))+".pth"
#    if not os.path.exists("models/model_"+dataname+".pth"):
    torch.save(model.state_dict(), model_name)
#    print('saved')
#    else:
#         model.load_state_dict(torch.load("models/model_"+dataname+".pth"))
#         print('loaded')

In [23]:
def save_the_results(cyclic_group, epoch, train_accuracy, train_loss, test_accuracy, test_loss, dataname):
    add_headers = False
    if not os.path.exists("results"):
        os.mkdir("results")
    if not os.path.exists("results/model_"+dataname+".csv"):
        add_headers = True
    # writing to csv file
    with open("results/model_"+dataname+".csv", 'a', newline='') as csvfile: 
        # creating a csv writer object 
        csvwriter = csv.writer(csvfile)

        if add_headers:
            csvwriter.writerow(['Dataname', 'Cyclic Group', 'Epoch', 'Train Accuracy', 'Train Loss', 'Test Accuracy', 'Test Loss'])

        # writing the data rows
        csvwriter.writerow([dataname, cyclic_group, epoch, train_accuracy, train_loss, test_accuracy, test_loss])

In [24]:
# Initialize the model
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-5)

######################################################
# IF CONTINUING TRAINING, LOAD MODEL FROM LAST EPOCH
######################################################
CONTINUING = False
if CONTINUING:
    # train_loader = get_fresh_training_data(30000)
    if EQUIVARIANT:
        model.load_state_dict(torch.load("models/model_dota_C8_65_97.6337_874.7356_84.8987_2469.3884.pth"))
    else:
        model.load_state_dict(torch.load("models/model_dota_C1_64_95.0866_865.1543_86.0953_1910.6196.pth"))

In [25]:
# # get initial performance with random weights
# test_total = 0
# test_correct = 0
# test_loss = 0
# with torch.no_grad():
#     model.eval()
    
#     #for i, (x, t) in enumerate(val_loader):
#     with tqdm(val_loader, unit="batch") as tepoch:
#         for x, t in tepoch:
#             tepoch.set_description(f"Epoch {-1}")
            
#             #if i%1000==0:
#             #    print(i, "/", len(test_loader))

#             x = x.to(device)
#             t = t.to(device)

#             y = model(x)

#             _, prediction = torch.max(y.data, 1)
#             if prediction.shape[0] != t.shape[0]:
#                 print(t)
#                 t = t[:-(t.shape[0]-prediction.shape[0])]
#                 print(t)
#             test_total += t.shape[0]
#             test_correct += (prediction == t).sum().item()

#             loss = loss_function(y, t)
#             test_loss += loss

# test_accuracy = test_correct/test_total*100.
    
# print(f"test accuracy: {test_accuracy}")
# print(f"test loss: {test_loss.item()}")

In [None]:
samples_per_class = 500
max_per_class = 3000
increase_by = 100

index_lists = {}
classdict = data_val.classdict.copy() # so we use the same labels for training later

start_epoch = 0
max_epochs = 100

for epoch in range(start_epoch, max_epochs):
    
    print('starting epoch', epoch)
    
    ########################################
    # TRAIN
    ########################################
    train_total = 0
    train_correct = 0
    train_loss = 0
    model.train()
    
    # NOTE: i pull fresh training data examples each epoch to counter the huge class imbalance.
    #       i keep the validation data the same each epoch though to make sure i'm consistently
    #       measuring validation accuracy/loss.
    train_loader, index_lists, classdict = get_fresh_training_data(samples_per_class,
                                                                   index_lists=index_lists,
                                                                   classdict=classdict)
    print("samples_per_class", samples_per_class)
    # NOTE: increasing the number of samples per class will mess up the loss curve because i am
    #       adding more examples and thus more potential loss each epoch. comment the next line out
    #       if you wish to avoid this.
    if samples_per_class < max_per_class:
        samples_per_class += increase_by

    with tqdm(train_loader, unit="batch") as tepoch:
        for x, t in tepoch:
            tepoch.set_description(f"Epoch {epoch}")

            optimizer.zero_grad()

            x = x.to(device)
            t = t.to(device)

            y = model(x)
            _, prediction = torch.max(y.data, 1)

            train_total += t.shape[0]
            train_correct += (prediction == t).sum().item()

            loss = loss_function(y, t)
            train_loss += loss

            loss.backward()

            optimizer.step()
            
            train_accuracy = 100.*train_correct/train_total
            tepoch.set_postfix(Train_Acc=train_accuracy, Train_Loss=train_loss.item())

    tepoch.close()

    ########################################
    # TEST
    ########################################
    test_total = 0
    test_correct = 0
    test_loss = 0
    
    with torch.no_grad():
        model.eval()
        
        with tqdm(val_loader, unit="batch") as tepoch:
            for x, t in tepoch:
                tepoch.set_description(f"Epoch {epoch}")

                x = x.to(device)
                t = t.to(device)

                y = model(x)
                _, prediction = torch.max(y.data, 1)

                test_total += t.shape[0]
                test_correct += (prediction == t).sum().item()

                loss = loss_function(y, t)
                test_loss += loss

                test_accuracy = 100.*test_correct/test_total
                tepoch.set_postfix(Test_Acc=test_accuracy, Test_Loss=test_loss.item())

    tepoch.close()
    time.sleep(0.5)
    
    save_the_results(cyclic_group, epoch, train_accuracy, train_loss.item(), test_accuracy, test_loss.item(), dataname)
    save_the_model(model, cyclic_group, epoch, train_accuracy, train_loss.item(), test_accuracy, test_loss.item(), dataname)


starting epoch 0


100%|██████████| 15/15 [00:06<00:00,  2.33class/s]


samples_per_class 500


Epoch 0: 100%|██████████| 870/870 [04:31<00:00,  3.20batch/s]
Epoch 0: 100%|██████████| 3606/3606 [05:19<00:00, 11.29batch/s]


starting epoch 1


100%|██████████| 15/15 [00:08<00:00,  1.77class/s]


samples_per_class 600


Epoch 1: 100%|██████████| 997/997 [05:56<00:00,  2.80batch/s]
Epoch 1: 100%|██████████| 3606/3606 [05:24<00:00, 11.12batch/s]


starting epoch 2


100%|██████████| 15/15 [00:09<00:00,  1.53class/s]


samples_per_class 700


Epoch 2: 100%|██████████| 1113/1113 [05:48<00:00,  3.19batch/s] 
Epoch 2: 100%|██████████| 3606/3606 [04:05<00:00, 14.68batch/s]


starting epoch 3


100%|██████████| 15/15 [00:08<00:00,  1.72class/s]


samples_per_class 800


Epoch 3: 100%|██████████| 1226/1226 [05:39<00:00,  3.61batch/s]
Epoch 3: 100%|██████████| 3606/3606 [04:05<00:00, 14.71batch/s]


starting epoch 4


100%|██████████| 15/15 [00:09<00:00,  1.64class/s]


samples_per_class 900


Epoch 4: 100%|██████████| 1338/1338 [06:10<00:00,  3.62batch/s]
Epoch 4: 100%|██████████| 3606/3606 [04:05<00:00, 14.71batch/s]


starting epoch 5


100%|██████████| 15/15 [00:10<00:00,  1.48class/s]


samples_per_class 1000


Epoch 5: 100%|██████████| 1451/1451 [06:41<00:00,  3.62batch/s]
Epoch 5: 100%|██████████| 3606/3606 [04:05<00:00, 14.71batch/s]


starting epoch 6


100%|██████████| 15/15 [00:10<00:00,  1.43class/s]


samples_per_class 1100


Epoch 6: 100%|██████████| 1563/1563 [07:12<00:00,  3.62batch/s]
Epoch 6: 100%|██████████| 3606/3606 [04:07<00:00, 14.59batch/s]


starting epoch 7


100%|██████████| 15/15 [00:12<00:00,  1.15class/s]


samples_per_class 1200


Epoch 7: 100%|██████████| 1676/1676 [08:23<00:00,  3.33batch/s]
Epoch 7: 100%|██████████| 3606/3606 [04:27<00:00, 13.50batch/s]


starting epoch 8


100%|██████████| 15/15 [00:11<00:00,  1.26class/s]


samples_per_class 1300


Epoch 8: 100%|██████████| 1788/1788 [08:16<00:00,  3.60batch/s]
Epoch 8: 100%|██████████| 3606/3606 [04:05<00:00, 14.70batch/s]


starting epoch 9


100%|██████████| 15/15 [00:12<00:00,  1.24class/s]


samples_per_class 1400


Epoch 9: 100%|██████████| 1901/1901 [08:45<00:00,  3.61batch/s]
Epoch 9: 100%|██████████| 3606/3606 [04:05<00:00, 14.68batch/s]


starting epoch 10


100%|██████████| 15/15 [00:12<00:00,  1.17class/s]


samples_per_class 1500


Epoch 10: 100%|██████████| 2013/2013 [09:17<00:00,  3.61batch/s]
Epoch 10: 100%|██████████| 3606/3606 [04:06<00:00, 14.65batch/s]


starting epoch 11


100%|██████████| 15/15 [00:13<00:00,  1.11class/s]


samples_per_class 1600


Epoch 11: 100%|██████████| 2126/2126 [09:48<00:00,  3.61batch/s]
Epoch 11: 100%|██████████| 3606/3606 [04:06<00:00, 14.66batch/s]


starting epoch 12


100%|██████████| 15/15 [00:13<00:00,  1.09class/s]


samples_per_class 1700


Epoch 12: 100%|██████████| 2238/2238 [10:19<00:00,  3.61batch/s]
Epoch 12: 100%|██████████| 3606/3606 [04:06<00:00, 14.64batch/s]


starting epoch 13


100%|██████████| 15/15 [00:14<00:00,  1.04class/s]


samples_per_class 1800


Epoch 13: 100%|██████████| 2343/2343 [10:45<00:00,  3.63batch/s]
Epoch 13: 100%|██████████| 3606/3606 [04:31<00:00, 13.28batch/s]


starting epoch 14


100%|██████████| 15/15 [00:17<00:00,  1.18s/class]


samples_per_class 1900


Epoch 14: 100%|██████████| 2443/2443 [12:15<00:00,  3.32batch/s]
Epoch 14: 100%|██████████| 3606/3606 [04:05<00:00, 14.70batch/s]


starting epoch 15


100%|██████████| 15/15 [00:17<00:00,  1.13s/class]


samples_per_class 2000


Epoch 15: 100%|██████████| 2543/2543 [11:41<00:00,  3.62batch/s]
Epoch 15: 100%|██████████| 3606/3606 [04:05<00:00, 14.67batch/s]


starting epoch 16


100%|██████████| 15/15 [00:17<00:00,  1.15s/class]


samples_per_class 2100


Epoch 16: 100%|██████████| 2636/2636 [12:06<00:00,  3.63batch/s]
Epoch 16: 100%|██████████| 3606/3606 [04:05<00:00, 14.71batch/s]


starting epoch 17


100%|██████████| 15/15 [00:17<00:00,  1.18s/class]


samples_per_class 2200


Epoch 17: 100%|██████████| 2724/2724 [12:30<00:00,  3.63batch/s]
Epoch 17: 100%|██████████| 3606/3606 [04:03<00:00, 14.78batch/s]


starting epoch 18


100%|██████████| 15/15 [00:18<00:00,  1.23s/class]


samples_per_class 2300


Epoch 18: 100%|██████████| 2811/2811 [12:53<00:00,  3.64batch/s]
Epoch 18: 100%|██████████| 3606/3606 [04:04<00:00, 14.77batch/s]


starting epoch 19


100%|██████████| 15/15 [00:19<00:00,  1.27s/class]


samples_per_class 2400


Epoch 19: 100%|██████████| 2895/2895 [14:15<00:00,  3.38batch/s]
Epoch 19: 100%|██████████| 3606/3606 [04:17<00:00, 13.99batch/s]


starting epoch 20


100%|██████████| 15/15 [00:19<00:00,  1.29s/class]


samples_per_class 2500


Epoch 20: 100%|██████████| 2970/2970 [13:47<00:00,  3.59batch/s]
Epoch 20: 100%|██████████| 3606/3606 [04:07<00:00, 14.55batch/s]


starting epoch 21


100%|██████████| 15/15 [00:20<00:00,  1.38s/class]


samples_per_class 2600


Epoch 21:  23%|██▎       | 709/3045 [03:15<10:48,  3.60batch/s]

In [None]:
# save_the_results(cyclic_group, epoch, train_accuracy, train_loss.item(), test_accuracy, test_loss.item(), dataname)
# save_the_model(model, cyclic_group, epoch, train_accuracy, train_loss.item(), test_accuracy, test_loss.item(), dataname)

In [None]:
# train_accuracy = train_correct/train_total*100.
# print(train_correct, train_total, train_accuracy)
#save_the_results(cyclic_group, epoch, train_accuracy, train_loss.item(), test_accuracy, test_loss.item())
#save_the_model(model, cyclic_group, epoch, train_accuracy, train_loss.item(), test_accuracy, test_loss.item())

In [None]:
# print(prediction.shape, t.shape)
# print(prediction)
# print(t)
# print(t[:-1])
# print(y.shape)
# print(torch.max(y.data, 1)) # 1 is the dimension

In [None]:
# epoch -1 | test accuracy: 10.011096469935502
# epoch -1 | test loss: 2109988.5

# epoch 0 | train accuracy: 52.91684784257414
# epoch 0 | train loss: 42925.43359375
# epoch 0 | test accuracy: 18.846660656078782
# epoch 0 | test loss: 201328.765625

In [None]:
#data_test = TheDataset(mode='test', dataname=dataname)

In [None]:
# retrieve the first image from the test set
x, y = next(iter(data_val))

# evaluate the model
test_model(model, x)

In [None]:
#n_classes = 15 # 15 62

confusion_matrix = torch.zeros(n_classes, n_classes)
with torch.no_grad():
#     for i, (x, t) in enumerate(val_loader):
#         if i%1000==0:
#             print(i, "/", len(val_loader))
    with tqdm(val_loader, unit="batch") as tepoch:
        for x, t in tepoch:
            x = x.to(device)
            t = t.to(device)
            y = model(x)
            _, preds = torch.max(y, 1)
            for t2, p in zip(t.view(-1), preds.view(-1)):
                confusion_matrix[t2.long(), p.long()] += 1

In [None]:
print(confusion_matrix)

In [None]:
import sys
np.set_printoptions(threshold=sys.maxsize)
cm = confusion_matrix.numpy().copy()
cm = cm.astype('int32')
print(cm)

In [None]:
for root, dirs, filenames in os.walk(data_val.chipdir, topdown=False):
    pass
print(dirs)

In [None]:
if EQUIVARIANT:
    cnum = "8"
else:
    cnum = "1"

dfcm = pd.DataFrame(cm,
                    index = dirs,
                    columns=dirs)
plt.figure(figsize = (18,18))
sns.heatmap(dfcm, annot=True, cmap="Blues", fmt='d')
plt.savefig("results/c"+cnum+"_confusion_matrix_"+dataname+"_counts.png",
            dpi=100,
            bbox_inches='tight')

In [None]:
print(100*confusion_matrix.diag()/confusion_matrix.sum(1))

In [None]:
cm2 = confusion_matrix.numpy().copy()
#cm2 = cm2.astype('int32')
print(cm2)

In [None]:
print(confusion_matrix.sum(1))

In [None]:
for row in range(len(cm2)):
    for col in range(len(cm2[row])):
        cm2[row, col] = (100 * cm2[row, col]) / (confusion_matrix.sum(1)[row])

In [None]:
cm2 = cm2.astype('float32') #float32, int32
#cm2 = cm2.astype('int32') #float32, int32
print(cm2)

In [None]:
if EQUIVARIANT:
    cnum = "8"
else:
    cnum = "1"

dfcm2 = pd.DataFrame(cm2,
                    index = dirs,
                    columns=dirs)
plt.figure(figsize = (18,18))
sns.heatmap(dfcm2, annot=True, cmap="Blues", fmt='.0f') # , fmt='.1f' 'd' 'g'
plt.savefig("results/c"+cnum+"_confusion_matrix_"+dataname+"_accuracies.png",
            dpi=100,
            bbox_inches='tight')