In [1]:
# python imports
import os
from tqdm import tqdm

# torch imports
import torch
import torch.nn as nn
import torch.optim as optim

# helper functions for computer vision
import torchvision
import torchvision.transforms as transforms

In [2]:
class LeNet(nn.Module):
    def __init__(self, input_shape=(32, 32), num_classes=100):
        super(LeNet, self).__init__()
        # certain definitions

        # step 1
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, stride=1)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(2, 2)

        # step 2
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(2, 2)

        # step 3
        self.flatten = nn.Flatten()

        # step 4
        self.fc1 = nn.Linear(16 * 5 * 5, 256)
        self.relu3 = nn.ReLU()

        # step 5
        self.fc2 = nn.Linear(256, 128)
        self.relu4 = nn.ReLU()

        # step 6
        self.fc3 = nn.Linear(128, num_classes)

    def forward(self, x):
        shape_dict = {}

        # step 1
        out = self.conv1(x)
        out = self.relu1(out)
        out = self.maxpool1(out)
        shape_dict[1] = out.shape

        # step 2
        out = self.conv2(out)
        out = self.relu2(out)
        out = self.maxpool2(out)
        shape_dict[2] = out.shape

        # step 3
        out = self.flatten(out)
        shape_dict[3] = out.shape

        # step 4
        out = self.fc1(out)
        shape_dict[4] = out.shape

        # step 5
        out = self.fc2(out)
        shape_dict[5] = out.shape

        # step 6
        out = self.fc3(out)
        shape_dict[6] = out.shape

        return out, shape_dict

In [3]:
test = LeNet()

test.forward(torch.randn(5, 3, 32, 32))

(tensor([[-0.1837,  0.1295, -0.0351,  0.1377, -0.0972,  0.1317, -0.0026,  0.1051,
          -0.0641,  0.0285,  0.0568, -0.1232,  0.0380,  0.0971,  0.0638, -0.0405,
           0.0335,  0.0172,  0.2030, -0.0460,  0.0230, -0.0184,  0.0760,  0.0584,
          -0.1500,  0.1259, -0.0769,  0.0636, -0.0432, -0.0532,  0.2191, -0.0071,
           0.1280,  0.0541, -0.0425, -0.0135, -0.0776, -0.0126,  0.1713,  0.1308,
           0.0434, -0.0332, -0.0636, -0.0511, -0.0220, -0.0062, -0.1975,  0.0645,
          -0.0475, -0.1438,  0.0226,  0.0162,  0.1012, -0.0087,  0.0698, -0.0993,
          -0.1616, -0.0054, -0.0317, -0.0298, -0.1672,  0.1127,  0.0771,  0.1481,
          -0.2321, -0.0005, -0.0594, -0.1144,  0.1837, -0.1163,  0.1377, -0.0728,
           0.0778, -0.0370, -0.0991,  0.1671, -0.0113, -0.0900,  0.0388,  0.0042,
          -0.1313, -0.0122, -0.2741, -0.1449, -0.0367,  0.0769, -0.2415,  0.0454,
          -0.0074, -0.0296,  0.1728,  0.0551,  0.0681, -0.0243,  0.0491,  0.0987,
           0.047

In [6]:
def count_model_params():
    '''
    return the number of trainable parameters of LeNet.
    '''
    model = LeNet()
    model_params = 0.0

    for param in model.parameters():
        model_params += param.numel()
        

    return model_params

In [14]:
model.named_parameters

<bound method Module.named_parameters of LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=400, out_features=256, bias=True)
  (relu3): ReLU()
  (fc2): Linear(in_features=256, out_features=128, bias=True)
  (relu4): ReLU()
  (fc3): Linear(in_features=128, out_features=100, bias=True)
)>

In [16]:
model = LeNet()
model_params = 0.0

for param in model.parameters():
    print(param.numel())
    model_params += param.numel()

model_params/1e6

450
6
2400
16
102400
256
32768
128
12800
100


0.151324