In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.cuda.current_device()
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
import time
import os
import logging
import numpy as np
from my_layers import LatticeLocalSL2, GroupLocalSL2, GroupMaxPool, GroupReLU
logging.basicConfig(level=logging.INFO, filename='MNIST_test.log', 
                    format='%(asctime)s:%(name)s:%(levelname)s:%(message)s')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [10]:
indices = list(range(60_000))
# np.random.shuffle(indices)
train_indices = indices[:50_000]
valid_indices = indices[50_000:]


In [11]:
transformation = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)

train_set = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transformation)
valid_set = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transformation)

classes = ('0', '1', '2', '3',
           '4', '5', '6', '7', '8', '9')
NUM_CLASSES = 10

In [12]:
class DoubleMNIST(nn.Module):
    def __init__(self, first_channels=32, second_channels=64, fc_channels=128):
        super().__init__()
        self.gconv1 = LatticeLocalSL2(1, first_channels, 5, 2, len_fun="len", group='SL2', pad_type='partial')
        self.gconv2 = GroupLocalSL2(first_channels, second_channels, 5, 2, len_fun="len")
        self.fc1 = nn.Linear(second_channels * 8 * 4 * 4, fc_channels)
        self.fc2 = nn.Linear(fc_channels, 10)
        self.pool = GroupMaxPool(2, 2)
        self.grelu = GroupReLU()

    def forward(self, x):
        dict1, x = self.gconv1(x)
        x = self.pool(self.grelu(x))
        dict2, x = self.gconv2(x, dict1)
        x = self.pool(self.grelu(x))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [13]:
import torch.optim as optim

predict_criterion = nn.CrossEntropyLoss()

In [14]:
    
def train(net):
    learned_train_set = False
    true_epoch = 0
    total_time = 0
    train_error_list = []
    test_error_list = []
    testloader = torch.utils.data.DataLoader(valid_set, batch_size=128, sampler=valid_sampler,
                                         shuffle=False, num_workers=2, pin_memory=False)
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=20, sampler=train_sampler,
                                          shuffle=False, num_workers=2, pin_memory=True)
    optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5)
    for epoch in range(30):  # loop over the dataset multiple times
        start_time = time.time()
        true_epoch += 1
        running_loss = 0.0
        true_train_total = 0.0
        correct_train_total = 0.0
        correct_train = 0.0
        total_train = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(inputs)
            loss = predict_criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            true_train_total += labels.size(0)
            correct_train += (predicted == labels).sum().item()
            correct_train_total += (predicted == labels).sum().item()
            
            loss.backward()
            optimizer.step()        
    
            # print statistics
            running_loss += loss.item()
            if i % 500 == 499:  # print every 100 mini-batches
                print(f'[{true_epoch}, {i + 1}] loss: {running_loss / 100:.4f}, Correct Rate: {100 * correct_train / total_train:.2f}%, Cumulative time: {time.time() - start_time}s')
                running_loss = 0.0
                correct_train = 0.0
                total_train = 0.0
        train_error_list.append(100*correct_train_total / true_train_total) 
        print(f"Correctness on training epoch {true_epoch}: {100*correct_train_total / true_train_total:.2f}%")
        logging.info(f"Correctness on training epoch {true_epoch}: {100*correct_train_total / true_train_total:.2f}%")
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            test_error_list.append(100 * correct / total)
        print(f'Accuracy of the network on the 10000 test images for epoch {true_epoch}: {100 * correct / total:.2f}%')
        logging.info(f'Accuracy of the network on the 10000 test images for epoch {true_epoch}: {100 * correct / total:.2f}%')
        total_time += time.time() - start_time
        print(f'Finished epoch {true_epoch}, cumulative time: {total_time}s')
        if correct_train_total == true_train_total:
            if not learned_train_set:
                learned_train_set = True
            else:
                break
    print("Finished")
    logging.info("Finished Training")

In [15]:
logging.info("------------------------Beginning Test---------------------")
for i in range(6):
    print(f'Starting test {i+1}')
    logging.info(f'Starting test {i+1}')
    net = DoubleMNIST(9,32,128)
    net.to(device)
    net.gconv1.flat_indices = net.gconv1.flat_indices.to(device)
    logging.info(repr(net))
    train(net)
logging.info("------------------------Ending Test---------------------")