In [7]:
# Dependencies
import torch
import torch.nn as nn  # For torch.nn.Module object, of which each model is a subclass
import torch.nn.functional as F  # For activation functions and other related utilities

In [25]:
class LeNet(nn.Module):
    r"""Implements the basic LeNet architecture.
    """

    #
    def __init__(self):
        r"""The initializer.
        """
        super(LeNet, self).__init__()
        # Define layers and other components
        self.conv1 = nn.Conv2d(
            in_channels=1, out_channels=6, kernel_size=(3, 3)
        )  # [28, 28] -> [26, 26]
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2))  # [26, 26] -> [13, 13]
        self.conv2 = nn.Conv2d(
            in_channels=6, out_channels=16, kernel_size=(3, 3)
        )  # [13, 13] -> [11, 11]
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2))  # [12, 12] -> [5, 5]
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    #
    def forward(self, x):
        r"""Implements the forward pass of the model.

        Parameters
        ----------
        x:
            The input.
            SHAPE: [<batch>, <in_channel=1>, <height=28>, <width=28>].

        Returns
        -------
        logits (implicit):
            The logits for the input.
            SHAPE: [<batch>, <num_classes=10>].
        """
        # Conv and pool layers.
        y1 = self.maxpool1(F.relu(self.conv1(x)))
        y2 = self.maxpool2(F.relu(self.conv2(y1)))
        # Linearization followed by fully connected layers.
        y3 = F.relu(self.fc1(y2.view(y2.shape[0], -1)))
        y4 = F.relu(self.fc2(y3))
        return self.fc3(y4)

In [26]:
# Checking architecture
model = LeNet()
print(model)

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (maxpool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (maxpool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [29]:
ip = torch.Tensor(16, 1, 28, 28).normal_()
print('input shape:', ip.shape)
op = model(ip)
print('output shape:', op.shape)

input shape: torch.Size([16, 1, 28, 28])
output shape: torch.Size([16, 10])
