In [47]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

tf = ToTensor()
train_dataset = MNIST(root='./data', train=True, transform=tf, download=True)
test_dataset = MNIST(root='./data', train=False, transform=tf)


In [48]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [49]:
from torch import nn

class ConvolutionNetwork(nn.Module):
  def __init__(self):
    super().__init__()
    self.convolution_stack = nn.Sequential(
      # convolutional part
      nn.Conv2d(1, 32, kernel_size=3, stride=3),
      nn.ReLU(),
      nn.Dropout(0.1),
      nn.Conv2d(32, 64, kernel_size=3, stride=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.dense_stack = nn.Sequential(
      # dense part
      nn.Linear(576, 128),
      nn.ReLU(),
      nn.Linear(128, 10)
    )
    self.output_function = nn.Softmax(dim=1)
  
  def forward(self, x):
    x = self.convolution_stack(x)
    x = x.view(x.size(0), -1)
    x = self.dense_stack(x)
    x = self.output_function(x)
    return x


In [50]:
import torch

device = (
  'cuda' if torch.cuda.is_available() else
  'mps' if torch.backends.mps.is_available() else 
  'cpu'
)


In [41]:
device = 'cpu'

In [51]:
model = ConvolutionNetwork().to(device)

In [52]:
from torch import optim
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [53]:
# get first batch
images, _ = next(iter(train_loader))
images = images.to(device)
prediction = model(images)
prediction

  nonzero_finite_vals = torch.masked_select(


tensor([[0.0906, 0.0954, 0.1017, 0.0964, 0.0962, 0.1106, 0.1017, 0.1029, 0.0957,
         0.1088],
        [0.0926, 0.0961, 0.1035, 0.0950, 0.0960, 0.1090, 0.1014, 0.1040, 0.0933,
         0.1091],
        [0.0923, 0.0966, 0.1026, 0.0938, 0.0961, 0.1114, 0.1023, 0.1051, 0.0925,
         0.1071],
        [0.0918, 0.0971, 0.1026, 0.0965, 0.0946, 0.1104, 0.1026, 0.1032, 0.0927,
         0.1085],
        [0.0920, 0.0965, 0.1038, 0.0956, 0.0956, 0.1086, 0.0998, 0.1038, 0.0957,
         0.1085],
        [0.0922, 0.0976, 0.1027, 0.0965, 0.0949, 0.1121, 0.1008, 0.1013, 0.0937,
         0.1082],
        [0.0916, 0.0980, 0.1018, 0.0962, 0.0956, 0.1124, 0.1008, 0.1020, 0.0923,
         0.1094],
        [0.0930, 0.0973, 0.1012, 0.0955, 0.0931, 0.1107, 0.1010, 0.1018, 0.0961,
         0.1103],
        [0.0906, 0.0988, 0.1030, 0.0960, 0.0935, 0.1112, 0.1017, 0.1025, 0.0941,
         0.1086],
        [0.0928, 0.0973, 0.1030, 0.0950, 0.0946, 0.1084, 0.1019, 0.1040, 0.0944,
         0.1086],
        [0

In [54]:
epochs = 20
for current_epoch in range(epochs):
  for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    predictions = model(images)
    loss = loss_function(predictions, labels)
    loss.backward()
    optimizer.step()
  print(f'Epoch {current_epoch}: last loss = {loss.item()}')

Epoch 0: last loss = 1.6557761430740356
Epoch 1: last loss = 1.6053564548492432
Epoch 2: last loss = 1.6447163820266724
Epoch 3: last loss = 1.4918988943099976
Epoch 4: last loss = 1.4653573036193848
Epoch 5: last loss = 1.4612696170806885
Epoch 6: last loss = 1.5002129077911377
Epoch 7: last loss = 1.4611557722091675
Epoch 8: last loss = 1.4611525535583496
Epoch 9: last loss = 1.461168646812439
Epoch 10: last loss = 1.4965662956237793
Epoch 11: last loss = 1.4611517190933228
Epoch 12: last loss = 1.4632140398025513
Epoch 13: last loss = 1.4615941047668457
Epoch 14: last loss = 1.4918816089630127
Epoch 15: last loss = 1.469423770904541
Epoch 16: last loss = 1.4920732975006104
Epoch 17: last loss = 1.461150050163269
Epoch 18: last loss = 1.4749293327331543
Epoch 19: last loss = 1.4611501693725586


In [55]:
correct, total = 0, 0
model.eval()
with torch.no_grad():
  for images, labels in test_loader:
    # print(images.shape)
    images, labels = images.to(device), labels.to(device)
    predictions = model(images)
    # print(predictions.shape)
    _, predicted = torch.max(predictions.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
accuracy = correct / total
print(f'Accuracy: {accuracy: .2%}')

Accuracy:  98.87%


In [56]:
# save model
postfix = str(round(accuracy * 1e4))
path = f'./models/model_{postfix}.pth'
torch.save(model, path)
state = model.state_dict()
torch.save(state, f'{path}.state')
