In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F

In [7]:
def convblock(in_channels, out_channels,kernel_size=3,stride=1,padding=1,pool=False):
    if pool:
        return nn.Sequential(
            nn.Conv2d(in_channels,out_channels, kernel_size=kernel_size, stride=stride,padding=padding),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2,stride=2),
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_channels,out_channels, kernel_size=kernel_size, stride=stride,padding=padding),
            nn.Tanh(),
    )

In [8]:
def linearblock(in_features, out_features):
    return nn.Sequential(
        nn.Linear(in_features=in_features, out_features=out_features),
        nn.Tanh(),
    )

In [22]:
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.block1 = convblock(1,6,5,1,0,pool=True)
        self.block2 = convblock(6,16,5,1,0,pool=True)
        self.block3 = convblock(16,120,5,1,0)
        self.pool =  nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = linearblock(120,84)
        self.fc2 = nn.Linear(in_features=84, out_features=10)
        self.softmax = nn.Softmax(dim=1)
    def forward(self,x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.pool(x)
        x = x.view(-1, 120)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x


In [23]:
model = LeNet5()

In [24]:
x = torch.randn(1,1,224,224)
output = model(x)
output.shape

torch.Size([1, 10])