In [1]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision as tv

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

num_workers = os.cpu_count()
print('cpu:', num_workers)

cuda:0
cpu: 6


In [2]:
root = os.path.join('.', 'dataset_root')
batch_size = 128
lr = 0.01
momentum=0.9
step_size=10
gamma=0.1
epochs = 20

In [3]:
# tested value
mean, std = [0.13066046], [0.30150425]

transform = tv.transforms.Compose([
        tv.transforms.ToTensor(),
        tv.transforms.Normalize(mean, std)])
train_dataset = tv.datasets.MNIST(
    root,
    train=True, 
    download=True,
    transform=transform)
test_dataset = tv.datasets.MNIST(
    root,
    train=False, 
    download=True,
    transform=transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers
)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers
)

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [5]:
def train():
    model.train()
    total_loss = 0.
    corrects = 0.
    
    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)
        batch_size = x.size(0)
        
        optimizer.zero_grad()
        output = model(x)
        loss = F.nll_loss(output, y)
        loss.backward()
        optimizer.step()

        # for display
        total_loss += loss.item() * batch_size
        preds = output.max(1, keepdim=True)[1]
        corrects += preds.eq(y.view_as(preds)).sum().item()
    
    n = len(train_loader.dataset)
    total_loss = total_loss / n
    accuracy = corrects / n
    return total_loss, accuracy

In [6]:
def validate():
    model.eval()
    total_loss = 0.
    corrects = 0.
    
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
            batch_size = x.size(0)
            output = model(x)
            loss = F.nll_loss(output, y)
            total_loss += loss.item() * batch_size
            preds = output.max(1, keepdim=True)[1]
            corrects += preds.eq(y.view_as(preds)).sum().item()
    
    n = len(test_loader.dataset)
    total_loss = total_loss / n
    accuracy = corrects / n
    return total_loss, accuracy

In [7]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
lr_scheduler = optim.lr_scheduler.StepLR(
    optimizer, 
    step_size=step_size, 
    gamma=gamma)

In [8]:
since = time.time()
for epoch in range(epochs):
    start = time.time()
    tr_loss, tr_acc = train()
    va_loss, va_acc = validate()
    lr_scheduler.step()
    
    time_elapsed = time.time() - start
    print(('[{:2d}] {:.0f}m {:.1f}s Train Loss: {:.4f} Accuracy: {:.4f}%, ' +
        'Test Loss: {:.4f} Accuracy: {:.4f}%').format(
            epoch+1, time_elapsed // 60, time_elapsed % 60,
            tr_loss, tr_acc*100.,
            va_loss, va_acc*100.))
    
time_elapsed = time.time() - since
print('Total run time: {:.0f}m {:.1f}s'.format(
    time_elapsed // 60,
    time_elapsed % 60))

[ 1] 0m 3.1s Train Loss: 0.3531 Accuracy: 89.0467%, Test Loss: 0.0932 Accuracy: 97.0000%
[ 2] 0m 2.9s Train Loss: 0.1262 Accuracy: 96.2433%, Test Loss: 0.0487 Accuracy: 98.3200%
[ 3] 0m 2.9s Train Loss: 0.0894 Accuracy: 97.3217%, Test Loss: 0.0433 Accuracy: 98.6300%
[ 4] 0m 2.9s Train Loss: 0.0719 Accuracy: 97.8500%, Test Loss: 0.0367 Accuracy: 98.7900%
[ 5] 0m 3.0s Train Loss: 0.0620 Accuracy: 98.0600%, Test Loss: 0.0377 Accuracy: 98.7900%
[ 6] 0m 3.0s Train Loss: 0.0546 Accuracy: 98.3133%, Test Loss: 0.0381 Accuracy: 98.7800%
[ 7] 0m 2.9s Train Loss: 0.0468 Accuracy: 98.5317%, Test Loss: 0.0303 Accuracy: 98.9800%
[ 8] 0m 2.9s Train Loss: 0.0443 Accuracy: 98.6183%, Test Loss: 0.0297 Accuracy: 99.0400%
[ 9] 0m 2.9s Train Loss: 0.0396 Accuracy: 98.7850%, Test Loss: 0.0294 Accuracy: 98.9900%
[10] 0m 3.1s Train Loss: 0.0404 Accuracy: 98.6783%, Test Loss: 0.0326 Accuracy: 98.9700%
[11] 0m 2.9s Train Loss: 0.0307 Accuracy: 99.0317%, Test Loss: 0.0284 Accuracy: 99.1400%
[12] 0m 3.0s Train Lo

In [9]:
torch.save(model.state_dict(), 'mnist_cnn.pt')