In [22]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

tf = ToTensor()
train_dataset = MNIST(root='./data', train=True, transform=tf, download=True)
test_dataset = MNIST(root='./data', train=False, transform=tf)


In [30]:
64 * 7 * 7

3136

In [58]:
from torch import nn

class DigitRecognitionCNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.extraction_base = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
      nn.Conv2d(32, 64, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.flatten = nn.Flatten()
    self.classification_head = nn.Sequential(
      nn.Linear(64 * 7 * 7, 128),
      nn.ReLU(),
      nn.Dropout(0.1),
      nn.Linear(128, 10),
    )
    self.softmax = nn.Softmax(dim=1)
  
  def forward(self, x):
    x = self.extraction_base(x)
    x = self.flatten(x)
    x = self.classification_head(x)
    return self.softmax(x)

In [59]:
import torch

device = (
  'cuda' if torch.cuda.is_available()
  else 'mps' if torch.backends.mps.is_available()
  else 'cpu'
)

In [60]:
model = DigitRecognitionCNN().to(device)

In [61]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [62]:
from torch import optim
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [63]:
# get first batch
images, _ = next(iter(train_loader))
images = images.to(device)
prediction = model(images)
prediction

tensor([[0.1039, 0.1007, 0.0989, 0.0953, 0.0932, 0.0916, 0.1082, 0.1066, 0.0915,
         0.1101],
        [0.1047, 0.1021, 0.1011, 0.0933, 0.0917, 0.0917, 0.1055, 0.1080, 0.0910,
         0.1111],
        [0.1047, 0.1048, 0.1011, 0.0916, 0.0882, 0.0919, 0.1059, 0.1076, 0.0923,
         0.1118],
        [0.1047, 0.1021, 0.0998, 0.0923, 0.0907, 0.0912, 0.1048, 0.1124, 0.0932,
         0.1088],
        [0.1046, 0.1035, 0.0997, 0.0939, 0.0899, 0.0898, 0.1061, 0.1083, 0.0926,
         0.1117],
        [0.1027, 0.1021, 0.1013, 0.0939, 0.0915, 0.0910, 0.1064, 0.1089, 0.0920,
         0.1103],
        [0.1044, 0.1031, 0.1005, 0.0938, 0.0909, 0.0912, 0.1057, 0.1081, 0.0928,
         0.1096],
        [0.1027, 0.1033, 0.1022, 0.0941, 0.0896, 0.0919, 0.1044, 0.1090, 0.0930,
         0.1099],
        [0.1039, 0.1032, 0.0995, 0.0927, 0.0895, 0.0912, 0.1052, 0.1119, 0.0923,
         0.1106],
        [0.1045, 0.1032, 0.0996, 0.0926, 0.0902, 0.0922, 0.1032, 0.1112, 0.0927,
         0.1107],
        [0

In [64]:
epochs = 20
total_loss, total_tries = 0, 0
for current_epoch in range(epochs):
  for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    predictions = model(images)
    loss = loss_function(predictions, labels)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
    total_tries += 1
  average_loss = total_loss / total_tries
  print(f'Epoch {current_epoch}: average loss = {average_loss}')
  total_loss, total_tries = 0, 0


Epoch 0: average loss = 1.5697455500284831
Epoch 1: average loss = 1.4849194522857665
Epoch 2: average loss = 1.479354282951355
Epoch 3: average loss = 1.4768625860850015
Epoch 4: average loss = 1.4752865712483725
Epoch 5: average loss = 1.4741421333312987
Epoch 6: average loss = 1.4727361528396608
Epoch 7: average loss = 1.4721535385131836
Epoch 8: average loss = 1.4713026896794636
Epoch 9: average loss = 1.4709596195856731
Epoch 10: average loss = 1.4699616762161254
Epoch 11: average loss = 1.47053470287323


KeyboardInterrupt: 

In [65]:
# evaluate model on a test set
correct, total = 0, 0
model.eval()
with torch.no_grad():
  for images, labels in test_loader:
    # print(images.shape)
    images, labels = images.to(device), labels.to(device)
    predictions = model(images)
    # print(predictions.shape)
    _, predicted = torch.max(predictions.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
accuracy = correct / total
print(f'Accuracy: {accuracy: .2%}')

Accuracy:  98.85%


In [66]:
# save model
postfix = str(round(accuracy * 1e4))
path = f'./models/model_{postfix}.pth'
torch.save(model, path)
state = model.state_dict()
torch.save(state, f'{path}.state')
