In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
from torchmetrics.functional.classification import accuracy
from torchmetrics.aggregation import MeanMetric
from src.model import MNIST

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cuda')
print(f"Device = {device}")

Device = cuda


In [3]:
train_data = datasets.MNIST(root="./data/mnist/",
                            train=True,
                            download=True,
                            transform=transforms.ToTensor()
                            )

test_data = datasets.MNIST(root="./data/mnist/",
                           train=False,
                           download=True,
                           transform=transforms.ToTensor()
                           )

print(len(train_data), len(test_data))

60000 10000


In [4]:
batch_size = 50
learning_rate = 0.0001
epoch_num = 20

In [5]:
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

In [6]:
model = MNIST.CNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
metric_fn = accuracy

In [7]:
model.train()

loss_mean = MeanMetric()
metric_mean = MeanMetric()

for epoch in range(epoch_num):
    for data, target in train_loader:
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        metric = metric_fn(output, target)
        loss.backward()
        optimizer.step()
        
        loss_mean.update(loss.to('cpu'))
        metric_mean.update(metric.to('cpu'))
    print(f"Epoch: {epoch+1}, Accuracy: {metric_mean.compute():.4f}, Loss: {loss_mean.compute():.4f}")
        
    
        

Epoch: 1, Accuracy: 0.7851, Loss: 0.8042
Epoch: 2, Accuracy: 0.8435, Loss: 0.5650
Epoch: 3, Accuracy: 0.8703, Loss: 0.4611
Epoch: 4, Accuracy: 0.8874, Loss: 0.3963
Epoch: 5, Accuracy: 0.9000, Loss: 0.3498
Epoch: 6, Accuracy: 0.9097, Loss: 0.3142
Epoch: 7, Accuracy: 0.9175, Loss: 0.2861
Epoch: 8, Accuracy: 0.9239, Loss: 0.2630
Epoch: 9, Accuracy: 0.9293, Loss: 0.2437
Epoch: 10, Accuracy: 0.9339, Loss: 0.2274
Epoch: 11, Accuracy: 0.9378, Loss: 0.2133
Epoch: 12, Accuracy: 0.9413, Loss: 0.2011
Epoch: 13, Accuracy: 0.9444, Loss: 0.1904
Epoch: 14, Accuracy: 0.9470, Loss: 0.1809
Epoch: 15, Accuracy: 0.9495, Loss: 0.1725
Epoch: 16, Accuracy: 0.9516, Loss: 0.1649
Epoch: 17, Accuracy: 0.9536, Loss: 0.1580
Epoch: 18, Accuracy: 0.9554, Loss: 0.1518
Epoch: 19, Accuracy: 0.9571, Loss: 0.1461
Epoch: 20, Accuracy: 0.9586, Loss: 0.1408


In [8]:
model.eval()

correct = 0
for data, target in test_loader:
    data = data.to(device)
    target = target.to(device)
    output = model(data)
    prediction = output.data.max(1)[1]
    correct += prediction.eq(target.data).sum()
print('Test set Accuracy : {:.2f}%'.format(100. * correct / len(test_loader.dataset)))

Test set Accuracy : 98.56%


In [9]:
state_dict = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
}
checkpoint_path = './checkpoints/mnist_cnet.pth'

torch.save(state_dict, checkpoint_path)