In [57]:
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import mnist
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from model import get_model

In [58]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

## Download the MNIST dataset

In [59]:
batch_size = 256
train_dataset = mnist.MNIST(
    root="./dataset", train=True, download=True, transform=transform
)
test_dataset = mnist.MNIST(root="./dataset", train=False, transform=transform)

## Load the train set

In [60]:
train_loader = DataLoader(train_dataset, batch_size=batch_size)

## Load the test set

In [61]:
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [62]:
Net = get_model()
device = torch.device("mps")
model = Net().to(device)

In [63]:
train_losses = []

In [64]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        predicted = output.argmax(
            dim=1, keepdim=True
        )  # get the index of the max log-probability
        correct += predicted.eq(target.view_as(predicted)).sum().item()
        total += target.size(0)

    epoch_loss /= len(train_loader)
    epoch_accuracy = 100.0 * correct / total
    print(
        "Train Epoch: {} \tAverage Loss: {:.6f}\tAccuracy: {:.2f}%".format(
            epoch, epoch_loss, epoch_accuracy
        )
    )

In [65]:
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

In [66]:
num_epochs = 40
for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, epoch)

model.to(device)
torch.save(model.state_dict(), "model/mnist.pth")

Train Epoch: 1 	Average Loss: 0.293752	Accuracy: 90.90%
Train Epoch: 2 	Average Loss: 0.095086	Accuracy: 97.23%
Train Epoch: 3 	Average Loss: 0.070231	Accuracy: 97.92%
Train Epoch: 4 	Average Loss: 0.056409	Accuracy: 98.28%
Train Epoch: 5 	Average Loss: 0.049109	Accuracy: 98.47%
Train Epoch: 6 	Average Loss: 0.041312	Accuracy: 98.71%
Train Epoch: 7 	Average Loss: 0.038316	Accuracy: 98.80%
Train Epoch: 8 	Average Loss: 0.033667	Accuracy: 98.90%
Train Epoch: 9 	Average Loss: 0.030488	Accuracy: 99.04%
Train Epoch: 10 	Average Loss: 0.027102	Accuracy: 99.11%
Train Epoch: 11 	Average Loss: 0.025441	Accuracy: 99.15%
Train Epoch: 12 	Average Loss: 0.024146	Accuracy: 99.22%
Train Epoch: 13 	Average Loss: 0.024450	Accuracy: 99.20%
Train Epoch: 14 	Average Loss: 0.022286	Accuracy: 99.27%
Train Epoch: 15 	Average Loss: 0.021227	Accuracy: 99.30%
Train Epoch: 16 	Average Loss: 0.019786	Accuracy: 99.38%
Train Epoch: 17 	Average Loss: 0.018641	Accuracy: 99.39%
Train Epoch: 18 	Average Loss: 0.017780	