In [1]:
import sys
sys.path.append("..")
from inflation import BBI

import torch
import torch.nn.functional as F
import torch.nn as nn

import torchvision
import torchvision.transforms as transforms

batch_size = 50

transform = torchvision.transforms.ToTensor()

trainset = torchvision.datasets.MNIST(
             root='./mnist/',
             train=True,
             transform=transform,
             download= True
             )

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(
             root='./mnist/',
             train=False,
             transform=transform,
             download= True
             )

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)

In [2]:
# The convolutional network

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 5, 1, 2)
        self.pool = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16, 32, 5, 1, 2)
        self.fc1 = nn.Linear(32*7*7, 120)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x


criterion = nn.CrossEntropyLoss()

In [3]:
#First a grid scan with small number of epochs
lrs =  [.001, .01, 0.05,.1,.2,.3]
momenta = [.85, .9,.95,.99, .999]

scan_result = []
n_epochs = 3

for LR in lrs:
    for MOMENTUM in momenta:
        print("lr: ", LR, "\tmomentum: ", MOMENTUM)
        
        net = Net()
        optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum = MOMENTUM)

        for epoch in range(n_epochs):  

            for i, data in enumerate(trainloader, 0):  
                
                inputs, labels = data
                optimizer.zero_grad()
                outputs = net(inputs)
                loss = criterion(outputs, labels)  
                loss.backward()  
                optimizer.step()

        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        print('\tAccuracy of the network on all the test images: %f ' % (
            100 * correct / total))
        scan_result.append([LR, MOMENTUM,correct / total ])

lr:  0.001 	momentum:  0.85
	Accuracy of the network on all the test images: 95.930000 
lr:  0.001 	momentum:  0.9
	Accuracy of the network on all the test images: 96.830000 
lr:  0.001 	momentum:  0.95
	Accuracy of the network on all the test images: 97.980000 
lr:  0.001 	momentum:  0.99
	Accuracy of the network on all the test images: 98.700000 
lr:  0.001 	momentum:  0.999
	Accuracy of the network on all the test images: 94.300000 
lr:  0.01 	momentum:  0.85
	Accuracy of the network on all the test images: 98.650000 
lr:  0.01 	momentum:  0.9
	Accuracy of the network on all the test images: 98.750000 
lr:  0.01 	momentum:  0.95
	Accuracy of the network on all the test images: 98.980000 
lr:  0.01 	momentum:  0.99
	Accuracy of the network on all the test images: 94.700000 
lr:  0.01 	momentum:  0.999
	Accuracy of the network on all the test images: 10.280000 
lr:  0.05 	momentum:  0.85
	Accuracy of the network on all the test images: 98.740000 
lr:  0.05 	momentum:  0.9
	Accuracy of

In [4]:
#Pick the best performer
i_max = 0
acc_max = scan_result[i_max][-1]
for i in range(len(scan_result)):
    acc = scan_result[i][-1]
    if acc > acc_max:
        acc_max = acc
        i_max = i
print(scan_result[i_max])

[0.01, 0.95, 0.9898]


In [None]:
#Not the runs in the paper (less statistics)

In [5]:
#Then a longer run to check the final accuracy
LR = scan_result[i_max][0]
MOMENTUM = scan_result[i_max][1]

n_epochs = 50
check_result = []


n_checks = 5
print("lr: ", LR, "\tmomentum: ", MOMENTUM, "\n")

for check in range(n_checks):
    print("Run: ", check )

    net = Net()
    optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum = MOMENTUM)

    for epoch in range(n_epochs):  
        tests = []
        for i, data in enumerate(trainloader, 0):  

            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)  
            loss.backward()  
            optimizer.step()

        #Then evaluate the performance on the test set, at each epoch
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        if epoch%10 == 0: print('\tEpoch %d\t Accuracy: %f' % (epoch, 100 * correct / total))
        tests.append(correct/total)
    print('\tEpoch %d\t Accuracy: %f' % (epoch, 100 * correct / total))
    check_result.append(tests)

lr:  0.01 	momentum:  0.95 

Run:  0
	Epoch 0	 Accuracy: 98.100000
	Epoch 10	 Accuracy: 98.960000
	Epoch 20	 Accuracy: 98.880000
	Epoch 30	 Accuracy: 99.040000
	Epoch 40	 Accuracy: 99.040000
	Epoch 49	 Accuracy: 99.040000
Run:  1
	Epoch 0	 Accuracy: 97.700000
	Epoch 10	 Accuracy: 98.860000
	Epoch 20	 Accuracy: 98.900000
	Epoch 30	 Accuracy: 98.870000
	Epoch 40	 Accuracy: 99.070000
	Epoch 49	 Accuracy: 99.070000
Run:  2
	Epoch 0	 Accuracy: 98.410000
	Epoch 10	 Accuracy: 99.080000
	Epoch 20	 Accuracy: 99.250000
	Epoch 30	 Accuracy: 99.260000
	Epoch 40	 Accuracy: 99.260000
	Epoch 49	 Accuracy: 99.250000
Run:  3
	Epoch 0	 Accuracy: 98.350000
	Epoch 10	 Accuracy: 98.800000
	Epoch 20	 Accuracy: 98.980000
	Epoch 30	 Accuracy: 99.080000
	Epoch 40	 Accuracy: 99.070000
	Epoch 49	 Accuracy: 99.100000
Run:  4
	Epoch 0	 Accuracy: 98.400000
	Epoch 10	 Accuracy: 99.090000
	Epoch 20	 Accuracy: 99.040000
	Epoch 30	 Accuracy: 99.180000
	Epoch 40	 Accuracy: 99.200000
	Epoch 49	 Accuracy: 99.190000


In [6]:
res_tensor = torch.tensor(check_result).flatten()
print(torch.mean(res_tensor))
print(torch.std(res_tensor))


tensor(0.9913)
tensor(0.0009)
