# LeNet

This notebook trains the LeNet-5 neural network on the MNIST database.

Implementations in both PyTorch and JAX are provided under the respective subsections.

#### Imports

In [4]:
import numpy as np
import random

SEED = 12

random.seed(SEED)
np.random.seed(SEED)

### PyTorch

In [5]:
import torch
import torchvision;
from models.torch_lenet import TorchLeNet

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
lenet5 = TorchLeNet()
lenet5.param_count()
lenet5.eval()

Parameters:  61706


TorchLeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [33]:
mnist_train = torchvision.datasets.MNIST('data', train=True, download=True)
mnist_test = torchvision.datasets.MNIST('data', train=False, download=True)

In [40]:
train_images = mnist_train.data.numpy() / 255.0  # Images (28x28 grayscale)
train_labels = mnist_train.targets.numpy()  # Labels (0-9)

test_images = mnist_test.data.numpy() / 255.0  # Images (28x28 grayscale)
test_labels = mnist_test.targets.numpy()  # Labels (0-9)

Train images shape: (60000, 28, 28)
Train labels shape: (60000,)
Test images shape: (10000, 28, 28)
Test labels shape: (10000,)


### JAX