In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

In [None]:
## before LeNet, there is no Relu and MaxPooling. Use Relu and MaxPooling will have a better result

class LeNet(nn.Module):
    def __init__(self):
        # input shape : (1,28,28)
        super().__init__()
        self.conv1 = nn.Conv2d(1,6,5,padding=2) # shape (6, 28, 28)
        self.pool = nn.AvgPool2d(2, 2)          # shape (6, 14, 14)
        self.conv2 = nn.Conv2d(6, 16, 5)        # shape (16, 10, 10 )
        self.pool2 = nn.AvgPool2d(2, 2)         # shape (16, 5, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.pool(self.sigmoid(self.conv1(x)))
        x = self.pool2(self.sigmoid(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.sigmoid(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return x


net = LeNet()

In [None]:
!pip install -q torchinfo
from torchinfo import summary

In [None]:
summary(model=net, 
        input_size=(1, 1, 28, 28), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
LeNet (LeNet)                            [1, 1, 28, 28]       [1, 10]              --                   True
├─Conv2d (conv1)                         [1, 1, 28, 28]       [1, 6, 28, 28]       156                  True
├─Sigmoid (sigmoid)                      [1, 6, 28, 28]       [1, 6, 28, 28]       --                   --
├─AvgPool2d (pool)                       [1, 6, 28, 28]       [1, 6, 14, 14]       --                   --
├─Conv2d (conv2)                         [1, 6, 14, 14]       [1, 16, 10, 10]      2,416                True
├─Sigmoid (sigmoid)                      [1, 16, 10, 10]      [1, 16, 10, 10]      --                   --
├─AvgPool2d (pool2)                      [1, 16, 10, 10]      [1, 16, 5, 5]        --                   --
├─Linear (fc1)                           [1, 400]             [1, 120]             48,120               True
├─Sigmoid (sigmoid)   

In [None]:
class BetterLeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, padding = 2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
summary(model=BetterLeNet(), 
        input_size=(1, 1, 28, 28), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
BetterLeNet (BetterLeNet)                [1, 1, 28, 28]       [1, 10]              --                   True
├─Conv2d (conv1)                         [1, 1, 28, 28]       [1, 6, 28, 28]       156                  True
├─MaxPool2d (pool)                       [1, 6, 28, 28]       [1, 6, 14, 14]       --                   --
├─Conv2d (conv2)                         [1, 6, 14, 14]       [1, 16, 10, 10]      2,416                True
├─MaxPool2d (pool)                       [1, 16, 10, 10]      [1, 16, 5, 5]        --                   --
├─Linear (fc1)                           [1, 400]             [1, 120]             48,120               True
├─Linear (fc2)                           [1, 120]             [1, 84]              10,164               True
├─Linear (fc3)                           [1, 84]              [1, 10]              850                  True
Total params: 61,7