In [14]:
import numpy as np
import os
import torch
import torchvision
import torchvision.transforms as transforms

### Load dataset - Preprocessing
DATA_PATH = '/tmp/data'
BATCH_SIZE = 64

def load_mnist(path, batch_size):

    if not os.path.exists(path): os.mkdir(path)
    trans = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,), (1.0,))])

    train_set = torchvision.datasets.MNIST(root=path, train=True, 
                                           transform=trans, download=True)
    test_set = torchvision.datasets.MNIST(root=path, train=False, 
                                          transform=trans, download=True)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=batch_size,
        shuffle=True)

    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=batch_size,
        shuffle = False)

    return train_loader, test_loader

train_loader, test_loader = load_mnist(DATA_PATH, BATCH_SIZE)


### Build network
IN_SIZE = 28*28
HIDDEN_SIZE = 50
OUT_SIZE = 10
LR=0.001

class Net(torch.nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(IN_SIZE , HIDDEN_SIZE)
        self.l2 = torch.nn.Linear(HIDDEN_SIZE, OUT_SIZE)

    def forward(self, x):
        x = x.view(-1, IN_SIZE)
        x = torch.relu(self.l1(x))
        y_logits = self.l2(x)
        return y_logits

net = Net()
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
opti = torch.optim.SGD(net.parameters(), lr=LR)

### Training
NEPOCHS = 5

for epoch in range(NEPOCHS):

    for batch_idx, (X, y) in enumerate(train_loader):    
        net.zero_grad()
        y_logits = net(X)
        loss = criterion(y_logits, y)
        loss.backward()
        opti.step()

    
    preds = torch.empty(len(train_loader.dataset))
    y = torch.empty(len(train_loader.dataset))
    loss = 0
    for batch_idx, (bX, by) in enumerate(train_loader): 
        y_logits = net(bX)
        bloss = criterion(y_logits, by)
        bpreds = torch.argmax(y_logits, dim=1)
        preds[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE+len(bX)] = bpreds
        y[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE+len(bX)] = by
        loss += bloss
        
    acc = y.eq(preds).sum().float() / len(y)
    print('Epoch {}: Loss = {}, Accuracy = {}'.format(epoch+1, 
                                                      loss.data,
                                                      acc))
    
    
### Evaluate
preds = torch.empty(len(test_loader.dataset))
y = torch.empty(len(test_loader.dataset))
loss = 0
for batch_idx, (bX, by) in enumerate(test_loader): 
    y_logits = net(bX)
    bloss = criterion(y_logits, by)
    bpreds = torch.argmax(y_logits, dim=1)
    preds[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE+len(bX)] = bpreds
    y[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE+len(bX)] = by
    loss += bloss

acc = y.eq(preds).sum().float() / len(y)
print('Test Accuracy = {}'.format(acc))

Epoch 1: Loss = 19617.880859375, Accuracy = 0.9027166962623596
Epoch 2: Loss = 15585.22265625, Accuracy = 0.9244499802589417
Epoch 3: Loss = 13759.5146484375, Accuracy = 0.932533323764801
Epoch 4: Loss = 11149.7216796875, Accuracy = 0.9461333155632019
Epoch 5: Loss = 10233.53125, Accuracy = 0.94964998960495
Test Accuracy = 0.9467999935150146
