We implement LeNet-5 in its modern form after its evolution from LeCun's original design in the identification of handwritten zip codes. Our implementation is based on the architecture as presented by the following [image](https://upload.wikimedia.org/wikipedia/commons/c/cc/Comparison_image_neural_networks.svg), modified slightly for a different input size. We apply it to the CIFAR10 dataset, mainly to see how it stacks up against more modern architectures.

In [2]:
device = (
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

print(f"Using {device}")

if device == "cuda":
    print(torch.cuda.get_device_name(torch.cuda.current_device()))

Using cuda
Quadro T2000


In [38]:
class LeNet(nn.Module):
    # in_channels: number of input channels
    def __init__(self, in_channels: int) -> None:
        super().__init__()

        self.sigmoid = nn.Sigmoid()
        # Assume input is 32 x 32 x 3
        # 32 x 32 x 6
        self.conv0 = nn.Conv2d(in_channels, 6, kernel_size=5, padding=2)
        # 16 x 16 x 6
        self.avgpool0 = nn.AvgPool2d(kernel_size=2, stride=2)
        # 12 x 12 x 16
        self.conv1 = nn.Conv2d(6, 16, kernel_size=5)
        # 6 x 6 x 16
        self.avgpool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        # See linked image
        self.layer0 = nn.Linear(576, 120)
        self.layer1 = nn.Linear(120, 84)
        self.layer2 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv0(x)
        x = self.sigmoid(x)
        x = self.avgpool0(x)
        x = self.conv1(x)
        x = self.sigmoid(x)
        x = self.avgpool1(x)
        x = x.reshape(x.shape[0], (-1))
        x = self.layer0(x)
        x = self.sigmoid(x)
        x = self.layer1(x)
        x = self.sigmoid(x)
        x = self.layer2(x)
        return x

In [39]:
transform_CIFAR10 = transforms.Compose([
    transforms.ToTensor(),
    # Mean and standard deviation for CIFAR10 dataset
    # Sourced from gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    )
])

train_data = datasets.CIFAR10(
    root="data",
    train=True,
    transform=transform_CIFAR10,
    download=True
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    transform=transform_CIFAR10,
    download=True
)

BATCH_SIZE = 64
VALID_SIZE = 0.9

train_indices = list(range(len(train_data)))
np.random.shuffle(train_indices)
valid_split = int(len(train_data) * VALID_SIZE)
valid_indices = train_indices[valid_split:]
train_indices = train_indices[:valid_split]
valid_data = Subset(train_data, valid_indices)
train_data = Subset(train_data, train_indices)

print(f"Training data: {len(train_data)}")
print(f"Validation data: {len(valid_data)}")
print(f"Test data: {len(test_data)}")

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

Files already downloaded and verified
Files already downloaded and verified
Training data: 45000
Validation data: 5000
Test data: 10000


In [40]:
CLASSES = 10
EPOCHS = 20
LR = 0.1
WEIGHT_DECAY = 0.001
MOMENTUM = 0.9

model = LeNet(3)
model = nn.DataParallel(model)
model.to(device)

metric = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=LR,
    weight_decay=WEIGHT_DECAY,
    momentum=MOMENTUM
)

In [41]:
def train(
    loader: DataLoader,
    model: nn.Module,
    metric: nn.Module,
    optimizer: torch.optim.Optimizer
) -> None:
    total = len(loader.dataset)
    model.train()

    for batch, (x, y) in enumerate(loader):
        x = x.to(device)
        y = y.to(device)

        pred = model(x)
        loss = metric(pred, y)

        if batch % 100 == 99:
            progress = (batch + 1) * len(x)
            print(f"\tLoss: {loss.item():>7f} [{progress:>5d} / {total:>5d}]")
            print(f"\t\tLearning rate: {optimizer.param_groups[0]['lr']:>8f}")

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [42]:
def test(
    loader: DataLoader,
    model: nn.Module,
    metric: nn.Module,
) -> None:
    total = len(loader.dataset)
    batch_total = len(loader)
    total_loss = 0
    total_correct = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            pred = model(x)
            total_loss += metric(pred, y).item()
            pred_correct = pred.argmax(1) == y
            total_correct += pred_correct.type(torch.float).sum().item()

        total_loss /= batch_total
        total_correct /= total
        print(f"\tAccuracy: {(100 * total_correct):>0.1f}%")
        print(f"\tAverage loss: {total_loss:>8f}")

In [43]:
for t in range(EPOCHS):
    print(f"Epoch: {t + 1}")
    train(train_loader, model, metric, optimizer)
    test(valid_loader, model, metric)

torch.save(model.state_dict(), "lenet5.pth")

Epoch: 1
	Loss: 2.326713 [ 6400 / 45000]
		Learning rate: 0.100000
	Loss: 2.283743 [12800 / 45000]
		Learning rate: 0.100000
	Loss: 2.315441 [19200 / 45000]
		Learning rate: 0.100000
	Loss: 2.287435 [25600 / 45000]
		Learning rate: 0.100000
	Loss: 2.318593 [32000 / 45000]
		Learning rate: 0.100000
	Loss: 2.312937 [38400 / 45000]
		Learning rate: 0.100000
	Loss: 2.315701 [44800 / 45000]
		Learning rate: 0.100000
	Accuracy: 9.4%
	Average loss: 2.303887
Epoch: 2
	Loss: 2.307292 [ 6400 / 45000]
		Learning rate: 0.100000
	Loss: 2.280637 [12800 / 45000]
		Learning rate: 0.100000
	Loss: 2.091264 [19200 / 45000]
		Learning rate: 0.100000
	Loss: 2.229997 [25600 / 45000]
		Learning rate: 0.100000
	Loss: 1.960535 [32000 / 45000]
		Learning rate: 0.100000
	Loss: 2.018706 [38400 / 45000]
		Learning rate: 0.100000
	Loss: 2.129301 [44800 / 45000]
		Learning rate: 0.100000
	Accuracy: 19.0%
	Average loss: 2.074731
Epoch: 3
	Loss: 1.974731 [ 6400 / 45000]
		Learning rate: 0.100000
	Loss: 2.030190 [12800

In [44]:
test(test_loader, model, metric)

	Accuracy: 49.8%
	Average loss: 1.362571


Considering the small architecture in comparison to the complexity of the dataset and the total lack of learning rate optimization, or that of any other hyperparameter, this is not bad.