In [None]:
# Code from https://clay-atlas.com/us/blog/2021/04/22/pytorch-en-tutorial-4-train-a-model-to-classify-mnist/

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as dset
from torchvision import datasets, transforms

In [5]:
# GPU
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("GPU State:", device)

GPU State: cpu


In [6]:
# Transform
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

In [7]:
# Data
trainSet = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
testSet = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
trainLoader = dset.DataLoader(trainSet, batch_size=64, shuffle=True)
testLoader = dset.DataLoader(testSet, batch_size=64, shuffle=False)

In [8]:
# Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(in_features=784, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=10),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, input):
        return self.main(input)

In [9]:
net = Net().to(device)
print(net)

Net(
  (main): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=10, bias=True)
    (5): LogSoftmax(dim=1)
  )
)


In [10]:
# Parameters
epochs = 3
lr = 0.002
criterion = nn.NLLLoss()
optimizer = optim.SGD(net.parameters(), lr=0.002, momentum=0.9)

In [11]:
# Train
for epoch in range(epochs):
    running_loss = 0.0

    for times, data in enumerate(trainLoader):
        inputs, labels = data[0].to(device), data[1].to(device)
        inputs = inputs.view(inputs.shape[0], -1)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Foward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        if times % 100 == 99 or times + 1 == len(trainLoader):
            print(
                "[%d/%d, %d/%d] loss: %.3f"
                % (epoch + 1, epochs, times + 1, len(trainLoader), running_loss / 2000)
            )

print("Training Finished.")

[1/3, 100/938] loss: 0.109
[1/3, 200/938] loss: 0.186
[1/3, 300/938] loss: 0.227
[1/3, 400/938] loss: 0.256
[1/3, 500/938] loss: 0.281
[1/3, 600/938] loss: 0.302
[1/3, 700/938] loss: 0.324
[1/3, 800/938] loss: 0.343
[1/3, 900/938] loss: 0.361
[1/3, 938/938] loss: 0.367
[2/3, 100/938] loss: 0.017
[2/3, 200/938] loss: 0.035
[2/3, 300/938] loss: 0.052
[2/3, 400/938] loss: 0.068
[2/3, 500/938] loss: 0.084
[2/3, 600/938] loss: 0.099
[2/3, 700/938] loss: 0.115
[2/3, 800/938] loss: 0.129
[2/3, 900/938] loss: 0.144
[2/3, 938/938] loss: 0.150
[3/3, 100/938] loss: 0.015
[3/3, 200/938] loss: 0.029
[3/3, 300/938] loss: 0.043
[3/3, 400/938] loss: 0.055
[3/3, 500/938] loss: 0.069
[3/3, 600/938] loss: 0.082
[3/3, 700/938] loss: 0.095
[3/3, 800/938] loss: 0.108
[3/3, 900/938] loss: 0.121
[3/3, 938/938] loss: 0.126
Training Finished.


In [12]:
# Test
correct = 0
total = 0

with torch.no_grad():
    for data in testLoader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        inputs = inputs.view(inputs.shape[0], -1)

        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(
    "Accuracy of the network on the 10000 test images: %d %%" % (100 * correct / total)
)

class_correct = [0 for i in range(10)]
class_total = [0 for i in range(10)]

with torch.no_grad():
    for data in testLoader:
        inputs, labels = data[0].to(device), data[1].to(device)
        inputs = inputs.view(inputs.shape[0], -1)

        outputs = net(inputs)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(10):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1
            # print(class_correct)
            # print(class_total)

for i in range(10):
    print("Accuracy of %d: %3f" % (i, (class_correct[i] / class_total[i])))

Accuracy of the network on the 10000 test images: 91 %
Accuracy of 0: 0.973684
Accuracy of 1: 0.978378
Accuracy of 2: 0.930233
Accuracy of 3: 0.750000
Accuracy of 4: 0.937853
Accuracy of 5: 0.912698
Accuracy of 6: 0.914062
Accuracy of 7: 0.914634
Accuracy of 8: 0.951049
Accuracy of 9: 0.910180
