In [2]:
import torch
import torch.nn as nn

class View(nn.Module):

    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(x.shape[0], *self.shape) 

class Sepconv(nn.Module):

    def __init__(self, in_channel, outchannel):
        super(Sepconv, self).__init__()
        sepconv = []
        sepconv += [nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=3, stride=1, padding=1, groups=in_channel, bias=False),
                    nn.Conv2d(in_channels=in_channel, out_channels=outchannel, kernel_size=1, stride=1, bias=False),
                    nn.BatchNorm2d(num_features=outchannel)]

        self.layers = nn.Sequential(*sepconv)

    def forward(self, x):
        return self.layers(x)

class Entry_flow(nn.Module):

    def __init__(self):
        super(Entry_flow, self).__init__()

        layer1 = []
        layer1 += [nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False),
                   nn.BatchNorm2d(16),
                   nn.ReLU(inplace=True),
                   nn.Conv2d(16, 32, kernel_size=3, padding=1, bias=False),
                   nn.BatchNorm2d(32),
                   nn.ReLU(True)]
        layer2 = []
        layer2 += [Sepconv(32, 64),
                   nn.ReLU(True),
                   Sepconv(64, 64),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1)]
        layer3 = []
        layer3 += [nn.ReLU(True),
                   Sepconv(64, 128),
                   nn.ReLU(True),
                   Sepconv(128, 128),
                   nn.MaxPool2d(3, 2, padding=1)]
        layer4 = []
        layer4 += [nn.ReLU(True),
                   Sepconv(128, 270),
                   nn.ReLU(True),
                   Sepconv(270, 270),
                   nn.MaxPool2d(3, 2, padding=1)]

        self.layer1 = nn.Sequential(*layer1)
        self.layer2 = nn.Sequential(*layer2)
        self.layer3 = nn.Sequential(*layer3)
        self.layer4 = nn.Sequential(*layer4)
        self.skip_con1 = nn.Sequential(*[nn.Conv2d(32, 64, kernel_size=1, stride=2, bias=False),
                                         nn.BatchNorm2d(64)])
        self.skip_con2 = nn.Sequential(*[nn.Conv2d(64, 128, 1, 2, bias=False),
                                         nn.BatchNorm2d(128)])
        self.skip_con3 = nn.Sequential(*[nn.Conv2d(128, 270, 1, 2, bias=False),
                                         nn.BatchNorm2d(270)])

    def forward(self, x):
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        skip = self.skip_con1(x1)
        x2 = x2+skip
        x3 = self.layer3(x2)
        skip = self.skip_con2(x2)
        x3 = x3+skip
        x4 = self.layer4(x3)
        skip = self.skip_con3(x3)
        return x4+skip

class Middle_flow(nn.Module):

    def __init__(self):
        super(Middle_flow, self).__init__()

        layer = []
        layer += [nn.ReLU(True),
                  Sepconv(270, 270),
                  nn.ReLU(True),
                  Sepconv(270, 270),
                  nn.ReLU(True),
                  Sepconv(270, 270)]

        self.layers = nn.Sequential(*layer)

    def forward(self, x):
        out = self.layers(x)
        return out+x

class Exit_flow(nn.Module):

    def __init__(self):
        super(Exit_flow, self).__init__()

        layer1 = []
        layer1 += [nn.ReLU(True),
                   Sepconv(270, 270),
                   nn.ReLU(True),
                   Sepconv(270, 512),
                   nn.MaxPool2d(3, 2, padding=1)]

        layer2 = []
        layer2 += [Sepconv(512, 768),
                   nn.ReLU(True),
                   Sepconv(768, 1024),
                   nn.ReLU(True),
                   nn.AdaptiveAvgPool2d((1,1)),
                   View(-1),
                   nn.Linear(1024, 10)]

        self.layer1 = nn.Sequential(*layer1)
        self.layer2 = nn.Sequential(*layer2)
        self.skip_con = nn.Conv2d(in_channels=270, out_channels=512, kernel_size=1, stride=2, bias=False)
        self.bn = nn.BatchNorm2d(512)
        
        
    def forward(self, x):
        out1 = self.layer1(x)
        skip = self.skip_con(x)
        skip = self.bn(skip)
        output = self.layer2(out1+skip)
        return output
    
class Model(nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()

        layers = []
        layers += [Entry_flow(),
                   Middle_flow(),
                   Exit_flow()]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        y_pred = self.layers(x)
        return y_pred

model = Model()
from torchsummary import summary
summary(model, (3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 16, 16]             432
       BatchNorm2d-2           [-1, 16, 16, 16]              32
              ReLU-3           [-1, 16, 16, 16]               0
            Conv2d-4           [-1, 32, 16, 16]           4,608
       BatchNorm2d-5           [-1, 32, 16, 16]              64
              ReLU-6           [-1, 32, 16, 16]               0
            Conv2d-7           [-1, 32, 16, 16]             288
            Conv2d-8           [-1, 64, 16, 16]           2,048
       BatchNorm2d-9           [-1, 64, 16, 16]             128
          Sepconv-10           [-1, 64, 16, 16]               0
             ReLU-11           [-1, 64, 16, 16]               0
           Conv2d-12           [-1, 64, 16, 16]             576
           Conv2d-13           [-1, 64, 16, 16]           4,096
      BatchNorm2d-14           [-1, 64,