In [6]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch import optim
from torchvision import datasets, transforms

In [3]:
# device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [9]:
# get MNIST data
transform = transforms.ToTensor()

train_dataset = torchvision.datasets.MNIST(root='torch/data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='torch/data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=100, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)

## convolutional neural network architecture

In [17]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()

        ''' CONVOLUTIONAL LAYER '''
        # in: 28-x-28, 0 channels
        # out: 14-x-14, 16 channels
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        ''' CONVOLUTIONAL LAYER'''
        # in: 14-x-14, 16 channels
        # out: 7-x-7, 32 channels
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        ''' FULLY-CONNECTED LAYER'''
        # in: 7 * 7 * 32
        # out: 10 classes
        self.fc = nn.Linear(7*7*32, 10)

    def forward(self, x):
        # feed-forward through the convolutional layers
        x = self.layer1(x)
        x = self.layer2(x)

        # flatten the image data for the fully-connected layer
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)

        return x

### instantiate the neural network

In [22]:
# the model
model = ConvNet().to(device)

# loss and optimization functions
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

#  MODEL TRAINING

In [23]:
model.train()  # DOES THIS MATTER????

total_step = len(train_loader)
n_epochs = 2

for epoch in range(n_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        # prevent the gradients from accumulating
        optimizer.zero_grad()
        
        # forward-propagation
        # (analogous to model.forward(images))
        output = model(images)
        
        # compute loss
        loss = criterion(output, labels)
        
        # back-propagation
        loss.backward()
        
        # update the gradients' weights
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print('Epoch: [{}/{}], Step: [{}/{}], Loss: {:.4f}'.format(epoch+1, n_epochs, i+1, total_step, loss.item()))

Epoch: [1/2], Step: [100/600], Loss: 0.0903
Epoch: [1/2], Step: [200/600], Loss: 0.0731
Epoch: [1/2], Step: [300/600], Loss: 0.1183
Epoch: [1/2], Step: [400/600], Loss: 0.0625
Epoch: [1/2], Step: [500/600], Loss: 0.0942
Epoch: [1/2], Step: [600/600], Loss: 0.2129
Epoch: [2/2], Step: [100/600], Loss: 0.0204
Epoch: [2/2], Step: [200/600], Loss: 0.0236
Epoch: [2/2], Step: [300/600], Loss: 0.0613
Epoch: [2/2], Step: [400/600], Loss: 0.0325
Epoch: [2/2], Step: [500/600], Loss: 0.0469
Epoch: [2/2], Step: [600/600], Loss: 0.2090


# MODEL TESTING

In [25]:
model.eval()

with torch.no_grad():
    correct = 0
    total = 0
    
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        
        # forward-propagate test-set data
        output = model(images)
        
        # make predictions based on the class probabilities of the CNN's output
        _, predicted = torch.max(output.data, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    print('Test accuracy of the model on 10,000 it never saw previously: {}%'.format(100 * correct / total))
    

# save the model's checkpoint
torch.save(model.state_dict(), 'torch/model.ckpt')

Test accuracy of the model on 10,000 it never saw previously: 98.52%
