In [1]:
%pylab inline
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim

import sklearn.metrics as metrics
from sklearn.model_selection import train_test_split

Populating the interactive namespace from numpy and matplotlib


In [2]:
root = './data2'
download = True # download MNIST dataset or not
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = dset.FashionMNIST(root=root, train=True, transform=trans, download=download)
test_set = dset.FashionMNIST(root=root, train=False, transform=trans)

batch_size = 100
n_iters = 5500
num_epochs = n_iters / (len(train_set) / batch_size)
num_epochs = int(num_epochs)

train_loader = th.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = th.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
shuffle=False)

In [5]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def classify(self, x):
        x = self.forward(x)
        x = x.data.numpy()
        return np.argmax(x, axis=1)


In [6]:
model = LeNet()
optimizer = optim.Adagrad(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
epochs = 10

In [7]:
# building on work here https://github.com/mayurbhangale/fashion-mnist-pytorch/blob/master/CNN_Fashion_MNIST.ipynb and
# using the convolutional network in tutorial. 
iter = 0
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        
        images = Variable(images)
        labels = Variable(labels)
        
        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()
        
        # Forward pass to get output/logits
        outputs = model(images)
        
        # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, labels)
        
        # Getting gradients w.r.t. parameters
        loss.backward()
        
        # Updating parameters
        optimizer.step()
        
        iter += 1
        
        if iter % 500 == 0:
            # Calculate Accuracy         
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
                images = Variable(images)
                
                # Forward pass only to get logits/output
                outputs = model(images)
                
                # Get predictions from the maximum value
                _, predicted = th.max(outputs.data, 1)
                
                # Total number of labels
                total += labels.size(0)
                
                correct += (predicted == labels).sum()
            
            accuracy = 100 * correct / total
            
            # Print Loss
            print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.data[0], accuracy))
print("Accuracy : {}".format(accuracy))

Iteration: 500. Loss: 0.3890562057495117. Accuracy: 82.92
Iteration: 1000. Loss: 0.4187420606613159. Accuracy: 85.43
Iteration: 1500. Loss: 0.3572849631309509. Accuracy: 86.7
Iteration: 2000. Loss: 0.32119765877723694. Accuracy: 87.17
Iteration: 2500. Loss: 0.2694917321205139. Accuracy: 87.71
Iteration: 3000. Loss: 0.32186535000801086. Accuracy: 88.0
Iteration: 3500. Loss: 0.20430637896060944. Accuracy: 88.23
Iteration: 4000. Loss: 0.33797308802604675. Accuracy: 88.46
Iteration: 4500. Loss: 0.3813328146934509. Accuracy: 88.23
Iteration: 5000. Loss: 0.21985380351543427. Accuracy: 88.83
Accuracy : 88.83
