In [None]:
# import libraries
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional
import torch.optim
import torch.linalg as linalg
import matplotlib.pyplot as plt

In [None]:
from models import *
from plotting import *

In [None]:
random_seed = 1234
torch.manual_seed(random_seed)
n_epochs = 3
batch_size_train = 40
batch_size_test = 1
num_classes = 10

results_folder = './cifar10_results'

In [None]:
# Loading MNIST data
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)), transforms.Lambda(lambda x : torch.flatten(x))])

trainset = torchvision.datasets.CIFAR10(root='./data/', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size_train, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data/', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size_test, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
def test_train(net, trainloader, testloader, name="", n_epochs=4, lr=0.0001):

    loss_fn = nn.CrossEntropyLoss(reduction='mean')
    optim = torch.optim.SGD(net.parameters(), lr=lr)

    train_losses = []
    train_counter = []
    test_losses = []
    test_accuracy = []

    def train(epoch):
        net.train()
        for batch_idx, (data, target) in enumerate(trainloader):

            optim.zero_grad()

            for i in range(batch_size_train):
                d=data[i,:].float()
                t=nn.functional.one_hot(target[i,:],num_classes=num_classes).float()
                #optim.zero_grad()
                o = net(d)
                loss += loss_fn(o, t)

            loss.backward()
            optim.step()

            with torch.no_grad():
                    net.update_backwards()

            if batch_idx % 1000 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(trainloader.dataset),
                    100. * batch_idx / len(trainloader), loss.item()))
                train_losses.append(loss.item())
                train_counter.append(
                    (batch_idx*64) + ((epoch-1)*len(trainloader.dataset)))
    

    def test():
        net.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target_idx in testloader:
                data=torch.squeeze(data).float()
                target_idx = torch.squeeze(target_idx)
                target=nn.functional.one_hot(torch.squeeze(target_idx),num_classes=num_classes.float())
                output = net(data)
                test_loss += loss_fn(output, target).item()
                pred_idx = torch.argmax(output.data)
                correct += pred_idx.eq(target_idx.data).sum()
            test_loss /= len(testloader.dataset)
            test_losses.append(test_loss)
            test_accuracy.append(100. * correct / len(testloader.dataset))
            print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
                    test_loss, correct, len(testloader.dataset),
                    100. * correct / len(testloader.dataset)))
            
    torch.save(net.state_dict(), './results/model_{0}.pth'.format(name))
    torch.save(optim.state_dict(), './results/optimizer_{0}.pth'.format(name))
            
    test()
    for epoch in range(1, n_epochs + 1):
        train(epoch)
        test()
    
    return {"train_losses": train_losses, "train_counter": train_counter, "test_losses" : test_losses, "test_accuracy" : test_accuracy}

In [None]:
bp_net = FullyConnected(n_hidden=3, input_size=(32*32*3), hidden_size=1024, output_size=10, grad_type='backprop')
bp_data = test_train(bp_net, trainloader, testloader, name="backprop")
torch.save(bp_data, results_folder+'/bp_data.pth')

In [None]:
ps_net = FullyConnected(n_hidden=3, input_size=(32*32*3), hidden_size=1024, output_size=10, grad_type='pseudo')
ps_data = test_train(ps_net, trainloader, testloader, name="pseudo")
torch.save(ps_data, results_folder+'/ps_data.pth')

In [None]:
rd_net = FullyConnected(n_hidden=3, input_size=(32*32*3), hidden_size=1024, output_size=10, grad_type='random')
rd_data = test_train(rd_net, trainloader, testloader, name="random")
torch.save(rd_data, results_folder+'/rd_data.pth')