In [2]:
import torch
import torch.nn.functional as F  # Parameterless functions, like (some) activation functions
import torchvision.datasets as datasets  # Standard datasets
import torchvision.transforms as transforms  # Transformations we can perform on our dataset for augmentation
from torch import optim  # For optimizers like SGD, Adam, etc.
from torch import nn  # All neural network modules
from torch.utils.data import (
    DataLoader,
)  # Gives easier dataset managment by creating mini batches etc.
from tqdm import tqdm  # For nice progress bar!
import matplotlib.pyplot as plt

# Architecture

In [134]:
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=6,
            kernel_size=5,
            stride=1,
            padding='valid',
        )
        self.conv2 = nn.Conv2d(
            in_channels=6,
            out_channels=16,
            kernel_size=5,
            stride=1,
            padding='valid',
        )
        self.linear1 = nn.Linear(256, 120)
        self.linear2 = nn.Linear(120, 84)
        self.linear3 = nn.Linear(84, 10)

    def forward(self, inputs):
        inputs = self.relu(self.conv1(inputs))
        inputs = self.pool(inputs)
        inputs = self.relu(self.conv2(inputs))
        inputs = self.pool(inputs)
        # inputs = self.relu(self.conv3(inputs))  # num_examples x 120 x 1 x 1 --> num_examples x 120
        inputs = inputs.reshape(inputs.shape[0], -1)
        inputs = self.relu(self.linear1(inputs))
        inputs = self.linear2(inputs)
        inputs = self.linear3(inputs)
        return inputs

# Test

In [150]:
batch_size = 32
learning_rate = 0.001
num_epochs = 5

In [152]:
train_dataset = datasets.MNIST(
    root="dataset/", train=True, transform=transforms.ToTensor(), download=True
)
test_dataset = datasets.MNIST(
    root="dataset/", train=False, transform=transforms.ToTensor(), download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [154]:
#Initial model
le_net = LeNet()

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(le_net.parameters(), lr=learning_rate)

In [None]:
# Train Network
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):
        # Forward
        scores = le_net(data)
        loss = criterion(scores, targets)

        # Backward
        optimizer.zero_grad()
        loss.backward()

        # Gradient descent or adam step
        optimizer.step()

In [172]:
def check_accuracy(loader, model):
    """
    Check accuracy of our trained model given a loader and a model

    Parameters:
        loader: torch.utils.data.DataLoader
            A loader for the dataset you want to check accuracy on
        model: nn.Module
            The model you want to check accuracy on

    Returns:
        acc: float
            The accuracy of the model on the dataset given by the loader
    """

    num_correct = 0
    num_samples = 0
    model.eval()

    # We don't need to keep track of gradients here so we wrap it in torch.no_grad()
    with torch.no_grad():
        # Loop through the data
        for x, y in loader:
            # Forward pass 
            scores = model(x) #64 x 10 matrix 64 is the batch size, 10 is num_classes
            _, predictions = scores.max(1) #give us the maximum value of each row of the matrix and its indice, we can use scores.argmax(1) instead

            # Check how many we got correct
            num_correct += (predictions == y).sum()

            # Keep track of number of samples
            num_samples += predictions.size(0)

    model.train()
    return num_correct / num_samples


In [174]:
# Check accuracy on training & test to see how good our model
print(f"Accuracy on training set: {check_accuracy(train_loader, le_net)*100:.2f}")
print(f"Accuracy on test set: {check_accuracy(test_loader, le_net)*100:.2f}")

Accuracy on training set: 56.75
Accuracy on test set: 57.19
