# LeNet

This notebook trains the LeNet-5 neural network on the MNIST database.

Implementations in both PyTorch and JAX are provided under the respective subsections.

#### Imports

In [38]:
import numpy as np
import random

SEED = 12

random.seed(SEED)
np.random.seed(SEED)

### PyTorch

In [27]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from models.torch_lenet import TorchLeNet

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

#### Data

In [41]:
import data  # your module
import importlib

importlib.reload(data)

BATCH_SIZE = 64

train_dataset, test_dataset, val_dataset = data.get_MNIST()

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

#### Model

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lenet5 = TorchLeNet().to(device) # TODO: Add params for activation, reproduce original init
lenet5.param_count()
lenet5.eval()

Parameters:  61706


TorchLeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

#### Training

In [28]:
# Hyperparams -> TODO: try to reproduce the original paper in this section
LEARNING_RATE = 1e-3
EPOCHS = 6

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(lenet5.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    lenet5.train()
    epoch_loss = 0

    for x,y in train_loader:
        x,y = x.to(device), y.to(device)

        optimizer.zero_grad()
        y_pred = lenet5(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss

    print(f'Loss at epoch {epoch}: {epoch_loss}')

Loss at epoch 0: 1030.6214599609375
Loss at epoch 1: 233.1431427001953
Loss at epoch 2: 135.980712890625
Loss at epoch 3: 101.42192840576172
Loss at epoch 4: 80.53797912597656
Loss at epoch 5: 68.14480590820312


#### Evaluation

### JAX