# Classify Digits with the MNIST Dataset

## Overview

## Models

In [None]:
import torch
import torch.nn as nn
import tqdm as tqdm # fancy progress bars
import matplotlib.pyplot as plt
import torchvision # for loading MNIST dataset and visualizing images


In [None]:
import onnx

def save_onnx(model, filename):
    base_path = "projects/blog/2-mnist/models/"
    dummy_input = torch.randn(1, 1, 28, 28)
    torch.onnx.export(model, dummy_input, base_path + filename, verbose=True, input_names=['images'], output_names=['classes'], export_params=True)


In [None]:
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False)


In [None]:
print("Train dataset size: ", len(train_dataset))
print("Test dataset size: ", len(test_dataset))


In [None]:
# save the first image in the dataset
image, label = train_dataset[0]
image = image
print("Image shape: ", image.shape)

torchvision.utils.save_image(image, "projects/blog/2-mnist/sample_image.png")


In [None]:
def train(
    model,
    train_loader,
    test_loader,
    num_epochs=10,
    lr=0.1,
):
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    
    # Statistics dictionary
    stats = {"train_losses": [0], "test_losses": [0], "global_step": 0}

    # Training loop
    with tqdm.trange(
        num_epochs, desc="Training", unit="epoch"
    ) as epochs:  # Adjust the number of epochs
        for epoch in epochs:
            model.train()
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)

                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                stats["train_losses"].append(loss.item())
                stats["global_step"] += 1

                epochs.set_postfix_str(f"train loss: {stats['train_losses'][-1]:.6f}, test loss: {stats['test_losses'][-1]:.6f}")

            model.eval()
            test_loss = 0
            with torch.no_grad():
                for images, labels in test_loader:
                    images, labels = images.to(device), labels.to(device)

                    outputs = model(images)

                    loss = criterion(outputs, labels)

                    test_loss += loss.item()

            stats["test_losses"].append(test_loss / len(test_loader))

    return stats

def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)

            _, predicted = torch.max(outputs, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

def plot_stats(stats):
    # Plot train and test losses
    plt.figure(figsize=(20, 4))
    plt.plot(stats["train_losses"], label="Train Loss")
    plt.plot(
        range(
            len(train_loader),
            len(stats["train_losses"]) + len(train_loader),
            len(train_loader),
        ),
        stats["test_losses"],
        label="Test Loss",
        c="r",
    )
    plt.title("Train and Test Loss")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.yscale("log")
    plt.legend()
    plt.grid(True)
    plt.show()


### Feed Forward Neural Network

#### Baseline Linear Classifier

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

print(f"Using device: {device}")


In [None]:
torch.manual_seed(0) # for reproducibility if you care about that

# Define the model
model = nn.Sequential(
    nn.Flatten(), nn.Linear(28 * 28, 10)  # MNIST images are 28x28 and have 10 classes
).to(device)

stats = train(model, train_loader, test_loader)

plot_stats(stats)

print(f"Accuracy: {test(model, test_loader):.4f}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: {sum(p.numel() * p.element_size() for p in model.parameters()):,} bytes")

save_onnx(model, "baseline-linear-classifer.onnx")


#### One-Hidden-Layer Fully Connected Multilayer NN

In [None]:
torch.manual_seed(0) # for reproducibility if you care about that

# Define the model
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 256),
    nn.Tanh(),
    nn.Linear(256, 10),
).to(device)

stats = train(model, train_loader, test_loader)

plot_stats(stats)

print(f"Accuracy: {test(model, test_loader):.4f}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: {sum(p.numel() * p.element_size() for p in model.parameters()):,} bytes")

save_onnx(model, "one-hidden-layer.onnx")


#### Two-Hidden-Layer Fully Connected Multilayer NN

In [None]:
torch.manual_seed(0) # for reproducibility if you care about that

# Define the model
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 256),
    nn.Tanh(),
    nn.Linear(256, 128),
    nn.Tanh(),
    nn.Linear(128, 10),
).to(device)

stats = train(model, train_loader, test_loader)

plot_stats(stats)

print(f"Accuracy: {test(model, test_loader):.4f}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: {sum(p.numel() * p.element_size() for p in model.parameters()):,} bytes")

save_onnx(model, "two-hidden-layer.onnx")


### Convolutional Neural Network

#### LeNet-1

In [None]:
torch.manual_seed(0) # for reproducibility if you care about that

# Define the model
model = nn.Sequential(
    nn.Conv2d(1, 4, kernel_size=5),
    nn.Tanh(),
    nn.AvgPool2d(2),
    nn.Conv2d(4, 12, kernel_size=5),
    nn.Tanh(),
    nn.AvgPool2d(2),
    nn.Flatten(),
    nn.Linear(192, 10),
).to(device)

stats = train(model, train_loader, test_loader)

plot_stats(stats)

print(f"Accuracy: {test(model, test_loader):.4f}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: {sum(p.numel() * p.element_size() for p in model.parameters()):,} bytes")

save_onnx(model, "lenet-1.onnx")


#### LeNet-4

In [None]:
torch.manual_seed(0) # for reproducibility if you care about that

# Define the model
model = nn.Sequential(
    nn.Conv2d(1, 4, kernel_size=5),
    nn.Tanh(),
    nn.AvgPool2d(2),
    nn.Conv2d(4, 16, kernel_size=5),
    nn.Tanh(),
    nn.AvgPool2d(2),
    nn.Flatten(),
    nn.Linear(256, 120),
    nn.Tanh(),
    nn.Linear(120, 10),
).to(device)

stats = train(model, train_loader, test_loader)

plot_stats(stats)

print(f"Accuracy: {test(model, test_loader):.4f}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: {sum(p.numel() * p.element_size() for p in model.parameters()):,} bytes")

save_onnx(model, "lenet-4.onnx")


#### LeNet-5

In [None]:
torch.manual_seed(0) # for reproducibility if you care about that

# Define the model
model = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5),
    nn.Tanh(),
    nn.AvgPool2d(2),
    nn.Conv2d(6, 16, kernel_size=5),
    nn.Tanh(),
    nn.AvgPool2d(2),
    nn.Flatten(),
    nn.Linear(256, 120),
    nn.Tanh(),
    nn.Linear(120, 84),
    nn.Tanh(),
    nn.Linear(84, 10),
).to(device)

stats = train(model, train_loader, test_loader)

plot_stats(stats)

print(f"Accuracy: {test(model, test_loader):.4f}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: {sum(p.numel() * p.element_size() for p in model.parameters()):,} bytes")

save_onnx(model, "lenet-5.onnx")


##### Optimized Training

In [None]:
def train(
    model,
    train_loader,
    test_loader,
    num_epochs=10,
    lr=0.1,
    momentum=0.9, # Add 0.9 momentum
):
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum) # Add momentum
    
    # Statistics dictionary
    stats = {"train_losses": [], "test_losses": [], "global_step": 0}

    # Training loop
    with tqdm.trange(
        num_epochs, desc="epochs", unit="epoch"
    ) as epochs:  # Adjust the number of epochs
        for epoch in epochs:
            model.train()
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)

                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                stats["train_losses"].append(loss.item())
                stats["global_step"] += 1

                epochs.set_postfix_str(f"loss: {loss.item():.4f}")

            model.eval()
            test_loss = 0
            with torch.no_grad():
                for images, labels in test_loader:
                    images, labels = images.to(device), labels.to(device)

                    outputs = model(images)

                    loss = criterion(outputs, labels)

                    test_loss += loss.item()

            stats["test_losses"].append(test_loss / len(test_loader))

    return stats


In [None]:
torch.manual_seed(0) # for reproducibility if you care about that

# Define the model
model = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5),
    nn.Tanh(),
    nn.AvgPool2d(2),
    nn.Conv2d(6, 16, kernel_size=5),
    nn.Tanh(),
    nn.AvgPool2d(2),
    nn.Flatten(),
    nn.Linear(256, 120),
    nn.Tanh(),
    nn.Linear(120, 84),
    nn.Tanh(),
    nn.Linear(84, 10),
).to(device)

stats = train(model, train_loader, test_loader)

plot_stats(stats)

print(f"Accuracy: {test(model, test_loader):.4f}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: {sum(p.numel() * p.element_size() for p in model.parameters()):,} bytes")

save_onnx(model, "lenet-5-momentum.onnx")


##### Smaller Batch Size

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False) # Doesnt matter if batch size is large


In [None]:
torch.manual_seed(0) # for reproducibility if you care about that

# Define the model
model = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5),
    nn.Tanh(),
    nn.AvgPool2d(2),
    nn.Conv2d(6, 16, kernel_size=5),
    nn.Tanh(),
    nn.AvgPool2d(2),
    nn.Flatten(),
    nn.Linear(256, 120),
    nn.Tanh(),
    nn.Linear(120, 84),
    nn.Tanh(),
    nn.Linear(84, 10),
).to(device)

stats = train(model, train_loader, test_loader)

plot_stats(stats)

print(f"Accuracy: {test(model, test_loader):.4f}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: {sum(p.numel() * p.element_size() for p in model.parameters()):,} bytes")

save_onnx(model, "lenet-5-momentum-small-batch.onnx")


##### 2020's

In [None]:
torch.manual_seed(0) # for reproducibility if you care about that

# Define the model
model = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=5),
    nn.ReLU(),
    nn.AvgPool2d(2),
    nn.Conv2d(64, 128, kernel_size=5),
    nn.ReLU(),
    nn.AvgPool2d(2),
    nn.Flatten(),
    nn.Linear(128*4*4, 256),
    nn.ReLU(),
    nn.Linear(256, 120),
    nn.ReLU(),
    nn.Linear(120, 10),
).to(device)

stats = train(model, train_loader, test_loader)

plot_stats(stats)

print(f"Accuracy: {test(model, test_loader):.4f}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: {sum(p.numel() * p.element_size() for p in model.parameters()):,} bytes")

save_onnx(model, "lenet-5-momentum-small-batch-relu-wide.onnx")
