# LeNet

This notebook trains the LeNet-5 neural network on the MNIST database.

Implementations in both PyTorch and JAX are provided under the respective subsections.

#### Imports

In [76]:
import numpy as np
import random
import time

SEED = 12

random.seed(SEED)
np.random.seed(SEED)

In [None]:
# TEMP: Reload /data after updating
import data  # your module
import importlib

importlib.reload(data)

### PyTorch

In [83]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from models.torch_lenet import TorchLeNet

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

TORCH_PATH = './data/lenet5.pth' # Change filename for different models

#### Data

In [84]:
BATCH_SIZE = 64

train_dataset, test_dataset, val_dataset = data.get_MNIST()

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=1000, shuffle=False)

#### Model

In [85]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lenet5 = TorchLeNet().to(device) # TODO: Add params for activations, reproduce original init
lenet5.param_count()
lenet5.eval()

Parameters:  61706


TorchLeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

#### Training

In [87]:
# Hyperparams -> TODO: try to reproduce the original paper in this section
LEARNING_RATE = 1e-3
EPOCHS = 10

train_time = 0

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(lenet5.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    lenet5.train()
    start_time = time.time()
    epoch_loss = 0
    val_loss = 0

    for x,y in train_loader:
        x,y = x.to(device), y.to(device)

        optimizer.zero_grad()
        y_pred = lenet5(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss

    end_time = time.time()
    epoch_duration = end_time - start_time
    train_time += epoch_duration

    # Validation Loss
    lenet5.eval()
    with torch.no_grad():
        for x,y in val_loader:
            x,y = x.to(device), y.to(device)
            y_pred = lenet5(x)
            loss = criterion(y_pred, y)
            val_loss += loss

    print(f'Epoch {epoch} | {epoch_duration:.2f}s | Train Loss: {epoch_loss / len(train_loader):.4f}, Val Loss: {val_loss / len(val_loader):.4f}')

print(f'Total training time: {train_time:.2f}s')

torch.save(lenet5.state_dict(), TORCH_PATH)
print(f'Saved model at {TORCH_PATH}.')

Epoch 0 | 25.87s | Train Loss: 0.0402, Val Loss: 0.0494
Epoch 1 | 26.65s | Train Loss: 0.0352, Val Loss: 0.0526
Epoch 2 | 26.02s | Train Loss: 0.0328, Val Loss: 0.0455
Epoch 3 | 25.80s | Train Loss: 0.0317, Val Loss: 0.0533
Epoch 4 | 25.96s | Train Loss: 0.0275, Val Loss: 0.0440
Epoch 5 | 26.23s | Train Loss: 0.0262, Val Loss: 0.0466
Epoch 6 | 26.67s | Train Loss: 0.0231, Val Loss: 0.0481
Epoch 7 | 26.04s | Train Loss: 0.0224, Val Loss: 0.0521
Epoch 8 | 26.34s | Train Loss: 0.0199, Val Loss: 0.0437
Epoch 9 | 27.30s | Train Loss: 0.0187, Val Loss: 0.0447
Total training time: 262.88s
Saved model at ./data/lenet5.pth.


#### Evaluation

In [88]:
# Load train model
lenet5 = TorchLeNet().to(device) # TODO: Add params for activations, reproduce original init
lenet5.load_state_dict(torch.load(TORCH_PATH))
lenet5.eval()

TorchLeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [89]:
test_x = test_dataset.data.unsqueeze(1).float()
test_y = test_dataset.targets

lenet5.eval()
with torch.no_grad():
    test_x, test_y = test_x.to(device), test_y.to(device)
    logits = lenet5(test_x)
    test_preds = logits.argmax(axis=1)
    accuracy = (test_preds == test_y).sum() / len(test_y)

print(f'Test Set Accuracy: {100 * accuracy:.2f}%')

Test Set Accuracy: 97.27%


### JAX