<center><h1>LeNet</h1></center>

<center><p><a href="https://ieeexplore.ieee.org/abstract/document/726791">Gradient-based learning applied to document recognition</a></p></center>

<img src="https://production-media.paperswithcode.com/methods/LeNet_Original_Image_48T74Lc.jpg" width="800"/>

In [1]:
import torch
from torch import nn

# LeNet

In [2]:
class LeNet(nn.Module):
    def __init__(self, num_classes=10, init_weights=True):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5),  # 32 -> 28
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 28 -> 14
            nn.Conv2d(6, 16, kernel_size=5),  # 14 -> 10
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 10 -> 5
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

# Summary

## Data

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

data = torch.randn((32, 1, 32, 32)).to(device)

## LeNet

In [4]:
from torchkeras import summary

net = LeNet(num_classes=10).to(device)

summary(net, input_data=data)
del net

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
Conv2d-1                             [-1, 6, 28, 28]                  156
Sigmoid-2                            [-1, 6, 28, 28]                    0
MaxPool2d-3                          [-1, 6, 14, 14]                    0
Conv2d-4                            [-1, 16, 10, 10]                2,416
Sigmoid-5                           [-1, 16, 10, 10]                    0
MaxPool2d-6                           [-1, 16, 5, 5]                    0
Flatten-7                                  [-1, 400]                    0
Linear-8                                   [-1, 120]               48,120
Sigmoid-9                                  [-1, 120]                    0
Linear-10                                   [-1, 84]               10,164
Sigmoid-11                                  [-1, 84]                    0
Linear-12                            