In [1]:
from torchsummary import summary
import torch
from torch.nn import MaxPool2d, Linear, ReLU, BatchNorm2d, Sequential, Conv2d, Dropout


In [4]:
class Fmcnn2(torch.nn.Module):
    """
    Architecture of the Fmcnn1 model CNN.
    """
    def __init__(self):
        super().__init__()
        
        self.cnn_layers = Sequential(

            Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(32),
            ReLU(inplace=True),
            
            Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(32),
            ReLU(inplace=True),
            MaxPool2d(kernel_size=2, stride=2),

            Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(64),
            ReLU(inplace=True)
           
            
        )

        self.linear_layers = Sequential(
            Linear(64 * 32 * 32, 5)
        )

   
    def forward(self, x):
        x = self.cnn_layers(x)
        x = x.view(x.size(0), -1)
        x = self.linear_layers(x)
        return x

In [5]:
devnet = Fmcnn1()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
devnet.to(device)
summary(devnet, (3, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 64, 64]             224
       BatchNorm2d-2            [-1, 8, 64, 64]              16
              ReLU-3            [-1, 8, 64, 64]               0
         MaxPool2d-4            [-1, 8, 63, 63]               0
            Conv2d-5           [-1, 16, 63, 63]           1,168
       BatchNorm2d-6           [-1, 16, 63, 63]              32
              ReLU-7           [-1, 16, 63, 63]               0
            Conv2d-8           [-1, 16, 63, 63]           2,320
       BatchNorm2d-9           [-1, 16, 63, 63]              32
             ReLU-10           [-1, 16, 63, 63]               0
        MaxPool2d-11           [-1, 16, 31, 31]               0
           Conv2d-12           [-1, 32, 31, 31]           4,640
      BatchNorm2d-13           [-1, 32, 31, 31]              64
             ReLU-14           [-1, 32,