In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=False,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])),
    batch_size=32, shuffle=False)

In [3]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])),
    batch_size=32, shuffle=False)


In [4]:
class BasicNN(nn.Module):
    def __init__(self):
        super(BasicNN, self).__init__()
        self.net = nn.Linear(28 * 28, 10)
    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        output = self.net(x)
        return F.softmax(output)

In [5]:
model = BasicNN()
optimizer = optim.SGD(model.parameters(), lr=0.001)

In [6]:
def test():
    total_loss = 0
    correct = 0
    for image, label in test_loader:
        image, label = Variable(image), Variable(label)
        output = model(image)
        total_loss += F.cross_entropy(output, label)
        correct += (torch.max(output, 1)[1].view(label.size()).data == label.data).sum()
    total_loss = total_loss.data[0] / len(test_loader)
    accuracy = correct / len(test_loader.dataset)
    return total_loss, accuracy


In [7]:
def train():
    model.train()
    for image, label in train_loader:
        image, label = Variable(image), Variable(label)
        optimizer.zero_grad()
        output = model(image)
        loss = F.cross_entropy(output, label)
        loss.backward()
        optimizer.step()

In [None]:
best_test_loss = None
for e in range(1, 150):
    train()
    test_loss, test_accuracy = test()
    print("\n[Epoch: %d] Test Loss:%5.5f Test Accuracy:%5.5f" % (e, test_loss, test_accuracy))
    # Save the model if the test_loss is the lowest
    if not best_test_loss or test_loss < best_test_loss:
        best_test_loss = test_loss
    else:
        break
print("\nFinal Results\n-------------\n""Loss:", best_test_loss, "Test Accuracy: ", test_accuracy)


[Epoch: 1] Test Loss:2.27352 Test Accuracy:0.44360

[Epoch: 2] Test Loss:2.22371 Test Accuracy:0.45100

[Epoch: 3] Test Loss:2.16380 Test Accuracy:0.49840

[Epoch: 4] Test Loss:2.09973 Test Accuracy:0.51520

[Epoch: 5] Test Loss:2.04782 Test Accuracy:0.56200

[Epoch: 6] Test Loss:2.00434 Test Accuracy:0.60630

[Epoch: 7] Test Loss:1.96735 Test Accuracy:0.62930

[Epoch: 8] Test Loss:1.93913 Test Accuracy:0.64160

[Epoch: 9] Test Loss:1.91655 Test Accuracy:0.65620

[Epoch: 10] Test Loss:1.89545 Test Accuracy:0.68240

[Epoch: 11] Test Loss:1.87484 Test Accuracy:0.70650

[Epoch: 12] Test Loss:1.85802 Test Accuracy:0.71700

[Epoch: 13] Test Loss:1.84345 Test Accuracy:0.72550

[Epoch: 14] Test Loss:1.82930 Test Accuracy:0.74690

[Epoch: 15] Test Loss:1.81557 Test Accuracy:0.77430

[Epoch: 16] Test Loss:1.80372 Test Accuracy:0.78770

[Epoch: 17] Test Loss:1.79372 Test Accuracy:0.79150

[Epoch: 18] Test Loss:1.78501 Test Accuracy:0.79350

[Epoch: 19] Test Loss:1.77731 Test Accuracy:0.79600

[