In [4]:
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

image_path = './data/'

transform = transforms.Compose([
    transforms.ToTensor()
])

mnist_train_dataset = torchvision.datasets.MNIST(
    root=image_path, train=True,
    transform=transform, download=True
)

mnist_test_dataset = torchvision.datasets.MNIST(
    root=image_path, train=False,
    transform=transform, download=True
)

batch_size = 64
torch.manual_seed(42)
train_dl = DataLoader(mnist_train_dataset,
                      batch_size=batch_size, shuffle=True)

100%|██████████| 9.91M/9.91M [00:03<00:00, 3.09MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 93.5kB/s]
100%|██████████| 1.65M/1.65M [00:03<00:00, 447kB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 3.68MB/s]


In [21]:
import torch.nn as nn

image_size = mnist_train_dataset[0][0].shape
input_shape = image_size[0] * image_size[1] * image_size[2]

class MNISTClassifier(nn.Module):
    def __init__(self, input_size, hidden_size_l1,
                 hidden_size_l2, output_size):
        super().__init__()
        self.l0 = nn.Flatten()
        self.l1 = nn.Linear(input_size, hidden_size_l1)
        self.a1 = nn.ReLU()
        self.l2 = nn.Linear(hidden_size_l1, hidden_size_l2)
        self.a2 = nn.ReLU()
        self.l3 = nn.Linear(hidden_size_l2, output_size)

    def forward(self, x):
        x = self.l0(x)
        x = self.l1(x)
        x = self.a1(x)
        x = self.l2(x)
        x = self.a2(x)
        x = self.l3(x)
        return x
    
model = MNISTClassifier(input_shape, 32, 16, 10)
model

MNISTClassifier(
  (l0): Flatten(start_dim=1, end_dim=-1)
  (l1): Linear(in_features=784, out_features=32, bias=True)
  (a1): ReLU()
  (l2): Linear(in_features=32, out_features=16, bias=True)
  (a2): ReLU()
  (l3): Linear(in_features=16, out_features=10, bias=True)
)

In [22]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [26]:
num_epochs = 20

train_loss = 0
train_acc = 0

for epoch in range(num_epochs):
    for x, y in train_dl:
        preds = model(x)
        loss = loss_fn(preds, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_loss += loss.item()
        is_correct = (torch.argmax(preds, dim=1) == y).float()
        train_acc += is_correct.sum()

    train_loss /= len(train_dl)
    train_acc /= len(train_dl.dataset)

    print(f'Epoch {epoch} | '
          f'loss: {train_loss:.4f} | '
          f'accuracy: {train_acc:.4f}')


Epoch 0 | loss: 0.0273 | accuracy: 0.9915
Epoch 1 | loss: 0.0263 | accuracy: 0.9916
Epoch 2 | loss: 0.0254 | accuracy: 0.9916
Epoch 3 | loss: 0.0239 | accuracy: 0.9925
Epoch 4 | loss: 0.0231 | accuracy: 0.9925
Epoch 5 | loss: 0.0231 | accuracy: 0.9925
Epoch 6 | loss: 0.0206 | accuracy: 0.9934
Epoch 7 | loss: 0.0209 | accuracy: 0.9931
Epoch 8 | loss: 0.0196 | accuracy: 0.9934
Epoch 9 | loss: 0.0201 | accuracy: 0.9934
Epoch 10 | loss: 0.0187 | accuracy: 0.9938
Epoch 11 | loss: 0.0178 | accuracy: 0.9943
Epoch 12 | loss: 0.0160 | accuracy: 0.9953
Epoch 13 | loss: 0.0168 | accuracy: 0.9945
Epoch 14 | loss: 0.0151 | accuracy: 0.9950
Epoch 15 | loss: 0.0164 | accuracy: 0.9947
Epoch 16 | loss: 0.0169 | accuracy: 0.9942
Epoch 17 | loss: 0.0161 | accuracy: 0.9945
Epoch 18 | loss: 0.0126 | accuracy: 0.9961
Epoch 19 | loss: 0.0140 | accuracy: 0.9952


In [27]:
pred = model(mnist_test_dataset.data / 255.)
is_correct = (torch.argmax(pred, dim=1) == mnist_test_dataset.targets).float()
print(f'Test accuracy: {is_correct.mean():.4f}')

Test accuracy: 0.9637
