In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
from torchmetrics.functional.classification import accuracy
from torchmetrics.aggregation import MeanMetric
from src.model import MNIST

In [39]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cuda')
print(f"Device = {device}")

Device = cuda


In [31]:
train_data = datasets.MNIST(root="./data/mnist/",
                            train=True,
                            download=True,
                            transform=transforms.ToTensor()
                            )

test_data = datasets.MNIST(root="./data/mnist/",
                           train=False,
                           download=True,
                           transform=transforms.ToTensor()
                           )

print(len(train_data), len(test_data))

60000 10000


In [32]:
batch_size = 50
learning_rate = 0.0001
epoch_num = 20

In [33]:
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

In [34]:
model = MNIST.Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
metric_fn = accuracy

In [35]:
model.train()

loss_mean = MeanMetric()
metric_mean = MeanMetric()

for epoch in range(epoch_num):
    for data, target in train_loader:
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        metric = metric_fn(output, target)
        loss.backward()
        optimizer.step()
        
        loss_mean.update(loss.to('cpu'))
        metric_mean.update(metric.to('cpu'))
    print(f"Epoch: {epoch+1}, Accuracy: {metric_mean.compute():.4f}, Loss: {loss_mean.compute():.4f}")
        


Epoch: 1, Accuracy: 0.7572, Loss: 0.8556
Epoch: 2, Accuracy: 0.8341, Loss: 0.5749
Epoch: 3, Accuracy: 0.8683, Loss: 0.4523
Epoch: 4, Accuracy: 0.8892, Loss: 0.3789
Epoch: 5, Accuracy: 0.9033, Loss: 0.3294
Epoch: 6, Accuracy: 0.9137, Loss: 0.2934
Epoch: 7, Accuracy: 0.9217, Loss: 0.2659
Epoch: 8, Accuracy: 0.9280, Loss: 0.2442
Epoch: 9, Accuracy: 0.9331, Loss: 0.2264
Epoch: 10, Accuracy: 0.9374, Loss: 0.2117
Epoch: 11, Accuracy: 0.9411, Loss: 0.1992
Epoch: 12, Accuracy: 0.9442, Loss: 0.1883
Epoch: 13, Accuracy: 0.9470, Loss: 0.1789
Epoch: 14, Accuracy: 0.9495, Loss: 0.1705
Epoch: 15, Accuracy: 0.9517, Loss: 0.1630
Epoch: 16, Accuracy: 0.9536, Loss: 0.1563
Epoch: 17, Accuracy: 0.9555, Loss: 0.1502
Epoch: 18, Accuracy: 0.9571, Loss: 0.1446
Epoch: 19, Accuracy: 0.9586, Loss: 0.1395
Epoch: 20, Accuracy: 0.9600, Loss: 0.1348


In [36]:
model.eval()

correct = 0
for data, target in test_loader:
    data = data.to(device)
    target = target.to(device)
    output = model(data)
    prediction = output.data.max(1)[1]
    correct += prediction.eq(target.data).sum()
print('Test set Accuracy : {:.2f}%'.format(100. * correct / len(test_loader.dataset)))

Test set Accuracy : 98.55%


In [37]:
state_dict = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
}
checkpoint_path = './checkpoints/mnist_net.pth'

torch.save(state_dict, checkpoint_path)