In [24]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torchsummary import summary

class RecCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.convnet = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding='same'), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            
            nn.Conv2d(64, 16, 3, padding='same'), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
        )
        self.fc1 = nn.Linear(32*32*16, 32)
        self.fc2 = nn.Linear(32, num_classes)
        # self.softmax = nn.Softmax()
    def forward(self, x):
        x = torch.flatten(self.convnet(x), 1)
        x = self.fc1(x)
        x = self.fc2(x)
        # c = self.softmax(x)
        return x

In [25]:
x = torch.rand((1, 1, 128, 128))
model = RecCNN(num_classes=5)

In [26]:
summary(model, (1, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]             640
              ReLU-2         [-1, 64, 128, 128]               0
         MaxPool2d-3           [-1, 64, 64, 64]               0
            Conv2d-4           [-1, 16, 64, 64]           9,232
              ReLU-5           [-1, 16, 64, 64]               0
         MaxPool2d-6           [-1, 16, 32, 32]               0
            Linear-7                   [-1, 32]         524,320
            Linear-8                    [-1, 5]             165
Total params: 534,357
Trainable params: 534,357
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.06
Forward/backward pass size (MB): 19.13
Params size (MB): 2.04
Estimated Total Size (MB): 21.23
----------------------------------------------------------------


In [27]:
from thop import profile

In [29]:
x = torch.rand((1, 1, 128, 128))
model = RecCNN(num_classes=5)
macs, params = profile(model, inputs=(x, ))
print('MACs (G): ', macs/1000**2)
print('Params (M): ', params/1000**2)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
MACs (G):  47.710368
Params (M):  0.534357
