In [None]:
import torch
from torch import nn

class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class InceptionA(nn.Module):
    def __init__(self, in_channels, pool_features):
        super().__init__()
        self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)

        self.branch5x5 = nn.Sequential(
            BasicConv2d(in_channels, 48, kernel_size=1),
            BasicConv2d(48, 64, kernel_size=5, padding=2)
        )

        self.branch3x3dbl = nn.Sequential(
            BasicConv2d(in_channels, 64, kernel_size=1),
            BasicConv2d(64, 96, kernel_size=3, padding=1),
            BasicConv2d(96, 96, kernel_size=3, padding=1),
        )

        self.avg_pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
        self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)
        
    def forward(self, x):
        b1 = self.branch1x1(x)
        b2 = self.branch5x5(x)
        b3 = self.branch3x3dbl(x)
        b4 = self.avg_pool(x)
        b4 = self.branch_pool(b4)

        return torch.cat([b1, b2, b3, b4], 1)


class InceptionB(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)

        self.branch3x3dbl = nn.Sequential(
            BasicConv2d(in_channels, 64, kernel_size=1),
            BasicConv2d(64, 96, kernel_size=3, padding=1),
            BasicConv2d(96, 96, kernel_size=3, stride=2),
        )
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2)

    def forward(self, x):
        b1 = self.branch3x3(x)
        b2 = self.branch3x3dbl(x)
        b3 = self.max_pool(x)

        return torch.cat([b1, b2, b3], 1)

class InceptionC(nn.Module):
    def __init__(
        self, in_channels, channels_7x7):
        super().__init__()
        self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)
        c7 = channels_7x7

        self.branch7x7 = nn.Sequential(
            BasicConv2d(in_channels, c7, kernel_size=1),
            BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0))
        )

        self.branch7x7dbl =  nn.Sequential(
            BasicConv2d(in_channels, c7, kernel_size=1),
            BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
        )
        
        self.avg_pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
        self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)

    def forward(self, x):
        b1 = self.branch1x1(x)

        b2 = self.branch7x7(x)

        b3 = self.branch7x7dbl(x)

        b4 = self.avg_pool(x)
        b4 = self.branch_pool(b4)

        return torch.cat([b1, b2, b3, b4], 1)

class InceptionD(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.branch3x3 = nn. Sequential(
            BasicConv2d(in_channels, 192, kernel_size=1),
            BasicConv2d(192, 320, kernel_size=3, stride=2)
        )

        self.branch7x7x3 = nn.Sequential(
            BasicConv2d(in_channels, 192, kernel_size=1),
            BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(192, 192, kernel_size=3, stride=2)
        )
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2)
    def forward(self, x):
        b1 = self.branch3x3(x)
        b2 = self.branch7x7x3(x)

        b3 = self.max_pool(x)

        return torch.cat([b1, b2, b3], 1)


class InceptionE(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)

        self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch3x3dbl = nn.Sequential(
            BasicConv2d(in_channels, 448, kernel_size=1),
            BasicConv2d(448, 384, kernel_size=3, padding=1)
        )
        self.branch3x3dbl_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.avg_pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
        self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)

    def forward(self, x):
        b1 = self.branch1x1(x)

        b2 = self.branch3x3(x)
        b2 = torch.cat([self.branch3x3_2a(b2), self.branch3x3_2b(b2)], 1)

        b3 = self.branch3x3dbl(x)
        b3 = torch.cat([self.branch3x3dbl_2a(b3), self.branch3x3dbl_2b(b3)], 1)

        b4 = self.avg_pool(x)
        b4 = self.branch_pool(b4)

        return torch.cat([b1, b2, b3, b4], 1)

class InceptionAux(nn.Module):
    def __init__(
        self, in_channels, num_classes):
        super().__init__()
        self.avg_pool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)
        self.conv1 = BasicConv2d(128, 768, kernel_size=5)
        self.fc = nn.Linear(768, num_classes)
        self.adaptive_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
    def forward(self, x):
        # N x 768 x 17 x 17
        x = self.avg_pool(x)
        # N x 768 x 5 x 5
        x = self.conv0(x)
        # N x 128 x 5 x 5
        x = self.conv1(x)
        # N x 768 x 1 x 1
        # Adaptive average pooling
        x = self.adaptive_avg_pool(x)
        # N x 768 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 768
        x = self.fc(x)
        # N x 1000
        return x
    
class InceptionV3(nn.Module):
    def __init__(self, num_classes = 1000, aux_logits = True, dropout = 0.5):
        super().__init__()
        conv_block = BasicConv2d
        inception_a = InceptionA
        inception_b = InceptionB
        inception_c = InceptionC
        inception_d = InceptionD
        inception_e = InceptionE
        inception_aux = InceptionAux

        self.aux_logits = aux_logits

        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)

        self.Mixed_5b = inception_a(192, pool_features=32)
        self.Mixed_5c = inception_a(256, pool_features=64)
        self.Mixed_5d = inception_a(288, pool_features=64)

        self.Mixed_6a = inception_b(288)

        self.Mixed_6b = inception_c(768, channels_7x7=128)
        self.Mixed_6c = inception_c(768, channels_7x7=160)
        self.Mixed_6d = inception_c(768, channels_7x7=160)
        self.Mixed_6e = inception_c(768, channels_7x7=192)

        if aux_logits:
            self.AuxLogits = inception_aux(768, num_classes)

        self.Mixed_7a = inception_d(768)
        self.Mixed_7b = inception_e(1280)
        self.Mixed_7c = inception_e(2048)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        # N x 3 x 299 x 299
        x = self.Conv2d_1a_3x3(x)
        # N x 32 x 149 x 149
        x = self.Conv2d_2a_3x3(x)
        # N x 32 x 147 x 147
        x = self.Conv2d_2b_3x3(x)
        # N x 64 x 147 x 147
        x = self.maxpool1(x)
        # N x 64 x 73 x 73
        x = self.Conv2d_3b_1x1(x)
        # N x 80 x 73 x 73
        x = self.Conv2d_4a_3x3(x)
        # N x 192 x 71 x 71
        x = self.maxpool2(x)

        # N x 192 x 35 x 35
        x = self.Mixed_5b(x)
        # N x 256 x 35 x 35
        x = self.Mixed_5c(x)
        # N x 288 x 35 x 35
        x = self.Mixed_5d(x)

        # N x 288 x 35 x 35
        x = self.Mixed_6a(x)

        # N x 768 x 17 x 17
        x = self.Mixed_6b(x)
        # N x 768 x 17 x 17
        x = self.Mixed_6c(x)
        # N x 768 x 17 x 17
        x = self.Mixed_6d(x)
        # N x 768 x 17 x 17
        x = self.Mixed_6e(x)
        
        # N x 768 x 17 x 17
        aux = None
        if self.AuxLogits is not None:
            if self.training:
                aux = self.AuxLogits(x)

        # N x 768 x 17 x 17
        x = self.Mixed_7a(x)

        # N x 1280 x 8 x 8
        x = self.Mixed_7b(x)
        # N x 2048 x 8 x 8
        x = self.Mixed_7c(x)

        # N x 2048 x 8 x 8
        # Adaptive average pooling
        x = self.avgpool(x)
        # N x 2048 x 1 x 1
        x = self.dropout(x)
        # N x 2048 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 2048
        x = self.fc(x)
        # N x 1000 (num_classes)
        return x, aux

In [14]:
from torchinfo import summary

model = InceptionV3()
model = model.cuda()
summary(model, input_size=(2,3,299,299))

Layer (type:depth-idx)                   Output Shape              Param #
InceptionV3                              [2, 1000]                 3,326,696
├─BasicConv2d: 1-1                       [2, 32, 149, 149]         --
│    └─Conv2d: 2-1                       [2, 32, 149, 149]         864
│    └─BatchNorm2d: 2-2                  [2, 32, 149, 149]         64
│    └─ReLU: 2-3                         [2, 32, 149, 149]         --
├─BasicConv2d: 1-2                       [2, 32, 147, 147]         --
│    └─Conv2d: 2-4                       [2, 32, 147, 147]         9,216
│    └─BatchNorm2d: 2-5                  [2, 32, 147, 147]         64
│    └─ReLU: 2-6                         [2, 32, 147, 147]         --
├─BasicConv2d: 1-3                       [2, 64, 147, 147]         --
│    └─Conv2d: 2-7                       [2, 64, 147, 147]         18,432
│    └─BatchNorm2d: 2-8                  [2, 64, 147, 147]         128
│    └─ReLU: 2-9                         [2, 64, 147, 147]         --

In [None]:
import torch
from torch import nn

In [None]:
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, bias = False, **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels, eps = 0.001)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        return x

class InceptionF5(nn.Module): # Figure 5
    def __init__(self, in_channels):
        super().__init__()
        
        self.branch1 = nn.Sequential(
            BasicConv2d(in_channels, 64, kernel_size = 1),
            BasicConv2d(64, 96, kernel_size = 3, padding = 1),
            BasicConv2d(96, 96, kernel_size = 3, padding = 1),
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, 48, kernel_size = 1),
            BasicConv2d(48, 64, kernel_size = 3, padding = 1),
        )

        self.branch3 = nn.Sequential(
            nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1),
            BasicConv2d(in_channels, 64, kernel_size = 1),
        )
        
        self.branch4 = BasicConv2d(in_channels, 64, kernel_size = 1)

    def forward(self, x):
        return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], dim = 1)

class InceptionF6(nn.Module): # Figure 6
    def __init__(self, in_channels, f_7x7):
        super().__init__()

        self.branch1 = nn.Sequential(
            BasicConv2d(in_channels, f_7x7, kernel_size = 1),
            BasicConv2d(f_7x7, f_7x7, kernel_size = (1, 7), padding = (0, 3)),
            BasicConv2d(f_7x7, f_7x7, kernel_size = (7, 1), padding = (3, 0)),
            BasicConv2d(f_7x7, f_7x7, kernel_size = (1, 7), padding = (0, 3)),
            BasicConv2d(f_7x7, 192, kernel_size = (7, 1), padding = (3, 0)),
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, f_7x7, kernel_size = 1),
            BasicConv2d(f_7x7, f_7x7, kernel_size = (1, 7), padding = (0, 3)),
            BasicConv2d(f_7x7, 192, kernel_size = (7, 1), padding = (3, 0)),
        )

        self.branch3 = nn.Sequential(
            nn.MaxPool2d(3, stride = 1, padding = 1),
            BasicConv2d(in_channels, 192, kernel_size = 1),
        )

        self.branch4 = BasicConv2d(in_channels, 192, kernel_size = 1)
    
    def forward(self, x):
        return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], dim = 1)

class InceptionF7(nn.Module): # Figure 7
    def __init__(self, in_channels):
        super().__init__()

        self.branch1_stem = nn.Sequential(
            BasicConv2d(in_channels, 448, kernel_size = 1),
            BasicConv2d(448, 384, kernel_size = 3, padding = 1),
        )
        self.branch1_left = BasicConv2d(384, 384, kernel_size = (1, 3), padding = (0, 1))
        self.branch1_right = BasicConv2d(384, 384, kernel_size = (3, 1), padding = (1, 0))

        self.branch2_stem = BasicConv2d(in_channels, 384, kernel_size = 1)
        self.branch2_left = BasicConv2d(384, 384, kernel_size = (1, 3), padding = (0, 1))
        self.branch2_right = BasicConv2d(384, 384, kernel_size = (3, 1), padding = (1, 0))

        self.branch3 = nn.Sequential(
            nn.MaxPool2d(3, stride = 1, padding = 1),
            BasicConv2d(in_channels, 192, kernel_size = 1)
        )

        self.branch4 = BasicConv2d(in_channels, 320, kernel_size = 1)
    
    def forward(self, x):
        branch1_stem = self.branch1_stem(x)
        branch2_stem = self.branch2_stem(x)

        branch1 = torch.cat([self.branch1_left(branch1_stem), self.branch1_right(branch1_stem)], dim = 1)
        branch2 = torch.cat([self.branch2_left(branch2_stem), self.branch2_right(branch2_stem)], dim = 1)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        return torch.cat([branch1, branch2, branch3, branch4], dim = 1)

class Inception_ReduceA(nn.Module): # Figure 10 : conv (stride 2) -> pooling operation, 
    # 사람들 마다 코드가 조금씩 달라 pytorch source code를 이용.
    def __init__(self, in_channels):
        super().__init__()

        self.branch1 = nn.Sequential(
            BasicConv2d(in_channels, 64, kernel_size = 1),
            BasicConv2d(64, 96, kernel_size = 3, padding = 1),
            BasicConv2d(96, 96, kernel_size = 3, stride = 2),
        )

        self.branch2 = BasicConv2d(in_channels, 384, kernel_size = 3, stride = 2)
        self.branch3 = nn.MaxPool2d(3, stride = 2)

    def forward(self, x):
        return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x)], dim = 1)

class Inception_ReduceB(nn.Module): # Figure 10 : conv (stride 2) -> pooling operation
    # 사람들 마다 코드가 조금씩 달라 pytorch source code를 이용.
    def __init__(self, in_channels):
        super().__init__()

        self.branch1 = nn.Sequential(
            BasicConv2d(in_channels, 192, kernel_size = 1),
            BasicConv2d(192, 192, kernel_size = (1, 7), padding = (0, 3)),
            BasicConv2d(192, 192, kernel_size = (7, 1), padding = (3, 0)),
            BasicConv2d(192, 192, kernel_size = 3, stride = 2)
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, 192, kernel_size = 1),
            BasicConv2d(192, 320, kernel_size = 3, stride = 2),
        )

        self.branch3 = nn.MaxPool2d(3, stride = 2)

    def forward(self, x):
        return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x)], dim = 1)

class Inception_Aux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.avgpool = nn.AdaptiveAvgPool2d((5, 5)) # paper에는 nn.AvgPool2d(kernel_size = 5, stride = 3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size = 1)
        self.fc1 = nn.Linear(5 * 5 * 128, 1024)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.7)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.avgpool(x)     # N x 768 x 17 x 17 -> N x 768 x 5 x 5 
        x = self.conv(x)        # N x 768 x 5 x 5 -> N x 128 x 5 x 5

        x = torch.flatten(x, 1) # N x 128 x 5 x 5 -> N x 3200

        x = self.fc1(x)         # N x 3200 -> N x 1024
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)         # N x 1024 -> N x 1000
        return x

  
class Inception_V3(nn.Module):
    def __init__(self, num_classes = 1000, use_aux = True, drop_p = 0.5):
        super().__init__()
        in_channels = 3 #RGB

        self.conv1a = BasicConv2d(in_channels, 32, kernel_size = 3, stride = 2)
        self.conv1b = BasicConv2d(32, 32, kernel_size = 3)
        self.conv1c = BasicConv2d(32, 64, kernel_size = 3, padding = 1)

        self.pool1 = nn.MaxPool2d(3, stride = 2)
        
        self.conv2a = BasicConv2d(64, 80, kernel_size = 3)
        self.conv2b = BasicConv2d(80, 192, kernel_size = 3, stride = 2)
        self.conv2c = BasicConv2d(192, 288, kernel_size = 3, padding = 1)

        self.inception3a = InceptionF5(288)
        self.inception3b = InceptionF5(288)
        self.inception3c = InceptionF5(288)

        self.inception_red1 = Inception_ReduceA(288)

        self.inception4a = InceptionF6(768, f_7x7 = 128)
        self.inception4b = InceptionF6(768, f_7x7 = 160)
        self.inception4c = InceptionF6(768, f_7x7 = 160)
        self.inception4d = InceptionF6(768, f_7x7 = 160)
        self.inception4e = InceptionF6(768, f_7x7 = 192)

        if use_aux:
            self.aux = Inception_Aux(768, num_classes = num_classes)

        self.inception_red2 = Inception_ReduceB(768)

        self.inception5a = InceptionF7(1280)
        self.inception5b = InceptionF7(2048)

        self.pool6 = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p = drop_p)
        self.fc = nn.Linear(2048, num_classes)
    
    def forward(self, x):

        x = self.conv1a(x)  # -> N x 32 x 149 x 149
        x = self.conv1b(x)  # -> N x 32 x 147 x 147
        x = self.conv1c(x)  # -> N x 64 x 147 x 147

        x = self.pool1(x)   # -> N x 64 x 73 x 73

        x = self.conv2a(x)  # -> N x 80 x 71 x 71
        x = self.conv2b(x)  # -> N x 192 x 35 x 35
        x = self.conv2c(x)  # -> N x 288 x 35 x 35

        x = self.inception3a(x) # -> N x (96 + 64 * 3) x 35 x 35 = N x 288 x 35 x 35
        x = self.inception3b(x) # -> N x 288 x 35 x 35
        x = self.inception3c(x) # -> N x 288 x 35 x 35

        x = self.inception_red1(x)  # -> N x 768 x 17 x 17

        x = self.inception4a(x) # -> N x (192 * 4) x 17 x 17 = N x 768 x 17 x 17
        x = self.inception4b(x) # -> N x 768 x 17 x 17
        x = self.inception4c(x) # -> N x 768 x 17 x 17
        x = self.inception4d(x) # -> N x 768 x 17 x 17
        x = self.inception4e(x) # -> N x 768 x 17 x 17

        if self.aux is not None and self.training:
            aux = self.aux(x)
        else:
            aux = None  # Not defined error 방지
        
        x = self.inception_red2(x) # -> N x 1280 x 8 x 8

        x = self.inception5a(x) # -> N x (384 * 2 * 2 + 192 + 320) x 8 x 8 = N x 2048 x 8 x 8
        x = self.inception5b(x) # -> N x 2048 x 8 x 8
        x = self.pool6(x)       # -> N x 2048 x 1 x 1

        x = torch.flatten(x, 1) # -> N x 2048
        
        x = self.dropout(x)
        x = self.fc(x) # -> N x 1000

        return x, aux

In [6]:
model = Inception_V3()
model = model.cuda()
summary(model, input_size=(2,3,299,299))

Layer (type:depth-idx)                   Output Shape              Param #
Inception_V3                             [2, 1000]                 4,401,384
├─BasicConv2d: 1-1                       [2, 32, 149, 149]         --
│    └─Conv2d: 2-1                       [2, 32, 149, 149]         864
│    └─BatchNorm2d: 2-2                  [2, 32, 149, 149]         64
│    └─ReLU: 2-3                         [2, 32, 149, 149]         --
├─BasicConv2d: 1-2                       [2, 32, 147, 147]         --
│    └─Conv2d: 2-4                       [2, 32, 147, 147]         9,216
│    └─BatchNorm2d: 2-5                  [2, 32, 147, 147]         64
│    └─ReLU: 2-6                         [2, 32, 147, 147]         --
├─BasicConv2d: 1-3                       [2, 64, 147, 147]         --
│    └─Conv2d: 2-7                       [2, 64, 147, 147]         18,432
│    └─BatchNorm2d: 2-8                  [2, 64, 147, 147]         128
│    └─ReLU: 2-9                         [2, 64, 147, 147]         --