In [1]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False
if IN_COLAB:
    !pip3 install torch matplotlib torchmetrics scikit-image segmentation-models-pytorch

## Import

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as dset
from torchvision import datasets, transforms

## Set global device

In [4]:
# GPU
# device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

if torch.cuda.is_available():
    device = 'cuda:0'
elif torch.backends.mps.is_available():
    device = 'mps:0'
else:
    device = 'cpu'

print('GPU State:', device)

GPU State: mps:0


## Functions

In [5]:
def training_loop(model, loss, optimizer, loader, epochs, verbose=True, device=device):
    """
    Run training of CNN model given a loss function, optimizer, and set of training and validation data
    """

    # Train
    for epoch in range(epochs):
        # Setting initial loss
        running_loss = 0.0

        for times, data in enumerate(loader):
            # Retrieve inputs and labels
            inputs, labels = data[0].to(device), data[1].to(device)
            #inputs = inputs.view(inputs.shape[0], -1)
            #print(inputs)
        
            # Resetting optimizer gradients
            optimizer.zero_grad()

            # Forward + backward + steps
            outputs = model(inputs)
            loss_tensor = loss(outputs, labels)
            loss_tensor.backward()
            optimizer.step()

            # Print statistics
            running_loss += loss_tensor.item()
            if verbose:
                if times % 100 == 99 or times+1 == len(loader):
                    print('[%d/%d, %d/%d] loss: %.3f' % (epoch+1, epochs, times+1, len(loader), running_loss/2000))

In [6]:
def evaluate_model(model, loader, device):
    """
    Evaluate a model on all batches of a torch DataLoader

    Returns:
        Total number of correct classifications
        Total number of images
        List of the per class correct classification
        List of hte per class total number of images
    """

    # Test
    correct = 0
    total = 0

    with torch.no_grad():
        for data in loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            #inputs = inputs.view(inputs.shape[0], -1)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    class_correct = [0 for i in range(10)]
    class_total = [0 for i in range(10)]

    with torch.no_grad():
        for data in loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            #inputs = inputs.view(inputs.shape[0], -1)

            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(10):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1 
    
    return (correct, total, class_correct, class_total)
            

## Main program

In [7]:
# Transform
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,)),]
)

In [8]:
# Data
trainSet = datasets.MNIST(root='MNIST', download=True, train=True, transform=transform)
testSet = datasets.MNIST(root='MNIST', download=True, train=False, transform=transform)
trainLoader = dset.DataLoader(trainSet, batch_size=64, shuffle=True)
testLoader = dset.DataLoader(testSet, batch_size=64, shuffle=False)

In [9]:
# Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
            nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
            nn.Flatten(start_dim=1, end_dim=-1), 
            nn.Linear(in_features=800, out_features=10, bias=True),
            nn.LogSoftmax(dim=1)
        )
    
    def forward(self, input):
        return self.main(input)
    
net = Net().to(device)
print(net)

Net(
  (main): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=800, out_features=10, bias=True)
    (8): LogSoftmax(dim=1)
  )
)


In [20]:
calculated_parameters = (3*3*1+1)*16 + (3*3*16+1)*32 + (800*10)+10
calculated_parameters

12810

In [21]:
print((3*3*1+1)*16 + (3*3*16+1)*32)

4800


In [19]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())
params = count_parameters(net)
params

12810

In [10]:
# Checking parameters
[(name, param.shape) for name, param in net.named_parameters()]

[('main.0.weight', torch.Size([16, 1, 3, 3])),
 ('main.0.bias', torch.Size([16])),
 ('main.3.weight', torch.Size([32, 16, 3, 3])),
 ('main.3.bias', torch.Size([32])),
 ('main.7.weight', torch.Size([10, 800])),
 ('main.7.bias', torch.Size([10]))]

In [9]:
# Parameters
epochs = 4
lr = 0.002
loss = nn.NLLLoss()
optimizer = optim.SGD(net.parameters(), lr=0.002, momentum=0.9)

# Train
print('Training on %d images' % trainSet.data.shape[0])
training_loop(net, loss, optimizer, trainLoader, epochs)
print('Training Finished.\n')

# Test
correct, total, class_correct, class_total = evaluate_model(net, testLoader, device)
print('Accuracy of the network on the %d test images: %d %%' % (testSet.data.shape[0], (100*correct / total)))
for i in range(10):
    print('Accuracy of %d: %3f' % (i, (class_correct[i]/class_total[i])))

Training on 60000 images
[1/4, 100/938] loss: 0.098
[1/4, 200/938] loss: 0.132
[1/4, 300/938] loss: 0.151
[1/4, 400/938] loss: 0.166
[1/4, 500/938] loss: 0.179
[1/4, 600/938] loss: 0.190
[1/4, 700/938] loss: 0.200
[1/4, 800/938] loss: 0.210
[1/4, 900/938] loss: 0.219
[1/4, 938/938] loss: 0.222
[2/4, 100/938] loss: 0.008
[2/4, 200/938] loss: 0.015
[2/4, 300/938] loss: 0.022
[2/4, 400/938] loss: 0.029
[2/4, 500/938] loss: 0.035
[2/4, 600/938] loss: 0.040
[2/4, 700/938] loss: 0.046
[2/4, 800/938] loss: 0.051
[2/4, 900/938] loss: 0.057
[2/4, 938/938] loss: 0.059
[3/4, 100/938] loss: 0.005
[3/4, 200/938] loss: 0.010
[3/4, 300/938] loss: 0.015
[3/4, 400/938] loss: 0.020
[3/4, 500/938] loss: 0.025
[3/4, 600/938] loss: 0.029
[3/4, 700/938] loss: 0.034
[3/4, 800/938] loss: 0.038
[3/4, 900/938] loss: 0.042
[3/4, 938/938] loss: 0.044
[4/4, 100/938] loss: 0.004
[4/4, 200/938] loss: 0.008
[4/4, 300/938] loss: 0.012
[4/4, 400/938] loss: 0.015
[4/4, 500/938] loss: 0.018
[4/4, 600/938] loss: 0.023
[4/