# 파이토치

- Inception-ResNet-v1과 모든 모듈의 구조가 같지만, 1x1 Conv 레이어를 통과하고 나온 값의 채널 수가 다르다. 예를 들어, InceptionResA의 경우 v1은 256, v2는 384이다. 


- Fig.15에 나온 채널 수는 v1 기준이기 때문에 v2 구현 시에는 하나하나 계산해야한다.

In [1]:
import torch
import torch.nn as nn


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        
        return x
    
    
class Stem(nn.Module):
    def __init__(self):
        super(Stem, self).__init__()
        
        self.branch1 = nn.Sequential(
            ConvBlock(3, 32, kernel_size=3, stride=2, padding=0),
            ConvBlock(32, 32, kernel_size=3, stride=1, padding=0),
            ConvBlock(32, 64, kernel_size=3, stride=1, padding=1))
        
        self.maxpool_96 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
        self.maxpool_192 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        
        self.conv_96 = ConvBlock(64, 96, kernel_size=3, stride=2, padding=0)
        self.conv_192 = ConvBlock(192, 192, kernel_size=3, stride=2, padding=0) # 논문엔 stride = 2 표시가 안된듯
        
        self.branch2_1 = nn.Sequential(
            ConvBlock(160, 64, kernel_size=1, stride=1, padding=0),
            ConvBlock(64, 96, kernel_size=3, stride=1, padding=0))
        
        self.branch2_2 = nn.Sequential(
            ConvBlock(160, 64, kernel_size=1, stride=1, padding=0),
            ConvBlock(64, 64, kernel_size=(7,1), stride=1, padding=(3,0)),
            ConvBlock(64, 64, kernel_size=(1,7), stride=1, padding=(0,3)),
            ConvBlock(64, 96, kernel_size=3, stride=1, padding=0))
        
    def forward(self, x):
        
        x = self.branch1(x)
        
        x1_1 = self.maxpool_96(x)
        x1_2 = self.conv_96(x)
        
        x = torch.cat([x1_1, x1_2], dim=1)
        
        x2_1 = self.branch2_1(x)
        x2_2 = self.branch2_2(x)
        
        x = torch.cat([x2_1, x2_2], dim=1)
        
        x3_1 = self.conv_192(x)
        x3_2 = self.maxpool_192(x)
        
        return torch.cat([x3_1, x3_2], dim=1) # 아웃풋 채널: 384 = InceptionResA 인풋 채널
    

class InceptionResA(nn.Module):
    def __init__(self, in_ch, scale):
        super(InceptionResA, self).__init__()
        
        self.scaling = scale
        
        self.branch1 = nn.Sequential(
            ConvBlock(in_ch, 32, kernel_size=1, stride=1, padding=0))
        
        self.branch2 = nn.Sequential(
            ConvBlock(in_ch, 32, kernel_size=1, stride=1, padding=0),
            ConvBlock(32, 32, kernel_size=3, stride=1, padding=1))
        
        self.branch3 = nn.Sequential(
            ConvBlock(in_ch, 32, kernel_size=1, stride=1, padding=0),
            ConvBlock(32, 48, kernel_size=3, stride=1, padding=1),
            ConvBlock(48, 64, kernel_size=3, stride=1, padding=1))
        
        # 이 conv1x1 필터의 아웃풋 채널이 인풋으로 들어온 feature map의 채널과 같아야 더해질 수 있음
        self.conv1x1 = nn.Conv2d(128, 384, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        pre_x = x
        
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        
        x = torch.cat([x1, x2, x3], dim=1)
        
        x = self.conv1x1(x)
        x = self.relu(x)
        
        out = pre_x + x*self.scaling
        out = self.relu(out)
        
        return out # 아웃풋 채널: 384 = ReductionA 인풋 채널
    
    
class ReductionA(nn.Module): # k=256, l=256, m=384, n=384
    def __init__(self, in_ch, k, l, m, n):
        super(ReductionA, self).__init__()
        
        self.branch1 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
        
        self.branch2 = nn.Sequential(
            ConvBlock(in_ch, n, kernel_size=3, stride=2, padding=0))
        
        self.branch3 = nn.Sequential(
            ConvBlock(in_ch, k, kernel_size=1, stride=1, padding=0),
            ConvBlock(k, l, kernel_size=3, stride=1, padding=1),
            ConvBlock(l, m, kernel_size=3, stride=2, padding=0))
        
    def forward(self, x):
        
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        
        return torch.cat([x1, x2, x3], dim=1) # 아웃풋 채널: 1152 = InceptionResB 인풋 채널
    
    
class InceptionResB(nn.Module):
    def __init__(self, in_ch, scale):
        super(InceptionResB, self).__init__()
        
        self.scaling = scale
        
        self.branch1 = nn.Sequential(
            ConvBlock(in_ch, 192, kernel_size=1, stride=1, padding=0))
        
        self.branch2 = nn.Sequential(
            ConvBlock(in_ch, 128, kernel_size=1, stride=1, padding=0),
            ConvBlock(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)),
            ConvBlock(160, 192, kernel_size=(7,1), stride=1, padding=(3,0)))
        
        self.conv1x1 = nn.Conv2d(384, 1152, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        pre_x = x
        
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        
        x = torch.cat([x1, x2], dim=1)
        
        x = self.conv1x1(x)
        x = self.relu(x)
        
        out = pre_x + x*self.scaling
        out = self.relu(out)
        
        return out # 아웃풋 채널: 1152 = ReductionB 인풋 채널
    
    
class ReductionB(nn.Module):
    def __init__(self, in_ch):
        super(ReductionB, self).__init__()
        
        self.branch1 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
        
        self.branch2 = nn.Sequential(
            ConvBlock(in_ch, 256, kernel_size=1, stride=1, padding=0),
            ConvBlock(256, 384, kernel_size=3, stride=2, padding=0))
        
        self.branch3 = nn.Sequential(
            ConvBlock(in_ch, 256, kernel_size=1, stride=1, padding=0),
            ConvBlock(256, 288, kernel_size=3, stride=2, padding=0))
        
        self.branch4 = nn.Sequential(
            ConvBlock(in_ch, 256, kernel_size=1, stride=1, padding=0),
            ConvBlock(256, 288, kernel_size=3, stride=1, padding=1),
            ConvBlock(288, 320, kernel_size=3, stride=2, padding=0))
        
    def forward(self, x):
        
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x4 = self.branch4(x) 
        
        return torch.cat([x1, x2, x3, x4], dim=1) # 아웃풋 채널: 2144 = InceptionResC 인풋 채널
    

class InceptionResC(nn.Module):
    def __init__(self, in_ch, scale):
        super(InceptionResC, self).__init__()
        
        self.scaling = scale
        
        self.branch1 = nn.Sequential(
            ConvBlock(in_ch, 192, kernel_size=1, stride=1, padding=0))
        
        self.branch2 = nn.Sequential(
            ConvBlock(in_ch, 192, kernel_size=1, stride=1, padding=0),
            ConvBlock(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)),
            ConvBlock(224, 256, kernel_size=(3,1), stride=1, padding=(1,0)))
        
        self.conv1x1 = nn.Conv2d(448, 2144, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        pre_x = x
        
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        
        x = torch.cat([x1, x2], dim=1)
        
        x = self.conv1x1(x)
        x = self.relu(x)
        
        out = pre_x + x*self.scaling
        out = self.relu(out)
        
        return out

In [2]:
class Inception_ResNet_V2(nn.Module):
    def __init__(self, num_classes = 1000):
        super(Inception_ResNet_V2, self).__init__()
        
        layers = []
        layers.append(Stem())
        
        for _ in range(5):
            layers.append(InceptionResA(384, 0.17))
            
        layers.append(ReductionA(384, 256, 256, 384, 384))
        
        for _ in range(10):
            layers.append(InceptionResB(1152, 0.1))
            
        layers.append(ReductionB(1152))
        
        for _ in range(5):
            layers.append(InceptionResC(2144, 0.2))
        
        self.feature = nn.Sequential(*layers)        
        
        self.globalavgpool = nn.AdaptiveAvgPool2d((1,1))
        self.dropout = nn.Dropout(0.8)
        self.linear = nn.Linear(2144, num_classes)
        
    def forward(self, x):
        
        x = self.feature(x)
        x = self.globalavgpool(x)
        x = self.dropout(x)
        x = x.view(x.size(0),-1)
        x = self.linear(x)
            
        return x

In [3]:
if __name__ == '__main__':

    from torchsummary import summary
    model = Inception_ResNet_V2()
    summary(model, (3,299,299))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 149, 149]             864
       BatchNorm2d-2         [-1, 32, 149, 149]              64
              ReLU-3         [-1, 32, 149, 149]               0
         ConvBlock-4         [-1, 32, 149, 149]               0
            Conv2d-5         [-1, 32, 147, 147]           9,216
       BatchNorm2d-6         [-1, 32, 147, 147]              64
              ReLU-7         [-1, 32, 147, 147]               0
         ConvBlock-8         [-1, 32, 147, 147]               0
            Conv2d-9         [-1, 64, 147, 147]          18,432
      BatchNorm2d-10         [-1, 64, 147, 147]             128
             ReLU-11         [-1, 64, 147, 147]               0
        ConvBlock-12         [-1, 64, 147, 147]               0
        MaxPool2d-13           [-1, 64, 73, 73]               0
           Conv2d-14           [-1, 96,

     BatchNorm2d-125           [-1, 64, 35, 35]             128
            ReLU-126           [-1, 64, 35, 35]               0
       ConvBlock-127           [-1, 64, 35, 35]               0
          Conv2d-128          [-1, 384, 35, 35]          49,536
            ReLU-129          [-1, 384, 35, 35]               0
            ReLU-130          [-1, 384, 35, 35]               0
   InceptionResA-131          [-1, 384, 35, 35]               0
          Conv2d-132           [-1, 32, 35, 35]          12,288
     BatchNorm2d-133           [-1, 32, 35, 35]              64
            ReLU-134           [-1, 32, 35, 35]               0
       ConvBlock-135           [-1, 32, 35, 35]               0
          Conv2d-136           [-1, 32, 35, 35]          12,288
     BatchNorm2d-137           [-1, 32, 35, 35]              64
            ReLU-138           [-1, 32, 35, 35]               0
       ConvBlock-139           [-1, 32, 35, 35]               0
          Conv2d-140           [-1, 32, 

       ConvBlock-253          [-1, 128, 17, 17]               0
          Conv2d-254          [-1, 160, 17, 17]         143,360
     BatchNorm2d-255          [-1, 160, 17, 17]             320
            ReLU-256          [-1, 160, 17, 17]               0
       ConvBlock-257          [-1, 160, 17, 17]               0
          Conv2d-258          [-1, 192, 17, 17]         215,040
     BatchNorm2d-259          [-1, 192, 17, 17]             384
            ReLU-260          [-1, 192, 17, 17]               0
       ConvBlock-261          [-1, 192, 17, 17]               0
          Conv2d-262         [-1, 1152, 17, 17]         443,520
            ReLU-263         [-1, 1152, 17, 17]               0
            ReLU-264         [-1, 1152, 17, 17]               0
   InceptionResB-265         [-1, 1152, 17, 17]               0
          Conv2d-266          [-1, 192, 17, 17]         221,184
     BatchNorm2d-267          [-1, 192, 17, 17]             384
            ReLU-268          [-1, 192, 

       ConvBlock-381          [-1, 192, 17, 17]               0
          Conv2d-382         [-1, 1152, 17, 17]         443,520
            ReLU-383         [-1, 1152, 17, 17]               0
            ReLU-384         [-1, 1152, 17, 17]               0
   InceptionResB-385         [-1, 1152, 17, 17]               0
          Conv2d-386          [-1, 192, 17, 17]         221,184
     BatchNorm2d-387          [-1, 192, 17, 17]             384
            ReLU-388          [-1, 192, 17, 17]               0
       ConvBlock-389          [-1, 192, 17, 17]               0
          Conv2d-390          [-1, 128, 17, 17]         147,456
     BatchNorm2d-391          [-1, 128, 17, 17]             256
            ReLU-392          [-1, 128, 17, 17]               0
       ConvBlock-393          [-1, 128, 17, 17]               0
          Conv2d-394          [-1, 160, 17, 17]         143,360
     BatchNorm2d-395          [-1, 160, 17, 17]             320
            ReLU-396          [-1, 160, 

     BatchNorm2d-509            [-1, 256, 8, 8]             512
            ReLU-510            [-1, 256, 8, 8]               0
       ConvBlock-511            [-1, 256, 8, 8]               0
          Conv2d-512           [-1, 2144, 8, 8]         962,656
            ReLU-513           [-1, 2144, 8, 8]               0
            ReLU-514           [-1, 2144, 8, 8]               0
   InceptionResC-515           [-1, 2144, 8, 8]               0
          Conv2d-516            [-1, 192, 8, 8]         411,648
     BatchNorm2d-517            [-1, 192, 8, 8]             384
            ReLU-518            [-1, 192, 8, 8]               0
       ConvBlock-519            [-1, 192, 8, 8]               0
          Conv2d-520            [-1, 192, 8, 8]         411,648
     BatchNorm2d-521            [-1, 192, 8, 8]             384
            ReLU-522            [-1, 192, 8, 8]               0
       ConvBlock-523            [-1, 192, 8, 8]               0
          Conv2d-524            [-1, 224

In [4]:
# Find total parameters and trainable parameters
model = Inception_ResNet_V2()
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

32,433,928 total parameters.
32,433,928 training parameters.
