Simple PyTorch neural network for classifying handwritten digits using MNIST dataset

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        # First 2D convolutional layer, taking in 1 input channel (image),
        # outputting 32 convolutional features, with a square kernel size of 3
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        # Second 2D convolutional layer, taking in the 32 input layers,
        # outputting 64 convolutional features, with a square kernel size of 3
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)

        # First fully connected layer
        self.fc1 = nn.Linear(9216, 128)
        # Second fully connected layer that outputs our 10 labels
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)

        output = F.log_softmax(x, dim=1)
        return output

Parameters

In [3]:
batch_size = 64
test_batch_size = 1000
epochs = 14
lr = 1.0
gamma = 0.7
seed = 1
log_interval = 10
save_model = False

torch.manual_seed(seed)

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cpu device


In [4]:
def train (model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print(f"Train Epoch: {epoch} [{batch_idx*len(data)}/{len(train_loader.dataset)}] Loss: {loss.item()}")
            
def test (model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction="sum").item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            
    test_loss /= len(test_loader.dataset)
    print(f"\nTest set: Average loss: {test_loss}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset)})\n")

In [5]:
train_kwargs = {"batch_size": batch_size}
test_kwargs = {"batch_size": test_batch_size}

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # incredible magic numbers
])

dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("../data", train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset1, shuffle=True, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, shuffle=True, **test_kwargs)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=lr)

scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()
    
if save_model:
    torch.save(model.state_dict(), "mnist_cnn.pt")



Train Epoch: 1 [0/60000] Loss: 2.303724527359009
Train Epoch: 1 [640/60000] Loss: 1.550111174583435
Train Epoch: 1 [1280/60000] Loss: 1.0296837091445923
Train Epoch: 1 [1920/60000] Loss: 0.7578520178794861
Train Epoch: 1 [2560/60000] Loss: 0.6128470301628113
Train Epoch: 1 [3200/60000] Loss: 0.40133553743362427
Train Epoch: 1 [3840/60000] Loss: 0.3413877785205841
Train Epoch: 1 [4480/60000] Loss: 0.5638273358345032
Train Epoch: 1 [5120/60000] Loss: 0.23237666487693787
Train Epoch: 1 [5760/60000] Loss: 0.3067222237586975
Train Epoch: 1 [6400/60000] Loss: 0.2873188257217407
Train Epoch: 1 [7040/60000] Loss: 0.3278321921825409
Train Epoch: 1 [7680/60000] Loss: 0.33765214681625366
Train Epoch: 1 [8320/60000] Loss: 0.22571925818920135
Train Epoch: 1 [8960/60000] Loss: 0.2623680531978607
Train Epoch: 1 [9600/60000] Loss: 0.1240774467587471
Train Epoch: 1 [10240/60000] Loss: 0.20467790961265564
Train Epoch: 1 [10880/60000] Loss: 0.30298545956611633
Train Epoch: 1 [11520/60000] Loss: 0.2496369