In [None]:
import torch
from torch import optim
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt

In [None]:
data_path = 'downloads/'

In [None]:
mnist = datasets.MNIST(data_path, train=True, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to downloads/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting downloads/MNIST/raw/train-images-idx3-ubyte.gz to downloads/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to downloads/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting downloads/MNIST/raw/train-labels-idx1-ubyte.gz to downloads/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to downloads/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting downloads/MNIST/raw/t10k-images-idx3-ubyte.gz to downloads/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to downloads/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting downloads/MNIST/raw/t10k-labels-idx1-ubyte.gz to downloads/MNIST/raw



In [None]:
mnist_val = datasets.MNIST(data_path, train=False, download=True)

In [None]:
mnist = datasets.MNIST(data_path, train=True, download=False, transform=transforms.ToTensor())


In [None]:
images = torch.stack([img_t for img_t, _ in mnist], dim=3)

In [None]:
print('mean')
print(images.view(1, -1).mean(dim=1))
print('standard deviation')
print(images.view(1, -1).std(dim=1))

mean
tensor([0.1307])
standard deviation
tensor([0.3081])


In [None]:
print('mean')
print(images.view(1, -1).mean(dim=1))
print('standard deviation')
print(images.view(1, -1).std(dim=1))

mean
tensor([0.1307])
standard deviation
tensor([0.3081])


In [None]:
mnist = datasets.MNIST(data_path, train=True, download=False, 
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307),
                                               (0.3081))]))

In [None]:
mnist_val = datasets.MNIST(data_path, train=False, download=False, 
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307),
                                               (0.3081))]))

In [None]:
input_size = 784 
hidden_sizes = [128, 64]
output_size = 10
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
                     nn.ReLU(),
                     nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                     nn.ReLU(),
                     nn.Linear(hidden_sizes[1], output_size),
                     nn.LogSoftmax(dim=1))

In [None]:
train_loader = torch.utils.data.DataLoader(mnist, batch_size=64,
                                           shuffle=True)
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.NLLLoss()
n_epochs = 10
for epoch in range(n_epochs):
    for imgs, labels in train_loader:
        optimizer.zero_grad()
        batch_size = imgs.shape[0]
        output = model(imgs.view(batch_size, -1))
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
    print("Epoch: %d, Loss: %f" % (epoch, float(loss)))

Epoch: 0, Loss: 0.312053
Epoch: 1, Loss: 0.097645
Epoch: 2, Loss: 0.454697
Epoch: 3, Loss: 0.046747
Epoch: 4, Loss: 0.127001
Epoch: 5, Loss: 0.280274
Epoch: 6, Loss: 0.086072
Epoch: 7, Loss: 0.207579
Epoch: 8, Loss: 0.112312
Epoch: 9, Loss: 0.070654


In [None]:
val_loader = torch.utils.data.DataLoader(mnist_val, batch_size=64,
                                           shuffle=True)

correct = 0
total = 0
with torch.no_grad():
    for images, labels in val_loader:
        batch_size = imgs.shape[0]
        outputs = model(imgs.view(batch_size, -1))
        _, predicted = torch.max(outputs, dim=1)
        total += labels.shape[0]
        correct += int((predicted == labels).sum())
print("Accuracy: %f", correct / total)

Accuracy: %f 0.9663
