In [6]:
import os
import numpy as np
from PIL import Image
from torch_implementation.mnist_dataset import MNISTDataset
from torch_implementation.mnist_digit_recognizer_neural_net import MNISTDigitRecognizerNeuralNet
from torch.utils.data import DataLoader

import torch
import torch.optim as optim

In [7]:
imagePath = "./MNIST/training/"
saved_model_path = "./pytorch_models/mnist_digit_recognizer_neural_net.pth"

learning_rate = 0.001
epochs_between_saves = 10
epochs_total = 400

In [8]:
dataset = MNISTDataset(imagePath)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)

### 🥾 Initialize the Neural Net

In [9]:
net = MNISTDigitRecognizerNeuralNet()

if os.path.isfile(saved_model_path):
    net.load_state_dict(torch.load(saved_model_path))

net.cuda()
loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate)

In [10]:
epoch = 1
while epoch <= epochs_total:
    print(f'Training epoch {epoch}')
    
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        inputs, labels = data
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()

        outputs = net(inputs)
        
        loss = loss_function(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        running_loss += loss.item()
        
    if epoch % epochs_between_saves == 0:
        print('Saving model...')
        torch.save(net.state_dict(), saved_model_path)
        
    print (f'Loss: {running_loss / len(dataloader):.4f}')
    epoch += 1

Training epoch 1
Loss: 1.8370
Training epoch 2
Loss: 0.6846
Training epoch 3
Loss: 0.4573
Training epoch 4
Loss: 0.3967
Training epoch 5
Loss: 0.3673
Training epoch 6
Loss: 0.3494
Training epoch 7
Loss: 0.3368
Training epoch 8
Loss: 0.3272
Training epoch 9
Loss: 0.3196
Training epoch 10
Saving model...
Loss: 0.3134
Training epoch 11
Loss: 0.3082
Training epoch 12
Loss: 0.3033
Training epoch 13
Loss: 0.2992
Training epoch 14
Loss: 0.2954
Training epoch 15
Loss: 0.2918
Training epoch 16
Loss: 0.2887
Training epoch 17
Loss: 0.2855
Training epoch 18
Loss: 0.2826
Training epoch 19
Loss: 0.2799
Training epoch 20
Saving model...
Loss: 0.2770
Training epoch 21
Loss: 0.2746
Training epoch 22
Loss: 0.2722
Training epoch 23
Loss: 0.2696
Training epoch 24
Loss: 0.2672
Training epoch 25
Loss: 0.2648
Training epoch 26
Loss: 0.2624
Training epoch 27
Loss: 0.2601
Training epoch 28
Loss: 0.2579
Training epoch 29
Loss: 0.2555
Training epoch 30
Saving model...
Loss: 0.2532
Training epoch 31
Loss: 0.2508
