In [0]:
# Heavily inspired by: https://github.com/pytorch/examples/blob/master/mnist/main.py
import os
import torch
from torch import nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [0]:
class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 10),
            nn.Sigmoid(),
        )
        
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        output = self.model(img_flat)
        return output

In [0]:
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        
    train_loss /= len(train_loader.dataset)
    print('Train Epoch: {}, Average Loss: {:.6f}'.format(
        epoch, train_loss))

In [0]:
def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [5]:
def main():
    # Load data
    root = './data'
    if not os.path.exists(root):
        os.mkdir(root)

    # normalizes values to interval [-1, 1]
    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_set = dset.MNIST(root=root, train=True, transform=trans, download=True)
    test_set = dset.MNIST(root=root, train=False, transform=trans, download=True)

    batch_size = 64
    train_loader = torch.utils.data.DataLoader(
                     dataset=train_set,
                     batch_size=batch_size,
                     shuffle=True)
    test_loader = torch.utils.data.DataLoader(
                    dataset=test_set,
                    batch_size=batch_size,
                    shuffle=False)
    
    # Cuda stuff
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print("Device is " + str(device) + ".")
    
    # Training
    epochs = 10
    model = NN().to(device)
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    optimizer = torch.optim.Adam(model.parameters())
    
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, criterion, epoch)
        test(model, device, test_loader, criterion)
    
if __name__ == '__main__':
    main()

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:03, 3116342.13it/s]                             


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 48831.17it/s]                           
0it [00:00, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:02, 736011.49it/s]                             
0it [00:00, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 18403.26it/s]            


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!
Device is cuda.
Train Epoch: 1, Average Loss: 1.566326
Test set: Average loss: 1.5192, Accuracy: 9299/10000 (93%)
Train Epoch: 2, Average Loss: 1.508968
Test set: Average loss: 1.5002, Accuracy: 9537/10000 (95%)
Train Epoch: 3, Average Loss: 1.498298
Test set: Average loss: 1.4950, Accuracy: 9581/10000 (96%)
Train Epoch: 4, Average Loss: 1.492261
Test set: Average loss: 1.4880, Accuracy: 9660/10000 (97%)
Train Epoch: 5, Average Loss: 1.489244
Test set: Average loss: 1.4902, Accuracy: 9642/10000 (96%)
Train Epoch: 6, Average Loss: 1.486266
Test set: Average loss: 1.4872, Accuracy: 9695/10000 (97%)
Train Epoch: 7, Average Loss: 1.483044
Test set: Average loss: 1.4908, Accuracy: 9601/10000 (96%)
Train Epoch: 8, Average Loss: 1.483542
Test set: Average loss: 1.4870, Accuracy: 9686/10000 (97%)
Train Epoch: 9, Average Loss: 1.481793
Test set: Average loss: 1.4857, Accuracy: 9698/10000 (97%)
Train Epoch: 10, Average Los