# LeNet

This notebook trains the LeNet-5 neural network on the MNIST database.

Implementations in both PyTorch and JAX are provided under the respective subsections.

#### Imports

In [76]:
import numpy as np
import random
import time

SEED = 12

random.seed(SEED)
np.random.seed(SEED)

In [None]:
# TEMP: Reload /data after updating
import data  # your module
import importlib

importlib.reload(data)

### PyTorch

In [27]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from models.torch_lenet import TorchLeNet

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

#### Data

In [42]:
BATCH_SIZE = 64

train_dataset, test_dataset, val_dataset = data.get_MNIST()

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=1000, shuffle=False)

#### Model

In [43]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lenet5 = TorchLeNet().to(device) # TODO: Add params for activations, reproduce original init
lenet5.param_count()
lenet5.eval()

Parameters:  61706


TorchLeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

#### Training

In [None]:
# Hyperparams -> TODO: try to reproduce the original paper in this section
LEARNING_RATE = 1e-3
EPOCHS = 8

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(lenet5.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    lenet5.train()
    start_time = time.time()
    epoch_loss = 0
    val_loss = 0

    for x,y in train_loader:
        x,y = x.to(device), y.to(device)

        optimizer.zero_grad()
        y_pred = lenet5(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss

    end_time = time.time()
    epoch_duration = end_time - start_time

    # Validation Loss
    lenet5.eval()
    with torch.no_grad():
        for x,y in val_loader:
            x,y = x.to(device), y.to(device)
            y_pred = lenet5(x)
            loss = criterion(y_pred, y)
            val_loss += loss

    print(f'Epoch {epoch} | {epoch_duration:.2f}s | Train Loss: {epoch_loss / len(train_loader):.4f}, Val Loss: {val_loss / len(val_loader):.4f}')

    #TODO: Save model

Epoch 0 | 27.28s | Train Loss: 0.025, Val Loss: 0.046
Epoch 1 | 26.42s | Train Loss: 0.022, Val Loss: 0.046
Epoch 2 | 27.14s | Train Loss: 0.020, Val Loss: 0.042
Epoch 3 | 27.28s | Train Loss: 0.019, Val Loss: 0.045
Epoch 4 | 27.12s | Train Loss: 0.019, Val Loss: 0.041
Epoch 5 | 26.62s | Train Loss: 0.015, Val Loss: 0.047
Epoch 6 | 26.82s | Train Loss: 0.014, Val Loss: 0.047
Epoch 7 | 26.14s | Train Loss: 0.014, Val Loss: 0.040


#### Evaluation

In [80]:
test_x = test_dataset.data.unsqueeze(1).float()
test_y = test_dataset.targets

lenet5.eval()
with torch.no_grad():
    test_x, test_y = test_x.to(device), test_y.to(device)
    logits = lenet5(test_x)
    test_preds = logits.argmax(axis=1)
    accuracy = (test_preds == test_y).sum() / len(test_y)

print(f'Test Set Accuracy: {100 * accuracy:.2f}%')

Test Set Accuracy: 97.65%


### JAX