In [None]:
import sys
sys.path.append("..")
from inflation import BBI

import torch
import torch.nn.functional as F
import torch.nn as nn

import torchvision
import torchvision.transforms as transforms

batch_size = 50

transform = torchvision.transforms.ToTensor()

trainset = torchvision.datasets.MNIST(
             root='./mnist/',
             train=True,
             transform=transform,
             download= True
             )

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(
             root='./mnist/',
             train=False,
             transform=transform,
             download= True
             )

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)

In [2]:
# The convolutional network

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 5, 1, 2)
        self.pool = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16, 32, 5, 1, 2)
        self.fc1 = nn.Linear(32*7*7, 120)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x


criterion = nn.CrossEntropyLoss()

In [3]:
#First a small grid scan with small number of epochs
lrs = [.001, .01, 0.05,.1,.2,.3]
energies = [.0,.1,.5, 1.0, 2.]

scan_result = []
n_epochs = 3



v0 = 1e-6
threshold0 = 100
n_fixed_bounces = 5
threshold = 1000

for LR in lrs:
    for deltaEn in energies:
        print("lr: ", LR, "\tdeltaEn: ", deltaEn)
        
        net = Net()
        optimizer = BBI(net.parameters(), lr=LR, deltaEn = deltaEn, v0 = v0, threshold0 = threshold0, threshold = threshold, n_fixed_bounces = n_fixed_bounces)

        for epoch in range(n_epochs):  

            for i, data in enumerate(trainloader, 0):  
                
                inputs, labels = data
                optimizer.zero_grad()
                outputs = net(inputs)
                loss = criterion(outputs, labels)  
                loss.backward()
                def closure():
                    return loss  
                optimizer.step(closure)
                
        #Then evaluate the performance on the test set
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        print('\tAccuracy of the network on all the test images: %f' % (
            100 * correct / total))
        scan_result.append([LR, deltaEn,correct / total ])

lr:  0.001 	deltaEn:  0.0
	Accuracy of the network on all the test images: 76.780000
lr:  0.001 	deltaEn:  0.1
	Accuracy of the network on all the test images: 77.670000
lr:  0.001 	deltaEn:  0.5
	Accuracy of the network on all the test images: 76.460000
lr:  0.001 	deltaEn:  1.0
	Accuracy of the network on all the test images: 78.250000
lr:  0.001 	deltaEn:  2.0
	Accuracy of the network on all the test images: 75.090000
lr:  0.01 	deltaEn:  0.0
	Accuracy of the network on all the test images: 97.560000
lr:  0.01 	deltaEn:  0.1
	Accuracy of the network on all the test images: 97.650000
lr:  0.01 	deltaEn:  0.5
	Accuracy of the network on all the test images: 97.800000
lr:  0.01 	deltaEn:  1.0
	Accuracy of the network on all the test images: 97.400000
lr:  0.01 	deltaEn:  2.0
	Accuracy of the network on all the test images: 97.510000
lr:  0.05 	deltaEn:  0.0
	Accuracy of the network on all the test images: 98.850000
lr:  0.05 	deltaEn:  0.1
	Accuracy of the network on all the test image

In [4]:
#Then pick the best performer

i_max = 0
acc_max = scan_result[i_max][-1]
for i in range(len(scan_result)):
    acc = scan_result[i][-1]
    if acc > acc_max:
        acc_max = acc
        i_max = i
print(scan_result[i_max])

[0.2, 1.0, 0.9906]


In [None]:
#Not the runs in the paper (less statistics)

In [5]:
#Then a longer run to check the final accuracy
LR = scan_result[i_max][0]
deltaEn = scan_result[i_max][1]

n_epochs = 50
check_result = []


n_checks = 5
print("lr: ", LR, "\tdeltaEn: ", deltaEn, "\n")

for check in range(n_checks):
    print("Run: ", check )

    net = Net()
    optimizer = BBI(net.parameters(), lr=LR, deltaEn = deltaEn, v0 = v0, threshold0 = threshold0, threshold = threshold, n_fixed_bounces = n_fixed_bounces)

    for epoch in range(n_epochs):  
        tests = []
        for i, data in enumerate(trainloader, 0):  
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)  
            loss.backward()
            def closure():
                    return loss
            optimizer.step(closure)

        #Then evaluate the performance on the test set, at each epoch
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        if epoch%10 == 0: print('\tEpoch %d\t Accuracy: %f' % (epoch, 100 * correct / total))
        tests.append(correct/total)
    print('\tEpoch %d\t Accuracy: %f' % (epoch, 100 * correct / total))
    check_result.append(tests)

lr:  0.2 	deltaEn:  1.0 

Run:  0
	Epoch 0	 Accuracy: 98.350000
	Epoch 10	 Accuracy: 99.030000
	Epoch 20	 Accuracy: 99.200000
	Epoch 30	 Accuracy: 99.190000
	Epoch 40	 Accuracy: 99.230000
	Epoch 49	 Accuracy: 99.200000
Run:  1
	Epoch 0	 Accuracy: 96.960000
	Epoch 10	 Accuracy: 99.260000
	Epoch 20	 Accuracy: 99.160000
	Epoch 30	 Accuracy: 99.150000
	Epoch 49	 Accuracy: 99.140000
Run:  2
	Epoch 0	 Accuracy: 97.830000
	Epoch 10	 Accuracy: 99.080000
	Epoch 20	 Accuracy: 99.120000
	Epoch 30	 Accuracy: 99.110000
	Epoch 40	 Accuracy: 99.100000
	Epoch 49	 Accuracy: 99.120000
Run:  3
	Epoch 0	 Accuracy: 98.110000
	Epoch 10	 Accuracy: 99.240000
	Epoch 20	 Accuracy: 99.210000
	Epoch 30	 Accuracy: 99.140000
	Epoch 40	 Accuracy: 99.160000
	Epoch 49	 Accuracy: 99.180000
Run:  4
	Epoch 0	 Accuracy: 97.720000
	Epoch 10	 Accuracy: 99.070000
	Epoch 20	 Accuracy: 99.150000
	Epoch 30	 Accuracy: 99.100000
	Epoch 40	 Accuracy: 99.140000
	Epoch 49	 Accuracy: 99.190000


In [6]:
res_tensor = torch.tensor(check_result).flatten()
print(torch.mean(res_tensor))
print(torch.std(res_tensor))


tensor(0.9917)
tensor(0.0003)
