In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split
from torch.utils.data.sampler import SubsetRandomSampler
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time
from windows_inhibitor import WindowsInhibitor
from groupy.gconv.pytorch_gconv.splitgconv2d import P4MConvZ2, P4MConvP4M, P4ConvZ2, P4ConvP4
from groupy.gconv.pytorch_gconv.pooling import plane_group_spatial_max_pooling
import logging
logging.basicConfig(level=logging.INFO, filename='testing_sugg.log', 
                    format='%(asctime)s:%(name)s:%(levelname)s:%(message)s')
logging.info('Starting New Test')

In [2]:
indices = list(range(60_000))
train_indices = indices[:10_000]
valid_indices = indices[50_000:]
logging.info(f'Train length: {len(train_indices)}')
logging.info(f'Valid length: {len(valid_indices)}')

In [3]:
AUGMENT = False
if AUGMENT:
    train_transform = transforms.Compose([
    transforms.RandomAffine(15, translate=(0.1, 0.1), scale=(0.95, 1.05)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])
    
    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
else:
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
    
logging.info(f'train_transform: {repr(train_transform)}')
logging.info(f'valid_transform: {repr(valid_transform)}')
    
translate_transform = transforms.Compose([
    transforms.RandomAffine(0, translate=(0.2, 0.2), scale=(1, 1)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])

rotate_transform = transforms.Compose([
    transforms.RandomAffine(30),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)

train_set = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=train_transform)
valid_set = torchvision.datasets.MNIST(root='./data', train=False,
                                      download=True, transform=valid_transform)
translate_set = torchvision.datasets.MNIST(root='./data', train=False,
                                      download=True, transform=translate_transform)
rotation_set = torchvision.datasets.MNIST(root='./data', train=False,
                                      download=True, transform=rotate_transform)

classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
               'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

In [4]:
USE_CUDA = True
BASE = 1
LAMBDA = 3_000_000
DECAY = 0.1**6
logging.info(f'BASE = {BASE}, LAMBDA = {LAMBDA}, DECAY = {DECAY}')
LR = 0.01
LR_BASE = 0.95
MOMENTUM = 0.5
logging.info(f'LR = {LR}, LR_BASE = {LR_BASE}, MOMENTUM = {MOMENTUM}')

trainloader = torch.utils.data.DataLoader(train_set, batch_size=20, sampler=train_sampler,
                                          shuffle=False, num_workers=2, pin_memory=USE_CUDA)
validation_loader = torch.utils.data.DataLoader(valid_set, batch_size=256,
                                         shuffle=False, num_workers=2, pin_memory=False)
translate_loader = torch.utils.data.DataLoader(translate_set, batch_size=256,
                                         shuffle=False, num_workers=2, pin_memory=False)
rotation_loader = torch.utils.data.DataLoader(rotation_set, batch_size=256,
                                         shuffle=False, num_workers=2, pin_memory=False)


In [6]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = P4ConvZ2(1, 16, kernel_size=3)
        self.conv2 = P4ConvP4(16, 32, kernel_size=3)
        self.conv4 = P4ConvP4(32, 64, kernel_size=3)
        self.conv5 = P4ConvP4(64, 128, kernel_size=3)
        self.fc1 = nn.Linear(4 * 4 * 128 * 4, 10)
        self.prior_mean = nn.Parameter(torch.zeros([128, 10, 1, 1, 1], requires_grad=True))

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = plane_group_spatial_max_pooling(x, 2, 2)
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = plane_group_spatial_max_pooling(x, 2, 2)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        return x

In [7]:
def run():
    net = Net()
    train_error_list = []
    test_error_list = []
    true_epoch = 0
    total_time = 0
    torch.cuda.empty_cache()
    logging.info(repr(net))
    
    criterion = nn.CrossEntropyLoss()
    equivariant_loss = nn.MSELoss()
    logging.info(f'equivariant loss function: {repr(equivariant_loss)}')

    device = torch.device("cuda:0" if USE_CUDA else "cpu")
    net.to(device)
    learned_train_set = False
    for epoch in range(80):  
        # loop over the dataset multiple times
        start_time = time.time()
        optimizer = optim.SGD(net.parameters(), lr=LR * LR_BASE ** true_epoch, momentum=MOMENTUM, weight_decay=DECAY)
        little_lambda = (1/len(trainloader)) * LAMBDA * BASE ** true_epoch
        true_epoch += 1
        logging.info(f'Starting epoch {true_epoch}')
        running_loss = 0.0
        true_train_total = 0.0
        correct_train_total = 0.0
        correct_train = 0.0
        total_train = 0.0
        total_equi_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            
            nonequiv_weights1 = net.fc1.weight
            nonequiv_weights1 = nonequiv_weights1.view(128, 10, 4, 4, 4)
            prime_weights = net.prior_mean.expand(128, 10, 4, 4, 4)
            if little_lambda:
                total_equi_loss += equivariant_loss(nonequiv_weights1, prime_weights).detach()*little_lambda
                loss += little_lambda * equivariant_loss(nonequiv_weights1, prime_weights)

            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            true_train_total += labels.size(0)
            correct_train += (predicted == labels).sum().item()
            correct_train_total += (predicted == labels).sum().item()

            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 500 == 499:  
                # print every 500 mini-batches
                msg = f'[{true_epoch}, {i + 1}] loss: {running_loss / total_train:.4f}, ' \
                      f'e-loss: {total_equi_loss / total_train:.4f}, ' \
                      f'Correct Rate: {100 * correct_train / total_train:.2f}%'
                print(msg)
                
                running_loss = 0.0
                correct_train = 0.0
                total_train = 0.0
                total_equi_loss = 0.0
        train_error_list.append(100*correct_train_total / true_train_total)
        print(f"Correctness on training epoch: {100*correct_train_total / true_train_total:.2f}%")
        logging.info(f"Correctness on training epoch: {100*correct_train_total / true_train_total:.2f}%")
        correct = 0
        total = 0
        with torch.no_grad():
            for val_data in validation_loader:
                images, labels = val_data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            test_error_list.append(100 * correct / total)
            print(f'Accuracy of the network on the 10000 validation images: {100 * correct / total:.2f}%')
            logging.info(f'Accuracy of the network on the 10000 validation images: {100 * correct / total:.2f}%')
            for trans_data in translate_loader:
                images, labels = trans_data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            test_error_list.append(100 * correct / total)
            print(f'Accuracy of the network on the 10000 translated images: {100 * correct / total:.2f}%')
            logging.info(f'Accuracy of the network on the 10000 translated images: {100 * correct / total:.2f}%')
            for rotate_data in rotation_loader:
                images, labels = rotate_data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            test_error_list.append(100 * correct / total)
            print(f'Accuracy of the network on the 10000 rotated images: {100 * correct / total:.2f}%')
            logging.info(f'Accuracy of the network on the 10000 rotated images: {100 * correct / total:.2f}%')
        total_time += time.time() - start_time
        print(f'Finished epoch, cumulative time: {total_time}s')
        if correct_train_total == true_train_total:
            if not learned_train_set:
                learned_train_set = True
            else:
                break
    print("Finished")
    logging.info("Finished")
    

In [8]:
with WindowsInhibitor():
    logging.info("-----------------------Starting Testing--------------------------------")
    for i in range(3,6):
        logging.info(f'-------------------------Starting test {i}----------------------------')
        run()
    logging.info(f'-------------------------Testing Finished----------------------------')

Preventing Windows from going to sleep
[1, 500] loss: 0.0235, e-loss: 0.0086, Correct Rate: 90.75%
Correctness on training epoch: 90.75%
Accuracy of the network on the 10000 validation images: 96.84%
Accuracy of the network on the 10000 translated images: 70.61%
Accuracy of the network on the 10000 rotated images: 76.98%
Finished epoch, cumulative time: 28.89766788482666s
[2, 500] loss: 0.0104, e-loss: 0.0050, Correct Rate: 96.74%
Correctness on training epoch: 96.74%
Accuracy of the network on the 10000 validation images: 96.57%
Accuracy of the network on the 10000 translated images: 71.61%
Accuracy of the network on the 10000 rotated images: 78.03%
Finished epoch, cumulative time: 56.465943574905396s
[3, 500] loss: 0.0090, e-loss: 0.0045, Correct Rate: 97.19%
Correctness on training epoch: 97.19%
Accuracy of the network on the 10000 validation images: 97.66%
Accuracy of the network on the 10000 translated images: 74.22%
Accuracy of the network on the 10000 rotated images: 79.78%
Fini