In [1]:
from pathlib import Path
import gzip
import pickle
import numpy as np
import matplotlib.pyplot as plt

In [2]:
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader

In [3]:
data_path = Path('.').resolve(strict=True).parent / 'data'
path = data_path / 'mnist'

path.mkdir(parents=True, exist_ok=True)

In [4]:
url = "https://github.com/pytorch/tutorials/raw/master/_static/"
file_name = 'mnist.pkl.gz'

if not (path / file_name).exists():
    response = requests.get(url + file_name)
    with (path / file_name).open('wb') as file:
        file.write(response.content)

In [5]:
with gzip.open((path / file_name).as_posix(), 'rb') as file:
    ((X_train, y_train), (X_valid, y_valid), _) = \
        pickle.load(file, encoding='latin-1')

In [6]:
X_train, y_train, X_test, y_test = map(
    torch.tensor, (X_train, y_train, X_valid, y_valid))
n, c = X_train.shape

In [7]:
bs = 128
lr = 0.1
epochs = 20
momentum = 0.9
loss_func = F.cross_entropy

In [8]:
class Mnist_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1)

    def forward(self, xb):
        xb = xb.view(-1, 1, 28, 28)
        xb = F.relu(self.conv1(xb))
        xb = F.relu(self.conv2(xb))
        xb = F.relu(self.conv3(xb))
        xb = F.avg_pool2d(xb, 4)
        return xb.view(-1, xb.size(1))

In [9]:
train_ds = TensorDataset(X_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs)

model = Mnist_CNN()
opt = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

In [10]:
losses = []
accuracies = []

for epoch in range(epochs):
    model.train()
    for xb, yb in train_dl:
        pred = model(xb)
        loss = loss_func(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()

In [11]:
preds = torch.argmax(model(X_test), dim=1)
(preds == y_test).float().mean()

tensor(0.9616)