In [1]:
import torch
import torch.nn as nn
from collections import OrderedDict

In [2]:
def conv2d(filter_in, filter_out, kernel_size, groups=1, stride=1):
    pad = (kernel_size - 1) // 2 if kernel_size else 0
    return nn.Sequential(OrderedDict([
        ("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, groups=groups, bias=False)),
        ("bn", nn.BatchNorm2d(filter_out)),
        ("relu", nn.ReLU6(inplace=True)),
    ]))

In [3]:
def make_three_conv(filters_list, in_filters):
    m = nn.Sequential(
        conv2d(in_filters, filters_list[0], 1),
        conv_dw(filters_list[0], filters_list[1]),
        conv2d(filters_list[1], filters_list[0], 1),
    )
    return m

In [4]:
def conv_dw(filter_in, filter_out, stride = 1):
    return nn.Sequential(
        nn.Conv2d(filter_in, filter_in, 3, stride, 1, groups=filter_in, bias=False),
        nn.BatchNorm2d(filter_in),
        nn.ReLU6(inplace=True),

        nn.Conv2d(filter_in, filter_out, 1, 1, 0, bias=False),
        nn.BatchNorm2d(filter_out),
        nn.ReLU6(inplace=True),
    )

In [5]:
x=torch.randn(1,3,224,224)

In [6]:
class MobileNetV3_large(nn.Module):
    def __init__(self):
        super(MobileNetV3_large,self).__init__()
        self.mk=make_three_conv([16,32],3)

    def forward(self,x):
        out=self.mk(x)
        return out

In [7]:
net=MobileNetV3_large()

In [8]:
y=net(x)

In [9]:
y.size()

torch.Size([1, 16, 224, 224])