# LeNet

This notebook trains the LeNet-5 neural network on the MNIST database.

Implementations in both PyTorch and JAX are provided under the respective subsections.

#### Imports

In [38]:
import numpy as np
import random

SEED = 12

random.seed(SEED)
np.random.seed(SEED)

### PyTorch

In [27]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from models.torch_lenet import TorchLeNet

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

#### Data

In [42]:
import data  # your module
import importlib

importlib.reload(data)

BATCH_SIZE = 64

train_dataset, test_dataset, val_dataset = data.get_MNIST()

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=1000, shuffle=False)

#### Model

In [43]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lenet5 = TorchLeNet().to(device) # TODO: Add params for activation, reproduce original init
lenet5.param_count()
lenet5.eval()

Parameters:  61706


TorchLeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

#### Training

In [44]:
# Hyperparams -> TODO: try to reproduce the original paper in this section
LEARNING_RATE = 1e-3
EPOCHS = 8

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(lenet5.parameters(), lr=LEARNING_RATE)

#TODO: tqdm
for epoch in range(EPOCHS):
    lenet5.train()
    epoch_loss = 0
    val_loss = 0

    for x,y in train_loader:
        x,y = x.to(device), y.to(device)

        optimizer.zero_grad()
        y_pred = lenet5(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss

    # Validation Loss
    lenet5.eval()
    with torch.no_grad():
        for x,y in val_loader:
            x,y = x.to(device), y.to(device)
            y_pred = lenet5(x)
            loss = criterion(y_pred, y)
            val_loss += loss

    print(f'Epoch {epoch} | Train Loss: {epoch_loss / len(train_loader)}, Val Loss: {val_loss / len(val_loader)}')

Epoch 0 | Train Loss: 839.9629516601562, Val Loss: 1.9005024433135986
Epoch 1 | Train Loss: 192.3130340576172, Val Loss: 1.0523744821548462
Epoch 2 | Train Loss: 117.95606231689453, Val Loss: 0.9231815934181213
Epoch 3 | Train Loss: 87.24773406982422, Val Loss: 0.5728389024734497
Epoch 4 | Train Loss: 72.04277801513672, Val Loss: 0.47332829236984253
Epoch 5 | Train Loss: 61.30828857421875, Val Loss: 0.46507734060287476


#### Evaluation

### JAX