In [11]:
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms,utils
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from torch.nn import functional as F
from dfw import DFW
from dfw.losses import MultiClassHingeLoss

import itertools


In [2]:
BATCH_SIZE = 50

In [3]:
ts = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST('data/', train=True, download=True, transform=ts)
mnist_test = datasets.MNIST('data/', train=False, download=True, transform=ts)
train_data = DataLoader(mnist_train,batch_size=3000, shuffle = True)
test_data = DataLoader(mnist_test,batch_size=1000, shuffle = True)



In [2]:
l = [1, 2, 3]
for i,j in enumerate(l):
    print(i)

0
1
2


In [168]:
class DigitNet(nn.Module):
    """Used for digit classification"""

    def __init__(self):
        super(DigitNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3)
        self.fc1 = nn.Linear(16*5*5, 200)
        self.fc2 = nn.Linear(200, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
        x = F.relu(self.fc1(x.view(-1, 16*5*5)))
        x = self.fc2(x)
        # no need for softmax here because crossEntropyLoss() already applies one
        return x


In [169]:
class Model:
  
    def __init__(self):
        self.model = DigitNet()
       
    def train(self, train_data,nb_epochs=10, lr=1e-1, verbose=True):
        
        # create DFW optimizer 
        optimizer = DFW(self.model.parameters(), eta=lr)

        
        criterion = MultiClassHingeLoss()
        
        # SGD optimizer
        #optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)

        # Binary Cross Entropy Loss used for digit comparison
        #criterion = nn.CrossEntropyLoss()
        

        for e in range(nb_epochs):
            if verbose and e  != 0:
                print("Epochs {}".format(e))
                print("loss = {}".format(loss))
            
            for images, labels in iter(train_data):
                
                # Forward pass
                model_output = self.model(images)
                
                loss = criterion(model_output,labels)
                
                # Apply the backward step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step(lambda: float(loss))
        

    def test(self, test_data):
        """Test method using the error rate as a metric """
        # Init the number of correct predictions
        nb_correct = 0
                
        # Number of samples
        N = len(test_data) * BATCH_SIZE 

        for images, labels in iter(test_data):
            # Run the model on a mini batch of the images
            model_output = self.model(images)
                
            # Get the targets
            predicted_labels = torch.argmax(model_output,1,keepdim=True).view(labels.size()[0])
                
            # Count the number of correct predictions
            nb_correct +=(predicted_labels == labels).int().sum().item()
            
        return nb_correct / N
       

In [170]:
model = Model()
model.train(train_data)

Epochs 1
loss = 0.04388212040066719
Epochs 2
loss = 0.023867564275860786
Epochs 3
loss = 0.025275349617004395
Epochs 4
loss = 0.012584369629621506
Epochs 5
loss = 0.00035743237822316587
Epochs 6
loss = 0.005485172383487225
Epochs 7
loss = 0.004172940272837877
Epochs 8
loss = 0.0
Epochs 9
loss = 0.0


In [171]:
model.test(test_data)

0.9887