In [1]:
import torch

from inflation import BBI, BBI_v0tuning

import torch.nn.functional as F
import torch.nn as nn

import torchvision
import torchvision.transforms as transforms

import warnings
warnings.filterwarnings('ignore')

batch_size = 50

transform = torchvision.transforms.ToTensor()

trainset = torchvision.datasets.MNIST(
             root='./mnist/',
             train=True,
             transform=transform,
             download= True
             )

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(
             root='./mnist/',
             train=False,
             transform=transform,
             download= True
             )

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)



In [2]:
# The convolutional network

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 5, 1, 2)
        self.pool = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16, 32, 5, 1, 2)
        self.fc1 = nn.Linear(32*7*7, 120)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x


criterion = nn.CrossEntropyLoss()

In [3]:
v0 = 1e-20
threshold0 = 100
n_fixed_bounces = 5
threshold = 1000
LR = .2
deltaEn = .0
weight_decay = 0.0

In [4]:
#This is a run with a tiny v0 (DeltaV), which is never achieved

n_epochs = 20
check_result = []


n_checks = 5

for check in range(n_checks):
    print("Run: ", check )

    net = Net()
    optimizer = BBI(net.parameters(), lr=LR, deltaEn = deltaEn, v0 = v0, threshold0 = threshold0, threshold = threshold, n_fixed_bounces = n_fixed_bounces)

    for epoch in range(n_epochs):  
        tests = []
        for i, data in enumerate(trainloader, 0):  
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)  
            loss.backward()
            def closure():
                    return loss
            optimizer.step(closure)

        #Then evaluate the performance on the test set, at each epoch
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        if epoch%10 == 0: print('\tEpoch %d\t\t Accuracy: %f\t Loss: %.20f' % (epoch, 100 * correct / total, loss.item()))
        tests.append(correct/total)
    print('\tEpoch %d\t\t Accuracy: %f\t Loss: %.20f' % (epoch, 100 * correct / total, loss.item()))
    
    check_result.append(tests)

Run:  0
	Epoch 0		 Accuracy: 98.770000	 Loss: 0.16509181261062622070
	Epoch 10		 Accuracy: 99.150000	 Loss: 0.00000046729527980460
	Epoch 19		 Accuracy: 99.170000	 Loss: 0.00006044310066499747
Run:  1
	Epoch 0		 Accuracy: 97.920000	 Loss: 0.02141444385051727295
	Epoch 10		 Accuracy: 99.290000	 Loss: 0.00001236405296367593
	Epoch 19		 Accuracy: 99.230000	 Loss: 0.00002361184851906728
Run:  2
	Epoch 0		 Accuracy: 97.750000	 Loss: 0.32982391119003295898
	Epoch 10		 Accuracy: 99.230000	 Loss: 0.00001444772806280525
	Epoch 19		 Accuracy: 99.250000	 Loss: 0.00000003814693627646
Run:  3
	Epoch 0		 Accuracy: 98.060000	 Loss: 0.01596713997423648834
	Epoch 10		 Accuracy: 99.180000	 Loss: 0.00000384775921702385
	Epoch 19		 Accuracy: 99.230000	 Loss: 0.00000113007581603597
Run:  4
	Epoch 0		 Accuracy: 97.850000	 Loss: 0.00033564228215254843
	Epoch 10		 Accuracy: 99.370000	 Loss: 0.00010727781045716256
	Epoch 19		 Accuracy: 99.320000	 Loss: 0.00000021219146617568


In [5]:
res_tensor = torch.tensor(check_result).flatten()
print(torch.mean(res_tensor))
print(torch.std(res_tensor))


tensor(0.9924)
tensor(0.0005)


In [6]:
#This shows that even starting with an higher v0(deltaV), BBI self tunes it
v0 = 0.0001
n_epochs = 20
check_result = []


n_checks = 5

for check in range(n_checks):
    print("Run: ", check )

    net = Net()
    optimizer = BBI_v0tuning(net.parameters(), lr=LR, deltaEn = deltaEn, v0 = v0, threshold0 = threshold0, threshold = threshold, n_fixed_bounces = n_fixed_bounces, weight_decay = weight_decay, v0_tuning = True)

    for epoch in range(n_epochs):  
        tests = []
        for i, data in enumerate(trainloader, 0):  
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)  
            loss.backward()
            def closure():
                    return loss
            optimizer.step(closure)

        #Then evaluate the performance on the test set, at each epoch
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        if epoch%10 == 0: print('\tEpoch %d\t\t Accuracy: %f\t Loss: %.20f' % (epoch, 100 * correct / total, loss.item()))
        tests.append(correct/total)
    print('\tEpoch %d\t\t Accuracy: %f\t Loss: %.20f' % (epoch, 100 * correct / total, loss.item()))
    
    check_result.append(tests)

Run:  0
	Epoch 0		 Accuracy: 98.420000	 Loss: 0.08118575811386108398
Shifting v0, remember this is still in development!
New v0:  -0.0002124840102624148
	Epoch 10		 Accuracy: 99.080000	 Loss: 0.00000737064692657441
	Epoch 19		 Accuracy: 99.260000	 Loss: 0.00000000000000000000
Run:  1
	Epoch 0		 Accuracy: 97.160000	 Loss: 0.30432710051536560059
Shifting v0, remember this is still in development!
New v0:  -4.9153204599861056e-05
	Epoch 10		 Accuracy: 99.320000	 Loss: 0.00009690123260952532
	Epoch 19		 Accuracy: 99.230000	 Loss: 0.00188360107131302357
Run:  2
Shifting v0, remember this is still in development!
New v0:  -3.2778050808701664e-05
	Epoch 0		 Accuracy: 97.910000	 Loss: 0.15555025637149810791
	Epoch 10		 Accuracy: 99.270000	 Loss: 0.01298281550407409668
	Epoch 19		 Accuracy: 99.260000	 Loss: 0.00000032424702567369
Run:  3
	Epoch 0		 Accuracy: 98.510000	 Loss: 0.00991395208984613419
Shifting v0, remember this is still in development!
New v0:  8.727991371415555e-05
Shifting v0, re

In [7]:
res_tensor = torch.tensor(check_result).flatten()
print(torch.mean(res_tensor))
print(torch.std(res_tensor))


tensor(0.9925)
tensor(0.0003)
