In [1]:
import torch
from torch import nn

In [2]:
class Basic_Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()

        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.BN = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.BN(x)
        x = self.relu(x)

        return x

In [3]:
class Inception_F5(nn.Module):
    """
        From the paper, Figure 5 inception module.
    """
    def __init__(self, in_channels):
        super().__init__()

        self.branch1 = nn.Sequential(
            Basic_Conv(in_channels, 96, 1),
            Basic_Conv(96, 96, 3, 1, 1),
            Basic_Conv(96, 96, 3, 1, 1)
        )
        self.branch2 = nn.Sequential(
            Basic_Conv(in_channels, 64, 1),
            Basic_Conv(64, 64, 3, 1, 1)
        )
        self.branch3 = nn.Sequential(
            nn.MaxPool2d(3, 1, 1),
            Basic_Conv(in_channels, 64, 1)
        )
        self.branch4 = Basic_Conv(in_channels, 64, 1)

    def forward(self, x):
        out1 = self.branch1(x)
        out2 = self.branch2(x)
        out3 = self.branch3(x)
        out4 = self.branch4(x)

        return torch.concat([out1, out2, out3, out4], 1)

In [4]:
class Inception_F6(nn.Module):
    """
        From the paper, Figure 6 inception module.
    """
    def __init__(self, in_channels, red_1x1):
        super().__init__()

        self.branch1 = nn.Sequential(
            Basic_Conv(in_channels, red_1x1, kernel_size=1),
            Basic_Conv(red_1x1, red_1x1, kernel_size=(1, 7), stride=1, padding=(0, 3)),
            Basic_Conv(red_1x1, red_1x1, kernel_size=(7, 1), stride=1, padding=(3, 0)),
            Basic_Conv(red_1x1, red_1x1, kernel_size=(1, 7), stride=1, padding=(0, 3)),
            Basic_Conv(red_1x1, 192, kernel_size=(7, 1), stride=1, padding=(3, 0))
        )
        self.branch2 = nn.Sequential(
            Basic_Conv(in_channels, red_1x1, kernel_size=1),
            Basic_Conv(red_1x1, red_1x1, kernel_size=(1, 7), stride=1, padding=(0, 3)),
            Basic_Conv(red_1x1, 192, kernel_size=(7, 1), stride=1, padding=(3, 0))
        )
        self.branch3 = nn.Sequential(
            nn.MaxPool2d(3, 1, 1),
            Basic_Conv(in_channels, 192, kernel_size=(1, 1))
        )
        self.branch4 = Basic_Conv(in_channels, 192, kernel_size=(1, 1))

    def forward(self, x):
        out1 = self.branch1(x)
        out2 = self.branch2(x)
        out3 = self.branch3(x)
        out4 = self.branch4(x)

        return torch.concat([out1, out2, out3, out4], 1)

In [5]:
class Inception_F7(nn.Module):
    """
        From the paper, Figure 7 inception module.
    """
    def __init__(self, in_channels):
        super().__init__()

        self.branch1_stem = nn.Sequential(
            Basic_Conv(in_channels, 448, kernel_size=1),
            Basic_Conv(448, 384, kernel_size=(3, 3), stride=1, padding=1)
        )
        self.branch1_left = Basic_Conv(384, 384, kernel_size=(1, 3), stride=1, padding=(0, 1))
        self.branch1_right = Basic_Conv(384, 384, kernel_size=(3, 1), stride=1, padding=(1, 0))

        self.branch2_stem = Basic_Conv(in_channels, 384, kernel_size=(1, 1))
        self.branch2_left = Basic_Conv(384, 384, kernel_size=(1, 3), stride=1, padding=(0, 1))
        self.branch2_right = Basic_Conv(384, 384, kernel_size=(3, 1), stride=1, padding=(1, 0))

        self.branch3 = nn.Sequential(
            nn.MaxPool2d(3, 1, 1),
            Basic_Conv(in_channels, 192, kernel_size=(1, 1))
        )

        self.branch4 = Basic_Conv(in_channels, 320, kernel_size=(1, 1))


    def forward(self, x):
        branch1_stem = self.branch1_stem(x)
        branch1_left = self.branch1_left(branch1_stem)
        branch1_right = self.branch1_right(branch1_stem)

        branch2_stem = self.branch2_stem(x)
        branch2_left = self.branch2_left(branch2_stem)
        branch2_right = self.branch2_right(branch2_stem)

        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        return torch.concat([branch1_left, branch1_right, branch2_left, branch2_right, branch3, branch4], 1)

In [6]:
class Inception_Red(nn.Module):
    def __init__(self, in_channels, red_1x1, add_ch=0):
        super().__init__()

        self.branch1 = nn.Sequential(
            Basic_Conv(in_channels, red_1x1, 1),
            Basic_Conv(red_1x1, 96, 3, 1, 1),
            Basic_Conv(96, 96 + add_ch, 3, 2, 0)
        )
        self.branch2 = nn.Sequential(
            Basic_Conv(in_channels, red_1x1, 1),
            Basic_Conv(red_1x1, 384 + add_ch, 3, 2, 0)
        )
        self.branch3 = nn.MaxPool2d(3, 2, 0)

    def forward(self, x):
        out1 = self.branch1(x)
        out2 = self.branch2(x)
        out3 = self.branch3(x)

        return torch.concat([out1, out2, out3], 1)

In [7]:
class InceptionAux(nn.Module):
    """
        From the paper, auxilary classifier
    """
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        
        self.pool = nn.AdaptiveAvgPool2d((4,4))
        self.conv = nn.Conv2d(in_channels, 128, kernel_size=1, stride=1, padding=0)
        self.act = nn.ReLU()
        self.fc1 = nn.Linear(2048, 1024)
        self.dropout = nn.Dropout(0.7)
        self.fc2 = nn.Linear(1024, num_classes)
    
    def forward(self, x):
        x = self.pool(x)
        
        x = self.conv(x)
        x = self.act(x)
    
        x = torch.flatten(x, 1)
        
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        
        x = self.fc2(x)
        
        return x

In [8]:
class InceptionV3(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super().__init__()

        self.conv1 = Basic_Conv(in_channels=in_channels, out_channels=32, kernel_size=(3, 3), stride=2, padding=0)
        self.conv2 = Basic_Conv(32, 32, 3)
        self.conv3 = Basic_Conv(32, 64, 3, 1, 1)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv4 = Basic_Conv(64, 80, 3)
        self.conv5 = Basic_Conv(80, 192, 3, 2)
        self.conv6 = Basic_Conv(192, 288, 3, 1, 1)

        self.inception_5 = Inception_F5(288)
        self.red_a = Inception_Red(288, red_1x1=64)

        self.inception_6a = Inception_F6(768, red_1x1=128)
        self.inception_6b = Inception_F6(768, red_1x1=160)
        self.inception_6c = Inception_F6(768, red_1x1=160)
        self.inception_6d = Inception_F6(768, red_1x1=160)
        self.inception_6e = Inception_F6(768, red_1x1=192)
        self.red_b = Inception_Red(768, red_1x1=192, add_ch=16)
        
        self.aux = InceptionAux(768, num_classes) 
        
        self.inception_7a = Inception_F7(1280)
        self.inception_7b = Inception_F7(2048)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(in_features=2048, out_features=num_classes)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)

        x = self.inception_5(x)
        x = self.inception_5(x)
        x = self.inception_5(x)

        x = self.red_a(x)

        x = self.inception_6a(x)
        x = self.inception_6b(x)
        x = self.inception_6c(x)
        x = self.inception_6d(x)
        x = self.inception_6e(x)

        aux = self.aux(x)
        
        x = self.red_b(x)

        x = self.inception_7a(x)
        x = self.inception_7b(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)

        return x, aux

In [9]:
from torchsummary import summary #설계한 모델의 요약본 출력 모듈

debug_model = InceptionV3(in_channels = 3, num_classes=1000)

summary(debug_model, input_size=(3, 229, 229), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 114, 114]             896
       BatchNorm2d-2         [-1, 32, 114, 114]              64
              ReLU-3         [-1, 32, 114, 114]               0
        Basic_Conv-4         [-1, 32, 114, 114]               0
            Conv2d-5         [-1, 32, 112, 112]           9,248
       BatchNorm2d-6         [-1, 32, 112, 112]              64
              ReLU-7         [-1, 32, 112, 112]               0
        Basic_Conv-8         [-1, 32, 112, 112]               0
            Conv2d-9         [-1, 64, 112, 112]          18,496
      BatchNorm2d-10         [-1, 64, 112, 112]             128
             ReLU-11         [-1, 64, 112, 112]               0
       Basic_Conv-12         [-1, 64, 112, 112]               0
        MaxPool2d-13           [-1, 64, 55, 55]               0
           Conv2d-14           [-1, 80,