In [1]:
import torch
import torch.nn as nn
        
class Inception(nn.Module): # padding은 공식 이용해서 계산, stride는 주어짐, 'S'는 'same', 'V'는 'valid'로 간주
    
    def __init__(self, common_in_ch, o_c1, o_c2a, i_c2, o_c2b, o_c3a, i_c3, o_c3b, o_c4): 
        super(Inception, self).__init__()
        
        branch1 = []
        branch1 += [nn.Conv2d(in_channels = common_in_ch, out_channels = o_c1, kernel_size = 1, stride = 1, padding = 0, bias=False),
                    nn.BatchNorm2d(o_c1),
                    nn.ReLU(True)]
        
        branch2 = []
        branch2 += [nn.Conv2d(in_channels = common_in_ch, out_channels = o_c2a, kernel_size = 1, stride = 1, padding = 0, bias=False),
                    nn.BatchNorm2d(o_c2a),
                    nn.ReLU(True),
                    nn.Conv2d(in_channels = i_c2, out_channels = o_c2b, kernel_size = 3, stride = 1, padding = 1, bias=False),
                    nn.BatchNorm2d(o_c2b),
                    nn.ReLU(True)]
        
        branch3 = []
        branch3 += [nn.Conv2d(in_channels = common_in_ch, out_channels = o_c3a, kernel_size = 1, stride = 1, padding = 0, bias=False),
                    nn.BatchNorm2d(o_c3a),
                    nn.ReLU(True),
                    nn.Conv2d(in_channels = i_c3, out_channels = o_c3b, kernel_size = 5, stride = 1, padding = 2, bias=False),
                    nn.BatchNorm2d(o_c3b),
                    nn.ReLU(True)]
        
        branch4 = []
        branch4 += [nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1),
                    nn.Conv2d(in_channels = common_in_ch, out_channels = o_c4, kernel_size = 1, padding = 0, bias=False),
                    nn.BatchNorm2d(o_c4),
                    nn.ReLU(True)]
        
        self.layer1 = nn.Sequential(*branch1)
        self.layer2 = nn.Sequential(*branch2)
        self.layer3 = nn.Sequential(*branch3)
        self.layer4 = nn.Sequential(*branch4)
        
    def forward(self, x):
        
        x1 = self.layer1(x)
        x2 = self.layer2(x)
        x3 = self.layer3(x)
        x4 = self.layer4(x)
        
        out = torch.cat((x1, x2, x3, x4), dim = 1) # (batchsize, channel, width, height)
        
        return out

class GoogleNet(nn.Module):
    
    def __init__(self):
        super(GoogleNet, self).__init__()
        
        layer1 = []
        layer1 += [nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 3, stride = 2, padding = 1, bias=False),
                   nn.BatchNorm2d(64),
                   nn.ReLU(True),
                   nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1),
                   nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 1, stride = 1, padding = 0, bias=False), # (V)로 표시되서 valid로 간주
                   nn.BatchNorm2d(64),
                   nn.ReLU(True),
                   nn.Conv2d(in_channels = 64, out_channels = 192, kernel_size = 3, stride = 1, padding = 1, bias=False),
                   nn.BatchNorm2d(192),
                   nn.ReLU(True),
                   Inception(192, 64, 96, 96, 128, 16, 16, 32, 32),
                   Inception(256, 128, 128, 128, 192, 32, 32, 96, 64),
                   nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1),
                   Inception(480, 192, 96, 96, 208, 16, 16, 48, 64),
                   Inception(512, 160, 112, 112, 224, 24, 24, 64, 64)]

        self.layer1 = nn.Sequential(*layer1)

        self.conv = nn.Conv2d(in_channels = 512, out_channels = 256, kernel_size = 1, stride = 1, padding = 1, bias=False)

        self.bn = nn.BatchNorm2d(256)

        self.relu = nn.ReLU(True)
        
        self.dense1 = nn.Linear(in_features = 9216, out_features = 32)
        
        self.dense2 = nn.Linear(in_features = 32, out_features = 10)

        self.drop = nn.Dropout(0.3)
        
        self.avgpool = nn.AdaptiveAvgPool2d((5,5))

    def forward(self, x):
        
        x = self.layer1(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = x.view(x.size(0),-1)
        #x = self.drop(x)
        x = self.dense1(x)
        x = self.dense2(x)
        
        return x

def Model():
    r"""Return your custom model
    """
    return GoogleNet()

model = Model()
from torchsummary import summary
summary(model, (3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]           1,728
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]           4,096
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8            [-1, 192, 8, 8]         110,592
       BatchNorm2d-9            [-1, 192, 8, 8]             384
             ReLU-10            [-1, 192, 8, 8]               0
           Conv2d-11             [-1, 64, 8, 8]          12,288
      BatchNorm2d-12             [-1, 64, 8, 8]             128
             ReLU-13             [-1, 64, 8, 8]               0
           Conv2d-14             [-1, 9