Testing a local regularization system where adjacent nodes in the last hidden layer are regularized to have 
similar associated weights

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time
import os
from windows_inhibitor import WindowsInhibitor
import pickle
from groupy.gconv.pytorch_gconv.splitgconv2d import P4MConvZ2, P4MConvP4M, P4ConvZ2, P4ConvP4
from groupy.gconv.pytorch_gconv.pooling import plane_group_spatial_max_pooling
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib.pyplot as plt
import logging
logging.basicConfig(level=logging.INFO, filename='testing_sugg_soft.log', 
                    format='%(asctime)s:%(name)s:%(levelname)s:%(message)s')
logging.info('Starting New Test')

In [10]:
indices = list(range(60_000))
# np.random.shuffle(indices)
train_indices = indices[:10_000]
valid_indices = indices[50_000:]
logging.info(f'Train length: {len(train_indices)}')
logging.info(f'Valid length: {len(valid_indices)}')

In [11]:
AUGMENT = False
if AUGMENT:
    train_transform = transforms.Compose([
    # transforms.RandomHorizontalFlip(),
    # transforms.Pad(2),
    transforms.RandomAffine(15, translate=(0.1, 0.1), scale=(0.95, 1.05)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])
    
    valid_transform = transforms.Compose([
        # transforms.Pad(2),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
else:
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
    
logging.info(f'train_transform: {repr(train_transform)}')
logging.info(f'valid_transform: {repr(valid_transform)}')
    
translate_transform = transforms.Compose([
    transforms.RandomAffine(0, translate=(0.2, 0.2), scale=(1, 1)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])

rotate_transform = transforms.Compose([
    transforms.RandomAffine(30),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)

train_set = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=train_transform)
valid_set = torchvision.datasets.MNIST(root='./data', train=False,
                                      download=True, transform=valid_transform)
translate_set = torchvision.datasets.MNIST(root='./data', train=False,
                                      download=True, transform=translate_transform)
rotation_set = torchvision.datasets.MNIST(root='./data', train=False,
                                      download=True, transform=rotate_transform)
# trainset, validation_set = random_split(full_train_set,[50_000,10_000])
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=20,
#                                           shuffle=True, num_workers=2)
# testset = torchvision.datasets.MNIST(root='./data', train=False,
#                                      download=True, transform=transforms.ToTensor())
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
               'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

In [12]:
USE_CUDA = True
# Regularization hyper-parameters
BASE = 1
LAMBDA = 30
DECAY = 0.1 ** 6
logging.info(f'(Regularization) BASE = {BASE}, LAMBDA = {LAMBDA}, DECAY = {DECAY}')
# Optimization hyper-parameters
LR = 0.01
LR_BASE = 0.95
MOMENTUM = 0.5
logging.info(f'LR = {LR}, LR_BASE = {LR_BASE}, MOMENTUM = {MOMENTUM}')
# testloader = torch.utils.data.DataLoader(testset, batch_size=1024,
#                                          shuffle=False, num_workers=2)

trainloader = torch.utils.data.DataLoader(train_set, batch_size=20, sampler=train_sampler,
                                          shuffle=False, num_workers=2, pin_memory=USE_CUDA)
validation_loader = torch.utils.data.DataLoader(valid_set, batch_size=256,
                                         shuffle=False, num_workers=2, pin_memory=False)
translate_loader = torch.utils.data.DataLoader(translate_set, batch_size=256,
                                         shuffle=False, num_workers=2, pin_memory=False)
rotation_loader = torch.utils.data.DataLoader(rotation_set, batch_size=256,
                                         shuffle=False, num_workers=2, pin_memory=False)
logging.info(f'Trainloader = {repr(trainloader)}')

In [13]:
# functions to show an image


def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
# dataiter = iter(trainloader)
# images, labels = dataiter.next()
# 
# # show images
# imshow(torchvision.utils.make_grid(images))

In [14]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = P4ConvZ2(1, 16, kernel_size=3)
        self.conv2 = P4ConvP4(16, 32, kernel_size=3)
        # self.conv3 = P4MConvP4M(20, 20, kernel_size=5, padding=2, stride=2)
        self.conv4 = P4ConvP4(32, 64, kernel_size=3)
        self.conv5 = P4ConvP4(64, 128, kernel_size=3)
        # self.conv6 = P4MConvP4M(40, 40, kernel_size=5, padding=2, stride=2)
        self.fc1 = nn.Linear(4 * 4 * 128 * 4, 10)
        # self.prior_mean = nn.Parameter(torch.zeros([128, 10, 4, 1, 1], requires_grad=True))
        # self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        # x = F.relu(self.conv3(x))
        x = plane_group_spatial_max_pooling(x, 2, 2)
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        # x = F.relu(self.conv6(x))
        x = plane_group_spatial_max_pooling(x, 2, 2)
        # print(x.shape)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        # x = F.dropout(x, training=self.training, p=0.4)
        # x = self.fc2(x)
        return x

In [15]:
def run():
    net = Net()
    # from my_nets import DoubleEquivFat
    # net = DoubleEquivFat()
    train_error_list = []
    test_error_list = []
    true_epoch = 0
    total_time = 0
    torch.cuda.empty_cache()
    logging.info(repr(net))
    
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # net.to(device)
    # trainloader = torch.utils.data.DataLoader(train_set, batch_size=20, sampler=train_sampler,
    #                                           shuffle=False, num_workers=2, pin_memory=USE_CUDA)
    criterion = nn.CrossEntropyLoss()
    equivariant_loss = nn.MSELoss()
    logging.info(f'equivariant loss function: {repr(equivariant_loss)}')
    # weight_penalty = nn.MSELoss()
    
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = torch.device("cuda:0" if USE_CUDA else "cpu")
    net.to(device)
    learned_train_set = False
    vert_convolver = torch.tensor([[1], [-1]], dtype=torch.float32, requires_grad=False, device=device)
    horiz_convolver = torch.tensor([[1, -1]], dtype=torch.float32, requires_grad=False, device=device)
    vert_convolver, horiz_convolver = vert_convolver.view([1, 1, 1, 2, 1]), horiz_convolver.view([1, 1, 1, 1, 2])
    # net.gconv1.flat_indices = net.gconv1.flat_indices.cuda()
    for epoch in range(80):  # loop over the dataset multiple times
        start_time = time.time()
        # little_lambda = LAMBDA * BASE ** true_epoch
        optimizer = optim.SGD(net.parameters(), lr=LR * LR_BASE ** true_epoch, momentum=MOMENTUM, weight_decay=DECAY)
        little_lambda = (1 / (len(trainloader) * 2)) * LAMBDA * BASE ** true_epoch
        true_epoch += 1
        logging.info(f'Starting epoch {true_epoch}')
        running_loss = 0.0
        true_train_total = 0.0
        correct_train_total = 0.0
        correct_train = 0.0
        total_train = 0.0
        total_equi_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            
            nonequiv_weights1 = net.fc1.weight
            nonequiv_weights1 = nonequiv_weights1.view(1, 1, 128 * 10 * 4, 4, 4)
            # prime_weights = net.prior_mean.expand(128, 10, 4, 4, 4)
            # if i % 500 == 499:
            #     print(nonequiv_weights1[0,0,:,:,:])
            #     print(prime_weights[0,0,:,:,:])
            if little_lambda:
                discrepency =  torch.sum(F.conv3d(nonequiv_weights1, vert_convolver) ** 2) +\
                    torch.sum(F.conv3d(nonequiv_weights1, horiz_convolver) ** 2)
                loss += little_lambda * discrepency
                total_equi_loss += little_lambda * discrepency.detach()
                # if true_epoch > 3:
                #     print(little_lambda * equivariant_loss(nonequiv_weights1, prime_weights))
            # if DECAY:
            #     loss += DECAY * weight_penalty(net.fc2.weight,prime_weights.new_zeros([1],requires_grad=False))  

            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            true_train_total += labels.size(0)
            correct_train += (predicted == labels).sum().item()
            correct_train_total += (predicted == labels).sum().item()

            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 500 == 499:  # print every 500 mini-batches
                msg = f'[{true_epoch}, {i + 1}] loss: {running_loss / total_train:.4f}, ' \
                      f'e-loss: {total_equi_loss / total_train:.4f}, ' \
                      f'Correct Rate: {100 * correct_train / total_train:.2f}%'
                print(msg)
                
                # if running_loss / 100  < 0.005:
                #     break
                running_loss = 0.0
                correct_train = 0.0
                total_train = 0.0
                total_equi_loss = 0.0
        train_error_list.append(100*correct_train_total / true_train_total)
        print(f"Correctness on training epoch: {100*correct_train_total / true_train_total:.2f}%")
        logging.info(f"Correctness on training epoch: {100*correct_train_total / true_train_total:.2f}%")
        correct = 0
        total = 0
        with torch.no_grad():
            for val_data in validation_loader:
                images, labels = val_data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            test_error_list.append(100 * correct / total)
            print(f'Accuracy of the network on the 10000 validation images: {100 * correct / total:.2f}%')
            logging.info(f'Accuracy of the network on the 10000 validation images: {100 * correct / total:.2f}%')
            for trans_data in translate_loader:
                images, labels = trans_data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            test_error_list.append(100 * correct / total)
            print(f'Accuracy of the network on the 10000 translated images: {100 * correct / total:.2f}%')
            logging.info(f'Accuracy of the network on the 10000 translated images: {100 * correct / total:.2f}%')
            for rotate_data in rotation_loader:
                images, labels = rotate_data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            test_error_list.append(100 * correct / total)
            print(f'Accuracy of the network on the 10000 rotated images: {100 * correct / total:.2f}%')
            logging.info(f'Accuracy of the network on the 10000 rotated images: {100 * correct / total:.2f}%')
        total_time += time.time() - start_time
        print(f'Finished epoch, cumulative time: {total_time}s')
        if correct_train_total == true_train_total:
            if not learned_train_set:
                learned_train_set = True
            else:
                break
    print("Finished")
    logging.info("Finished")
    

In [16]:
with WindowsInhibitor():
    logging.info("-----------------------Starting Testing--------------------------------")
    for i in range(6):
        logging.info(f'-------------------------Starting test {i+1}----------------------------')
        run()
    logging.info(f'-------------------------Testing Finished----------------------------')

Preventing Windows from going to sleep
[1, 500] loss: 0.0218, e-loss: 0.0072, Correct Rate: 90.79%
Correctness on training epoch: 90.79%
Accuracy of the network on the 10000 validation images: 97.05%
Accuracy of the network on the 10000 translated images: 72.23%
Accuracy of the network on the 10000 rotated images: 78.22%
Finished epoch, cumulative time: 28.3751163482666s
[2, 500] loss: 0.0088, e-loss: 0.0037, Correct Rate: 97.03%
Correctness on training epoch: 97.03%
Accuracy of the network on the 10000 validation images: 97.72%
Accuracy of the network on the 10000 translated images: 76.27%
Accuracy of the network on the 10000 rotated images: 81.38%
Finished epoch, cumulative time: 56.91287541389465s
[3, 500] loss: 0.0066, e-loss: 0.0030, Correct Rate: 97.88%
Correctness on training epoch: 97.88%
Accuracy of the network on the 10000 validation images: 96.83%
Accuracy of the network on the 10000 translated images: 77.58%
Accuracy of the network on the 10000 rotated images: 81.84%
Finish