In [2]:
import torch
import torch.nn as nn

In [3]:
# Inception으로 dimension reduction 해주는 부분
# Figure 2 의 (b)로 구현

class Inception(nn.Module):
    def __init__(self,in_channels,out_1x1,reduce_3x3,out_3x3,reduce_5x5,out_5x5,pooling):
        super().__init__()
        
        # 1x1 conv층만
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels,out_1x1,kernel_size=1),
            nn.ReLU(),
        )
        
        # 1x1 conv -> 3x3 conv층
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels,reduce_3x3,kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(reduce_3x3,out_3x3,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
        )
        
        # 1x1 conv -> 5x5 conv층
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels,reduce_5x5,kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(reduce_5x5,out_5x5,kernel_size=5,stride=1,padding=2),
            nn.ReLU()
        )
        
        # 3x3 maxpooling -> 1x1 conv층
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3,stride=1,padding=1),
            nn.Conv2d(in_channels,pooling,kernel_size=1),
            nn.ReLU(),
        )
    def forward(self,x):
        out1 = self.branch1(x)
        out2 = self.branch2(x)
        out3 = self.branch3(x)
        out4 = self.branch4(x)
        # concatenating
        return torch.cat([out1,out2,out3,out4],dim=1)

In [4]:
class GoogleNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Inception 전까지
        self.layer1 = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
            nn.Conv2d(64,192,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
        )
        # Inception 3a,3b maxpool까지
        self.layer2 = nn.Sequential(
            Inception(192,64,96,128,16,32,32),
            Inception(256,128,128,192,32,96,64),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        )
        # Inception 4a,4b,4c,4d,4e,maxpool까지
        self.layer3 = nn.Sequential(
            Inception(480,192,96,208,16,48,64),
            Inception(512,160,112,224,24,64,64), 
            Inception(512,128,128,256,24,64,64),
            Inception(512,112,144,288,32,64,64),
            Inception(528,256,160,320,32,128,128),  
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
        )
        # Inception 5a,5b,avgpool까지
        self.layer4 = nn.Sequential(
            # Inception 5a,5b,avgpool
            Inception(832,256,160,320,32,128,128),
            Inception(832,384,192,384,48,128,128),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.dropout = nn.Dropout2d(p = 0.4)
        self.fc = nn.Linear(1024,1000)
    
    def forward(self,x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.dropout(x)
        x = x.view(x.size()[0],-1)
        x = self.fc(x)
        return x

In [10]:
model = GoogleNet()
img = torch.rand((1,3,224,224))
output = model(img)
print(output.size())

torch.Size([1, 1000])


In [7]:
from torchsummary import summary
summary(model, (3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,472
              ReLU-2         [-1, 64, 112, 112]               0
         MaxPool2d-3           [-1, 64, 56, 56]               0
            Conv2d-4          [-1, 192, 56, 56]         110,784
              ReLU-5          [-1, 192, 56, 56]               0
         MaxPool2d-6          [-1, 192, 28, 28]               0
            Conv2d-7           [-1, 64, 28, 28]          12,352
              ReLU-8           [-1, 64, 28, 28]               0
            Conv2d-9           [-1, 96, 28, 28]          18,528
             ReLU-10           [-1, 96, 28, 28]               0
           Conv2d-11          [-1, 128, 28, 28]         110,720
             ReLU-12          [-1, 128, 28, 28]               0
           Conv2d-13           [-1, 16, 28, 28]           3,088
             ReLU-14           [-1, 16,