In [2]:
import torch
import torch.nn as nn

class View(nn.Module):

    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(x.shape[0], *self.shape)

class Residual_Block(nn.Module):  
    
    def __init__(self, n_ch): 
        super(Residual_Block, self).__init__() 
        layers = []
        layers += [nn.BatchNorm2d(num_features=n_ch),
                  nn.ReLU(inplace=True), 
                  nn.Conv2d(in_channels=n_ch, out_channels=n_ch, kernel_size=3, stride=1, padding=1, bias=False),
                  nn.BatchNorm2d(num_features=n_ch),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(in_channels=n_ch, out_channels=n_ch, kernel_size=3, stride=1, padding=1, bias=False)]
        self.layers = nn.Sequential(*layers)
        
    def forward(self,x):
        out = self.layers(x)
        return x + out

class Bottleneck(nn.Module):

    def __init__(self, in_ch, growth_rate):
        super(Bottleneck, self).__init__()
        layers = []
        layers += [nn.BatchNorm2d(in_ch),
                   nn.ReLU(True),
                   nn.Conv2d(in_ch, growth_rate*4, kernel_size=1, bias=False),
                   nn.BatchNorm2d(growth_rate*4),
                   nn.ReLU(True),
                   Residual_Block(growth_rate*4),
                   nn.Conv2d(growth_rate*4, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)]
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        out = self.layers(x)
        out = torch.cat((x, out), dim=1)
        return out

class DenseBlock(nn.Module):

    def __init__(self, n_layers, n_ch, growth_rate):
        super(DenseBlock, self).__init__()

        for i in range(n_layers):
            setattr(self, "Dense_layer_{}".format(i), Bottleneck(n_ch + i * growth_rate, growth_rate))

        self.n_layers = n_layers

    def forward(self, x):
        for i in range(self.n_layers):
            x = getattr(self, "Dense_layer_{}".format(i))(x)
        return x

class Transition_layer(nn.Module):

    def __init__(self, in_ch):
        super(Transition_layer, self).__init__()
        num_ch = int(in_ch*0.5)
        layers = []
        layers += [nn.BatchNorm2d(in_ch),
                   nn.ReLU(True),
                   nn.Conv2d(in_ch, num_ch, kernel_size=3, padding=1, bias=False),
                   nn.AvgPool2d(kernel_size=2, stride=2)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        layers = []
        layers += [nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False),
                   nn.BatchNorm2d(16),
                   nn.ReLU(True),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
                   Residual_Block(16),
                   DenseBlock(6, 16, 16),
                   Transition_layer(112),
                   DenseBlock(12, 56, 16),
                   Transition_layer(248),
                   nn.AdaptiveAvgPool2d((1,1)),
                   View(-1),
                   #nn.Dropout(0.4),
                   nn.Linear(124, 10)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        x = self.layers(x)
        return x

model = Model()
from torchsummary import summary
summary(model, (3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 16, 16]             432
       BatchNorm2d-2           [-1, 16, 16, 16]              32
              ReLU-3           [-1, 16, 16, 16]               0
         MaxPool2d-4             [-1, 16, 8, 8]               0
       BatchNorm2d-5             [-1, 16, 8, 8]              32
              ReLU-6             [-1, 16, 8, 8]               0
            Conv2d-7             [-1, 16, 8, 8]           2,304
       BatchNorm2d-8             [-1, 16, 8, 8]              32
              ReLU-9             [-1, 16, 8, 8]               0
           Conv2d-10             [-1, 16, 8, 8]           2,304
   Residual_Block-11             [-1, 16, 8, 8]               0
      BatchNorm2d-12             [-1, 16, 8, 8]              32
             ReLU-13             [-1, 16, 8, 8]               0
           Conv2d-14             [-1, 6

            ReLU-125             [-1, 64, 4, 4]               0
          Conv2d-126             [-1, 64, 4, 4]          36,864
  Residual_Block-127             [-1, 64, 4, 4]               0
          Conv2d-128             [-1, 16, 4, 4]           9,216
      Bottleneck-129             [-1, 88, 4, 4]               0
     BatchNorm2d-130             [-1, 88, 4, 4]             176
            ReLU-131             [-1, 88, 4, 4]               0
          Conv2d-132             [-1, 64, 4, 4]           5,632
     BatchNorm2d-133             [-1, 64, 4, 4]             128
            ReLU-134             [-1, 64, 4, 4]               0
     BatchNorm2d-135             [-1, 64, 4, 4]             128
            ReLU-136             [-1, 64, 4, 4]               0
          Conv2d-137             [-1, 64, 4, 4]          36,864
     BatchNorm2d-138             [-1, 64, 4, 4]             128
            ReLU-139             [-1, 64, 4, 4]               0
          Conv2d-140             [-1, 64

  Residual_Block-253             [-1, 64, 4, 4]               0
          Conv2d-254             [-1, 16, 4, 4]           9,216
      Bottleneck-255            [-1, 232, 4, 4]               0
     BatchNorm2d-256            [-1, 232, 4, 4]             464
            ReLU-257            [-1, 232, 4, 4]               0
          Conv2d-258             [-1, 64, 4, 4]          14,848
     BatchNorm2d-259             [-1, 64, 4, 4]             128
            ReLU-260             [-1, 64, 4, 4]               0
     BatchNorm2d-261             [-1, 64, 4, 4]             128
            ReLU-262             [-1, 64, 4, 4]               0
          Conv2d-263             [-1, 64, 4, 4]          36,864
     BatchNorm2d-264             [-1, 64, 4, 4]             128
            ReLU-265             [-1, 64, 4, 4]               0
          Conv2d-266             [-1, 64, 4, 4]          36,864
  Residual_Block-267             [-1, 64, 4, 4]               0
          Conv2d-268             [-1, 16