In [24]:
import torchvision
import torchvision.transforms as transform
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [25]:
image_path = './MNIST/'
transform = transform.Compose([
    transform.ToTensor()
])

In [26]:
mnist_train_ds = torchvision.datasets.MNIST(
    image_path, train=True, transform=transform,  download=True
)

mnist_test_ds = torchvision.datasets.MNIST(
    image_path, train=False, transform=transform,  download=True
)

In [27]:
batch_size = 64
torch.manual_seed(1)
train_dl = DataLoader(mnist_train_ds, 
                      batch_size=batch_size, shuffle=True)

In [28]:
image_size = mnist_train_ds[0][0].shape
input_size = image_size[2]*image_size[1]*image_size[0]
output_shape = 10

import torch.nn as nn
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fl = nn.Flatten()
        self.l1 = nn.Linear(input_size, 32)
        self.a1 = nn.ReLU()
        self.l2 = nn.Linear(32, 16)
        self.a2 = nn.ReLU()
        self.l3 = nn.Linear(16, 10)
        self.a3 = nn.Softmax()
    def forward(self, x):
        x = self.fl(x)
        x = self.l1(x)
        x = self.a1(x)
        x = self.l2(x)
        x = self.a2(x)
        x = self.l3(x)
        x = self.a3(x)

        return x
    
model = Model().to(device)
model

Model(
  (fl): Flatten(start_dim=1, end_dim=-1)
  (l1): Linear(in_features=784, out_features=32, bias=True)
  (a1): ReLU()
  (l2): Linear(in_features=32, out_features=16, bias=True)
  (a2): ReLU()
  (l3): Linear(in_features=16, out_features=10, bias=True)
  (a3): Softmax(dim=None)
)

In [29]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr= 0.001)

In [30]:
num_epochs = 20
log_epochs = 1

acc_hist = [0]*num_epochs

for epoch in range(num_epochs):
    for x_batch, y_batch in train_dl:
        pred = model(x_batch.to(device))
        loss = loss_fn(pred, y_batch.to(device))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        is_correct = (torch.argmax(pred, dim=1) == y_batch.to(device)).float()
        acc_hist[epoch] += is_correct.sum().item()

    acc_hist[epoch] /= len(train_dl.dataset)
    if epoch % log_epochs == 0 :
        print(f'Epoch {epoch} Acc : {acc_hist[epoch]:4f}')

Epoch 0 Acc : 0.811650
Epoch 1 Acc : 0.917483
Epoch 2 Acc : 0.929900
Epoch 3 Acc : 0.936783
Epoch 4 Acc : 0.941067
Epoch 5 Acc : 0.945867
Epoch 6 Acc : 0.948633
Epoch 7 Acc : 0.950967
Epoch 8 Acc : 0.953967
Epoch 9 Acc : 0.955983
Epoch 10 Acc : 0.957983
Epoch 11 Acc : 0.959350
Epoch 12 Acc : 0.961967
Epoch 13 Acc : 0.962500
Epoch 14 Acc : 0.964450
Epoch 15 Acc : 0.965283
Epoch 16 Acc : 0.965867
Epoch 17 Acc : 0.966767
Epoch 18 Acc : 0.967267
Epoch 19 Acc : 0.968117


In [31]:
with torch.no_grad():
    pred = model(mnist_test_ds.data.to(device) / 255.)
    is_correct = (
    torch.argmax(pred, dim=1) ==
    mnist_test_ds.targets.to(device)).float()
    print(f'Test accuracy: {is_correct.mean().item():.4f}')

Test accuracy: 0.9549
