In [1]:
# import libraries
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional
import torch.optim
import torch.linalg as linalg
import matplotlib.pyplot as plt

from models import *
from plotting import *
from globals import *

In [2]:
random_seed = 1234
torch.manual_seed(random_seed)
n_epochs = 50
batch_size_train = 20
batch_size_test = 40

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

cuda:0


In [4]:
# Loading MNIST data
import torchvision
import torchvision.transforms as transforms

mnist_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])#, transforms.Lambda(lambda x : torch.flatten(x))])

trainset = torchvision.datasets.MNIST(root='./data/', train=True, download=True, transform=mnist_transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size_train, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data/', train=False, download=True, transform=mnist_transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size_test, shuffle=False)

In [None]:
# Loading CIFAR data
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data/', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size_train, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data/', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size_test, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
results_folder = './cifar10_results'

In [5]:
def test_train(net, trainloader, testloader, name="", n_epochs=4, lr=0.001):

    loss_fn = nn.CrossEntropyLoss(reduction='mean')
    optim = torch.optim.SGD(net.parameters(), lr=lr)

    train_losses = []
    train_counter = []
    test_losses = []
    test_accuracy = []

    def train(epoch):
        net.train()
        for batch_idx, (data, target_idx) in enumerate(trainloader):

            optim.zero_grad()

            data = data.float().to(DEVICE)
            #data = torch.squeeze(data).float().to(DEVICE)
            target_idx = target_idx.to(DEVICE)
            target = nn.functional.one_hot(target_idx,num_classes=10).float()
            output = net(data)
            loss = loss_fn(output, target)

            with torch.no_grad():
                net.update_backwards()

            loss.backward()
            optim.step()

            if batch_idx % 1000 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(trainloader.dataset),
                    100. * batch_idx / len(trainloader), loss.item()))
                train_losses.append(loss.item())
                train_counter.append(
                    (batch_idx*64) + ((epoch-1)*len(trainloader.dataset)))
                #return None
                
    

    def test():
        net.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target_idx in testloader:
                #data = torch.squeeze(data).float().to(DEVICE)
                data=data.float().to(DEVICE)
                target_idx = target_idx.to(DEVICE)
                target=nn.functional.one_hot(target_idx,num_classes=10).float()
                output = net(data)
                test_loss += loss_fn(output, target).item()
                pred_idx = torch.argmax(output.data, dim=-1)
                correct += pred_idx.eq(target_idx.data).sum().item()
            test_loss /= len(testloader.dataset)
            test_losses.append(test_loss)
            test_accuracy.append(100. * correct / len(testloader.dataset))
            print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
                    test_loss, correct, len(testloader.dataset),
                    100. * correct / len(testloader.dataset)))
            
    torch.save(net.state_dict(), './results/model_{0}.pth'.format(name))
    torch.save(optim.state_dict(), './results/optimizer_{0}.pth'.format(name))
            
    #test()
    for epoch in range(1, n_epochs + 1):
        train(epoch)
        test()
    
    return {"train_losses": train_losses, "train_counter": train_counter, "test_losses" : test_losses, "test_accuracy" : test_accuracy}

In [6]:
bp_net = FullyConnected(grad_type='pseudo')

In [6]:
convnet = ConvMNIST(grad_type="pseudo")

In [7]:
ps_data = test_train(convnet, trainloader, testloader, name="convbp", n_epochs=4, lr=0.001)


Test set: Avg. loss: 0.0114, Accuracy: 8718/10000 (87%)



KeyboardInterrupt: 

In [None]:
plt.plot(bp_data['test_losses'])
#plt.yscale('log')
plt.plot()

In [None]:


z = torch.randn((x.shape), dtype=torch.double, device=DEVICE, requires_grad=True)

import torch.autograd.gradcheck

net = ConvBP()

torch.autograd.gradcheck(net, z, eps=1e-4, atol=1e-4, nondet_tol=1.0, raise_exception=True)

In [None]:
x, t = next(iter(trainloader))
x = x.float().to(DEVICE)
y = bp_convnet(x)
print(torch.argmax(y,dim=-1))
print(t)

In [None]:
ps_data = test_train(ps_net, trainloader, testloader, name="pseudo")
torch.save(ps_data, './results/ps_data.pth')

In [None]:
plt.plot(bp_data['test_losses'])
plt.plot(ps_data['test_losses'])
plt.yscale('log')
plt.show()

In [None]:
plt.plot(bp_data['test_accuracy'])
plt.plot(ps_data['test_accuracy'])
plt.yscale('log')
plt.show()

In [None]:
input, target = testset[0]

output = net(input)

In [None]:
softmax = nn.Softmax(dim=0)
out_softmax = softmax(output)

In [None]:
out_softmax.round()

In [None]:
plot_mnist(input)

In [None]:
from plotting import *

In [None]:
a = torch.randn((10,4,3,3))
b = torch.Tensor([2,2,2,2])

In [None]:
c = b[None,...,None,None] * a

In [None]:
(c != c).any()