In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from torchsummary import summary

In [3]:
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet,self).__init__()
    self.conv1=nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5,5),stride=(1,1),padding=(0,0))
    self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=(5,5),stride=(1,1),padding=(0,0))
    self.fc1=nn.Linear(in_features=16*5*5,out_features=120)
    self.fc2=nn.Linear(in_features=120,out_features=84)
    self.fc3=nn.Linear(in_features=84,out_features=10)
  def forward(self,input):
    output=F.relu(self.conv1(input))
    output=F.max_pool2d(output,kernel_size=(2,2),stride=(2,2))
    output=F.relu(self.conv2(output))
    output=F.max_pool2d(output,kernel_size=(2,2),stride=(2,2))
    output=output.flatten(start_dim=1)
    output=F.relu(self.fc1(output))
    output=F.relu(self.fc2(output))
    output=self.fc3(output)
    return output

In [6]:
model = LeNet()
print(model)
summary(model, (1,32,32))

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 28, 28]             156
            Conv2d-2           [-1, 16, 10, 10]           2,416
            Linear-3                  [-1, 120]          48,120
            Linear-4                   [-1, 84]          10,164
            Linear-5                   [-1, 10]             850
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.05
Params size (MB): 0.24
Estimated Total Size (

In [14]:
def test():
  model = LeNet()
  x = torch.randn(64,1,32,32)
  output = model(x)
  print("output.shape : ",output.size())

In [15]:
test()

output.shape :  torch.Size([64, 10])
