# 파이토치

- BN을 정확히 어디에 적용을 해야 할까? Inception을 통과한, 즉 1x1 Conv를 통과한 값에 BN을 적용해야 하는지? 아니면 pre_x + x에 BN을 적용해야 하는지? 아니면 ConvBlock 자체에 BN을 적용시켜서 모두 적용해야 하는지?


- Fig.10, 11, 13의 1x1 Conv의 Linear라는게 여기는 BN을 적용하지 않는다는 뜻인듯. 논문에서 "we used batch-normalization only on top of the traditional layers, but not on top of the summations." 라고 나온 부분. traditional layer가 뭔가? 그냥 Conv 레이어를 의미하는건가?


- 깃허브 스타를 제일 많이 받은 코드를 따르면, ConvBlock 자체에 BN을 적용시키지만, 각 모듈의 Inception 부분을 통과하고 1x1 Conv 레이어를 통과시킬 때는 BN 적용 안한다. 그렇다면 traditional layer는 그냥 Conv 레이어를 의미하는듯.


- 모든 ResNet 관련 모듈에서, 인셉션 모듈을 통과한 값인 residual, 즉 1x1 Conv까지 통과한 값을 0.1~0.3의 값을 곱해서 scaling하고 원래 인풋 값과 더해서 relu를 통과시켜준다. (Fig.20과 "3.3: Scaling of the Residuals"을 참고) 


- 스케일링 값은 깃허브를 참고해서 A, B, C 모듈 각각 0.17, 0.1, 0.2 사용.


- BN 관련 참고 깃허브: https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/inceptionresnetv2.py


- Reduction B는 두가지 있음. Fig.12는 smaller, Fig.18은 wider 버전에 쓰이는건데, Fig.15의 설명을 보녀 smaller가 v1, wider가 v2에 쓰이는 것이다. 따라서 여기선 smaller, 즉 Fig.12를 쓴다. Fig.18은 v2의 Reduction B에 쓰인다.


- 논문에선 Inception-ResNet-V1과 Inception v3의 파라미터 수가 비슷하다는데, 내가 구현한 것으로는 Inception-ResNet-V1이 약 1천만개 정도 더 적다. 이는 Inception v3를 구현할 때 제대로 지정된 채널 수가 많지 않아서 그런걸로 보인다(ex. 해당 논문의 Fig.6 모듈 구현에서 f6 변수 등). 그걸 잘 조절하면 비슷해질 것 같다.

In [1]:
import torch
import torch.nn as nn


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        
        return x
    

class Stem(nn.Module):
    def __init__(self):
        super(Stem, self).__init__()
        
        self.layer = nn.Sequential(
            ConvBlock(3, 32, kernel_size=3, stride=2, padding=0),
            ConvBlock(32, 32, kernel_size=3, stride=1, padding=0),
            ConvBlock(32, 64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
            ConvBlock(64, 80, kernel_size=1, stride=1, padding=0),
            ConvBlock(80, 192, kernel_size=3, stride=1, padding=0),
            ConvBlock(192, 256, kernel_size=3, stride=2, padding=0))
        
    def forward(self, x):
        
        x = self.layer(x)
        
        return x
    
    
class InceptionResA(nn.Module):
    def __init__(self, in_ch, scale):
        super(InceptionResA, self).__init__()
        
        self.scaling = scale
        
        self.branch1 = nn.Sequential(
            ConvBlock(in_ch, 32, kernel_size=1, stride=1, padding=0))
        
        self.branch2 = nn.Sequential(
            ConvBlock(in_ch, 32, kernel_size=1, stride=1, padding=0),
            ConvBlock(32, 32, kernel_size=3, stride=1, padding=1))
        
        self.branch3 = nn.Sequential(
            ConvBlock(in_ch, 32, kernel_size=1, stride=1, padding=0),
            ConvBlock(32, 32, kernel_size=3, stride=1, padding=1),
            ConvBlock(32, 32, kernel_size=3, stride=1, padding=1))
        
        self.conv1x1 = nn.Conv2d(96, 256, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        pre_x = x
        
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        
        x = torch.cat([x1, x2, x3], dim=1)
        
        x = self.conv1x1(x)
        x = self.relu(x)
        
        out = pre_x + x*self.scaling
        out = self.relu(out)
        
        return out

    
class ReductionA(nn.Module): # k=192, l=192, m=256, n=384
    def __init__(self, in_ch, k, l, m, n):
        super(ReductionA, self).__init__()
        
        self.branch1 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
        
        self.branch2 = nn.Sequential(
            ConvBlock(in_ch, n, kernel_size=3, stride=2, padding=0))
        
        self.branch3 = nn.Sequential(
            ConvBlock(in_ch, k, kernel_size=1, stride=1, padding=0),
            ConvBlock(k, l, kernel_size=3, stride=1, padding=1),
            ConvBlock(l, m, kernel_size=3, stride=2, padding=0))
        
    def forward(self, x):
        
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        
        return torch.cat([x1, x2, x3], dim=1)

    
class InceptionResB(nn.Module):
    def __init__(self, in_ch, scale):
        super(InceptionResB, self).__init__()
        
        self.scaling = scale
        
        self.branch1 = nn.Sequential(
            ConvBlock(in_ch, 128, kernel_size=1, stride=1, padding=0))
        
        self.branch2 = nn.Sequential(
            ConvBlock(in_ch, 128, kernel_size=1, stride=1, padding=0),
            ConvBlock(128, 128, kernel_size=(1,7), stride=1, padding=(0,3)),
            ConvBlock(128, 128, kernel_size=(7,1), stride=1, padding=(3,0)))
        
        self.conv1x1 = nn.Conv2d(256, 896, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        pre_x = x
        
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        
        x = torch.cat([x1, x2], dim=1)
        
        x = self.conv1x1(x)
        x = self.relu(x)
        
        out = pre_x + x*self.scaling
        out = self.relu(out)
        
        return out
    
class ReductionB(nn.Module):
    def __init__(self, in_ch):
        super(ReductionB, self).__init__()
        
        self.branch1 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
        
        self.branch2 = nn.Sequential(
            ConvBlock(in_ch, 256, kernel_size=1, stride=1, padding=0),
            ConvBlock(256, 384, kernel_size=3, stride=2, padding=0))
        
        self.branch3 = nn.Sequential(
            ConvBlock(in_ch, 256, kernel_size=1, stride=1, padding=0),
            ConvBlock(256, 256, kernel_size=3, stride=2, padding=0))
        
        self.branch4 = nn.Sequential(
            ConvBlock(in_ch, 256, kernel_size=1, stride=1, padding=0),
            ConvBlock(256, 256, kernel_size=3, stride=1, padding=1),
            ConvBlock(256, 256, kernel_size=3, stride=2, padding=0))
        
    def forward(self, x):
        
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x4 = self.branch4(x)
        
        x = torch.cat([x1, x2, x3, x4], dim=1)
        
        return x
    
    
class InceptionResC(nn.Module):
    def __init__(self, in_ch, scale):
        super(InceptionResC, self).__init__()
        
        self.scaling = scale
        
        self.branch1 = nn.Sequential(
            ConvBlock(in_ch, 192, kernel_size=1, stride=1, padding=0))
        
        self.branch2 = nn.Sequential(
            ConvBlock(in_ch, 192, kernel_size=1, stride=1, padding=0),
            ConvBlock(192, 192, kernel_size=(1,3), stride=1, padding=(0,1)),
            ConvBlock(192, 192, kernel_size=(3,1), stride=1, padding=(1,0)))
        
        self.conv1x1 = nn.Conv2d(384, 1792, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        pre_x = x
        
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        
        x = torch.cat([x1, x2], dim=1)
        
        x = self.conv1x1(x)
        x = self.relu(x)
        
        out = pre_x + x*self.scaling
        out = self.relu(out)
        
        return out

In [4]:
class Inception_ResNet_V1(nn.Module):
    def __init__(self, num_classes = 1000):
        super(Inception_ResNet_V1, self).__init__()
        
        layers = []
        layers.append(Stem())
        
        for _ in range(5):
            layers.append(InceptionResA(256, 0.17))
            
        layers.append(ReductionA(256, 192, 192, 256, 384))
        
        for _ in range(10):
            layers.append(InceptionResB(896, 0.1))
            
        layers.append(ReductionB(896))
        
        for _ in range(5):
            layers.append(InceptionResC(1792, 0.2))
        
        self.feature = nn.Sequential(*layers)
        
        self.globalavgpool = nn.AdaptiveAvgPool2d((1,1))
        self.dropout = nn.Dropout(0.8)
        self.linear = nn.Linear(1792, num_classes)
        
    def forward(self, x):
        
        x = self.feature(x)            
        x = self.globalavgpool(x)
        x = self.dropout(x)
        x = x.view(x.size(0),-1)
        x = self.linear(x)
            
        return x

In [5]:
if __name__ == '__main__':

    from torchsummary import summary
    model = Inception_ResNet_V1()
    summary(model, (3,299,299))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 149, 149]             864
       BatchNorm2d-2         [-1, 32, 149, 149]              64
              ReLU-3         [-1, 32, 149, 149]               0
         ConvBlock-4         [-1, 32, 149, 149]               0
            Conv2d-5         [-1, 32, 147, 147]           9,216
       BatchNorm2d-6         [-1, 32, 147, 147]              64
              ReLU-7         [-1, 32, 147, 147]               0
         ConvBlock-8         [-1, 32, 147, 147]               0
            Conv2d-9         [-1, 64, 147, 147]          18,432
      BatchNorm2d-10         [-1, 64, 147, 147]             128
             ReLU-11         [-1, 64, 147, 147]               0
        ConvBlock-12         [-1, 64, 147, 147]               0
        MaxPool2d-13           [-1, 64, 73, 73]               0
           Conv2d-14           [-1, 80,

            ReLU-125           [-1, 32, 35, 35]               0
       ConvBlock-126           [-1, 32, 35, 35]               0
          Conv2d-127           [-1, 32, 35, 35]           9,216
     BatchNorm2d-128           [-1, 32, 35, 35]              64
            ReLU-129           [-1, 32, 35, 35]               0
       ConvBlock-130           [-1, 32, 35, 35]               0
          Conv2d-131           [-1, 32, 35, 35]           9,216
     BatchNorm2d-132           [-1, 32, 35, 35]              64
            ReLU-133           [-1, 32, 35, 35]               0
       ConvBlock-134           [-1, 32, 35, 35]               0
          Conv2d-135          [-1, 256, 35, 35]          24,832
            ReLU-136          [-1, 256, 35, 35]               0
            ReLU-137          [-1, 256, 35, 35]               0
   InceptionResA-138          [-1, 256, 35, 35]               0
          Conv2d-139           [-1, 32, 35, 35]           8,192
     BatchNorm2d-140           [-1, 32, 

          Conv2d-253          [-1, 128, 17, 17]         114,688
     BatchNorm2d-254          [-1, 128, 17, 17]             256
            ReLU-255          [-1, 128, 17, 17]               0
       ConvBlock-256          [-1, 128, 17, 17]               0
          Conv2d-257          [-1, 128, 17, 17]         114,688
     BatchNorm2d-258          [-1, 128, 17, 17]             256
            ReLU-259          [-1, 128, 17, 17]               0
       ConvBlock-260          [-1, 128, 17, 17]               0
          Conv2d-261          [-1, 896, 17, 17]         230,272
            ReLU-262          [-1, 896, 17, 17]               0
            ReLU-263          [-1, 896, 17, 17]               0
   InceptionResB-264          [-1, 896, 17, 17]               0
          Conv2d-265          [-1, 128, 17, 17]         114,688
     BatchNorm2d-266          [-1, 128, 17, 17]             256
            ReLU-267          [-1, 128, 17, 17]               0
       ConvBlock-268          [-1, 128, 

          Conv2d-381          [-1, 896, 17, 17]         230,272
            ReLU-382          [-1, 896, 17, 17]               0
            ReLU-383          [-1, 896, 17, 17]               0
   InceptionResB-384          [-1, 896, 17, 17]               0
       MaxPool2d-385            [-1, 896, 8, 8]               0
          Conv2d-386          [-1, 256, 17, 17]         229,376
     BatchNorm2d-387          [-1, 256, 17, 17]             512
            ReLU-388          [-1, 256, 17, 17]               0
       ConvBlock-389          [-1, 256, 17, 17]               0
          Conv2d-390            [-1, 384, 8, 8]         884,736
     BatchNorm2d-391            [-1, 384, 8, 8]             768
            ReLU-392            [-1, 384, 8, 8]               0
       ConvBlock-393            [-1, 384, 8, 8]               0
          Conv2d-394          [-1, 256, 17, 17]         229,376
     BatchNorm2d-395          [-1, 256, 17, 17]             512
            ReLU-396          [-1, 256, 

            ReLU-509            [-1, 192, 8, 8]               0
       ConvBlock-510            [-1, 192, 8, 8]               0
          Conv2d-511           [-1, 1792, 8, 8]         689,920
            ReLU-512           [-1, 1792, 8, 8]               0
            ReLU-513           [-1, 1792, 8, 8]               0
   InceptionResC-514           [-1, 1792, 8, 8]               0
AdaptiveAvgPool2d-515           [-1, 1792, 1, 1]               0
         Dropout-516           [-1, 1792, 1, 1]               0
          Linear-517                 [-1, 1000]       1,793,000
Total params: 22,756,328
Trainable params: 22,756,328
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.02
Forward/backward pass size (MB): 409.69
Params size (MB): 86.81
Estimated Total Size (MB): 497.52
----------------------------------------------------------------


In [6]:
# Find total parameters and trainable parameters
model = Inception_ResNet_V1()
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

22,756,328 total parameters.
22,756,328 training parameters.
