In [12]:
from torchsummary import summary

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class MobileNetV3WithConv(nn.Module):
    def __init__(self, num_classes=2):
        super(MobileNetV3WithConv, self).__init__()
        
        # Load the MobileNetV3-small model
        self.mobilenet_v3 = models.mobilenet_v3_small(weights = 'DEFAULT')
        self.mobilenet_v3 = self.mobilenet_v3.features[0:9]
        self.conv = nn.Conv2d(48, 1, kernel_size=1, stride=1, padding=0, bias=False)

        
    def forward(self, x):
        x = self.mobilenet_v3(x)
        x = self.conv(x)

        return x
    
net = MobileNetV3WithConv()
summary(net, (3, 240, 240))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 16, 120, 120]             432
       BatchNorm2d-2         [-1, 16, 120, 120]              32
         Hardswish-3         [-1, 16, 120, 120]               0
            Conv2d-4           [-1, 16, 60, 60]             144
       BatchNorm2d-5           [-1, 16, 60, 60]              32
              ReLU-6           [-1, 16, 60, 60]               0
 AdaptiveAvgPool2d-7             [-1, 16, 1, 1]               0
            Conv2d-8              [-1, 8, 1, 1]             136
              ReLU-9              [-1, 8, 1, 1]               0
           Conv2d-10             [-1, 16, 1, 1]             144
      Hardsigmoid-11             [-1, 16, 1, 1]               0
SqueezeExcitation-12           [-1, 16, 60, 60]               0
           Conv2d-13           [-1, 16, 60, 60]             256
      BatchNorm2d-14           [-1, 16,

In [None]:
input = torch.rand((1, 3, 600, 600))
net(input)