In [39]:
from torchsummary import summary
import torch
from torch.nn import MaxPool2d, Linear, ReLU, BatchNorm2d, Sequential, Conv2d, Dropout


In [44]:
class Fmcnn1(torch.nn.Module):

    def __init__(self):
        super().__init__()
        
        self.cnn_layers = Sequential(

            Conv2d(3, 8, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(8),
            ReLU(inplace=True),
            #MaxPool2d(kernel_size=2, stride=1),
            
            
            Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
            BatchNorm2d(16),
            ReLU(inplace=True),


            Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(16),
            ReLU(inplace=True),
            #MaxPool2d(kernel_size=2, stride=2),
            

            Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            BatchNorm2d(32),
            ReLU(inplace=True),
            #MaxPool2d(kernel_size=2, stride=2),
            
            Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(64),
            ReLU(inplace=True),
            MaxPool2d(kernel_size=2, stride=2),
        )

        self.linear_layers = Sequential(
            Linear(64 * 16 * 16, 64 * 16 * 16),
            Dropout(0.5),
            Linear(64 * 16 * 16, 32 * 16 * 16),
            Dropout(0.5),
            Linear(32 * 16 * 16, 5)
        )

   
    def forward(self, x):
        x = self.cnn_layers(x)
        x = x.view(x.size(0), -1)
        x = self.linear_layers(x)
        return x


In [45]:
devnet = Fmcnn1()
summary(devnet, (3, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 8, 128, 128]             224
       BatchNorm2d-2          [-1, 8, 128, 128]              16
              ReLU-3          [-1, 8, 128, 128]               0
            Conv2d-4           [-1, 16, 64, 64]           1,168
       BatchNorm2d-5           [-1, 16, 64, 64]              32
              ReLU-6           [-1, 16, 64, 64]               0
            Conv2d-7           [-1, 16, 64, 64]           2,320
       BatchNorm2d-8           [-1, 16, 64, 64]              32
              ReLU-9           [-1, 16, 64, 64]               0
           Conv2d-10           [-1, 32, 32, 32]           4,640
      BatchNorm2d-11           [-1, 32, 32, 32]              64
             ReLU-12           [-1, 32, 32, 32]               0
           Conv2d-13           [-1, 64, 32, 32]          18,496
      BatchNorm2d-14           [-1, 64,