In [17]:
import os

import torch
import torch.nn as nn
from torchvision import datasets, transforms

In [4]:
train_dataset = datasets.MNIST(
    "data/",
    download=True,
    train=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)),]
    ),
)
test_dataset = datasets.MNIST(
    "data/",
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)),]
    ),
)

In [5]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=64, shuffle=True, num_workers=15
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

In [59]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.mp = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

        self.ReLU = nn.ReLU()
        self.LogSoftmax = nn.LogSoftmax(1)

    def forward(self, x):
        in_size = x.size(0)
        out = self.ReLU(self.mp(self.conv1(x)))
        out = self.ReLU(self.mp(self.conv2(out)))
        out = out.view(in_size, -1)
        out = self.ReLU(self.fc1(out))
        out = self.ReLU(self.fc2(out))
        out = self.LogSoftmax(self.fc3(out))
        return out

In [60]:
USE_CUDA = torch.cuda.is_available()

In [61]:
net = Net()
if USE_CUDA:
    net = net.cuda()

In [62]:
loss_fn = nn.NLLLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)

In [63]:
PATH = "./cnn_net.pth"
if os.path.exists(PATH):
    net.load_state_dict(torch.load(PATH))

EPOCH = 1
err = 0
for epoch in range(EPOCH):
    for i, (data, target) in enumerate(train_loader):
        data = data.cuda()
        target = target.cuda()
        loss = loss_fn(net(data), target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        err += loss.item()
        if i % 200 == 199:
            print(f"[{epoch}, {i}] loss: {err/200}")
            err = 0
            torch.save(net.state_dict(), PATH)

torch.save(net.state_dict(), PATH)

[0, 199] loss: 0.0009491274016505713
[0, 399] loss: 0.0009376812739014895
[0, 599] loss: 0.0009789808118659947
[0, 799] loss: 0.002684974198150485


In [64]:
correct = 0
for data, target in test_loader:
    data = data.cuda()
    target = target.cuda()
    output = net(data)
    # get the index of the max log-probability
    pred = output.max(1, keepdim=True)[1]
    correct += pred.eq(target.view_as(pred)).cpu().sum()
print("{:.3f}%\n".format(100.0 * correct / len(test_loader.dataset)))

99.988%

