In [35]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

tf = ToTensor()
train_dataset = MNIST(root='./data', train=True, transform=tf, download=True)
test_dataset = MNIST(root='./data', train=False, transform=tf)


In [36]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [37]:
from torch import nn

class ConvolutionNetwork(nn.Module):
  '''Network with convolutional layers'''
  def __init__(self):
    super().__init__()
    self.convolution_stack = nn.Sequential(
      # convolutional part
      nn.Conv2d(1, 32, kernel_size=3, stride=2),
      nn.ReLU(),
      nn.Conv2d(32, 64, kernel_size=3, stride=3),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.dense_stack = nn.Sequential(
      # dense part
      nn.Linear(256, 128),
      nn.ReLU(),
      nn.Dropout(0.2),
      nn.Linear(128, 10)
    )
    self.output_function = nn.Softmax(dim=1)
    self.flatten = nn.Flatten()
  
  def forward(self, x):
    x = self.convolution_stack(x)
    x = self.flatten(x)
    x = self.dense_stack(x)
    x = self.output_function(x)
    return x


In [38]:
import torch

device = (
  'cuda' if torch.cuda.is_available() else
  'mps' if torch.backends.mps.is_available() else 
  'cpu'
)


In [31]:
device = 'cpu'

In [39]:
model = ConvolutionNetwork().to(device)

In [40]:
from torch import optim
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [41]:
# get first batch
images, _ = next(iter(train_loader))
images = images.to(device)
prediction = model(images)
prediction

tensor([[0.1009, 0.1054, 0.1043, 0.1109, 0.0924, 0.0957, 0.0961, 0.0944, 0.1112,
         0.0887],
        [0.1033, 0.1083, 0.1034, 0.1129, 0.0910, 0.0996, 0.0920, 0.0923, 0.1121,
         0.0851],
        [0.1024, 0.1151, 0.0974, 0.1100, 0.0983, 0.0960, 0.0896, 0.0950, 0.1067,
         0.0896],
        [0.1067, 0.1085, 0.0992, 0.1107, 0.0930, 0.0962, 0.0917, 0.0955, 0.1122,
         0.0862],
        [0.1023, 0.1060, 0.1028, 0.1148, 0.0925, 0.0936, 0.0918, 0.0980, 0.1092,
         0.0889],
        [0.1029, 0.1072, 0.1004, 0.1080, 0.0977, 0.0967, 0.0900, 0.0953, 0.1165,
         0.0854],
        [0.1020, 0.1075, 0.0986, 0.1048, 0.0959, 0.0980, 0.0940, 0.0972, 0.1129,
         0.0890],
        [0.1037, 0.1037, 0.1012, 0.1139, 0.0955, 0.0949, 0.0965, 0.0941, 0.1075,
         0.0890],
        [0.1015, 0.1092, 0.1037, 0.1138, 0.0904, 0.1034, 0.0877, 0.0957, 0.1112,
         0.0835],
        [0.0990, 0.1105, 0.1000, 0.1088, 0.0985, 0.0956, 0.0915, 0.0973, 0.1114,
         0.0874],
        [0

In [43]:
epochs = 25
for current_epoch in range(epochs):
  for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    predictions = model(images)
    loss = loss_function(predictions, labels)
    loss.backward()
    optimizer.step()
  print(f'Epoch {current_epoch}: last loss = {loss.item()}')

Epoch 0: last loss = 1.4987726211547852
Epoch 1: last loss = 1.4611507654190063
Epoch 2: last loss = 1.488769769668579
Epoch 3: last loss = 1.461150050163269
Epoch 4: last loss = 1.4611657857894897
Epoch 5: last loss = 1.492398738861084
Epoch 6: last loss = 1.4623394012451172
Epoch 7: last loss = 1.4611914157867432
Epoch 8: last loss = 1.4612189531326294
Epoch 9: last loss = 1.4915333986282349
Epoch 10: last loss = 1.4978288412094116
Epoch 11: last loss = 1.461153507232666
Epoch 12: last loss = 1.461151123046875
Epoch 13: last loss = 1.461547613143921
Epoch 14: last loss = 1.4611501693725586
Epoch 15: last loss = 1.4920281171798706
Epoch 16: last loss = 1.5167100429534912
Epoch 17: last loss = 1.461150884628296
Epoch 18: last loss = 1.4649856090545654
Epoch 19: last loss = 1.461155891418457
Epoch 20: last loss = 1.4910223484039307
Epoch 21: last loss = 1.4611501693725586
Epoch 22: last loss = 1.4615724086761475
Epoch 23: last loss = 1.4632105827331543
Epoch 24: last loss = 1.4611500501

In [44]:
correct, total = 0, 0
model.eval()
with torch.no_grad():
  for images, labels in test_loader:
    # print(images.shape)
    images, labels = images.to(device), labels.to(device)
    predictions = model(images)
    # print(predictions.shape)
    _, predicted = torch.max(predictions.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
accuracy = correct / total
print(f'Accuracy: {accuracy: .2%}')

Accuracy:  98.20%


In [45]:
# save model
postfix = str(round(accuracy * 1e4))
path = f'./models/model_{postfix}.pth'
torch.save(model, path)
state = model.state_dict()
torch.save(state, f'{path}.state')
