# 파이토치

- 예를 들어 Fig.4는 3x3 Conv 필터가 7개이므로 코드 구현을 한다면 input channel 수랑 output channel 수가 7로 같은 것.


- 이때 depthwise 연산을 하려면 채널 수가 7인 feature map을 7개로 쪼개서 각각을 3x3 Conv 필터에 통과시켜야하는데, 이것을 nn.Conv2d의 groups 변수가 해준다.


- groups 변수: 입력 레이어를 in_ch개의 서로 다른 그룹으로 만들어서 해당 연산을 수행한다. 즉 기본 값은 groups = 1.

In [1]:
import torch
import torch.nn as nn


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        
        return x
    

class SeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(SeparableConv, self).__init__()
        
        # 3x3 필터를 통한 공간 연산, 인풋 채널과 아웃풋 채널을 같게 해서 채널은 가만히 둔다. groups 변수 중요.
        self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=1, groups=in_ch, bias=False)
        
        # 1x1 필터를 통한 채널 연산, 필터 크기를 1x1로 해서 공간(feature map 크기)은 가만히 둔다
        self.pointwise = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0, bias=False)
        
        self.bn = nn.BatchNorm2d(out_ch)
        
    def forward(self, x):
        
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.bn(x)
        
        return x

In [2]:
# Entry Flow    

class Entry(nn.Module):
    def __init__(self):
        super(Entry, self).__init__()
        
        self.relu = nn.ReLU()
        
        self.conv_32 = ConvBlock(3, 32, kernel_size=3, stride=2, padding=1)
        self.conv_64 = ConvBlock(32, 64, kernel_size=3, stride=1, padding=1)
        
        self.conv1x1_64 = ConvBlock(64, 128, kernel_size=1, stride=2, padding=0)
        self.conv1x1_128 = ConvBlock(128, 256, kernel_size=1, stride=2, padding=0)
        self.conv1x1_256 = ConvBlock(256, 728, kernel_size=1, stride=2, padding=0)
        
        self.sepconv1 = SeparableConv(64, 128)
        self.sepconv2 = SeparableConv(128, 128)
        
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.sepconv3 = SeparableConv(128, 256)
        self.sepconv4 = SeparableConv(256, 256)
        
        self.sepconv5 = SeparableConv(256, 728)
        self.sepconv6 = SeparableConv(728, 728)
        
        self.conv = ConvBlock(128, 100, kernel_size=1, stride=2, padding=0)
        
    def forward(self, x):
        
        x = self.conv_32(x)
        x = self.conv_64(x)
        
        pre_x1 = x
        pre_x1 = self.conv1x1_64(x)
        
        x = self.sepconv1(x)
        
        x = self.relu(x)        
        x = self.sepconv2(x)
        
        x = self.maxpool(x)
        
        pre_x2 = x
        pre_x2 = self.conv1x1_128(x)
        
        x = self.relu(x+pre_x1)
        x = self.sepconv3(x)
        
        x = self.relu(x)
        x = self.sepconv4(x)
        
        x = self.maxpool(x)
        
        pre_x3 = x
        pre_x3 = self.conv1x1_256(x)
        
        x = self.relu(x+pre_x2)
        x = self.sepconv5(x)
        
        x = self.relu(x)
        x = self.sepconv6(x)
        
        x = self.maxpool(x)
        
        return x+pre_x3

In [3]:
if __name__ == '__main__':

    from torchsummary import summary
    model = Entry()
    summary(model, (3,299,299))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 150, 150]             864
       BatchNorm2d-2         [-1, 32, 150, 150]              64
              ReLU-3         [-1, 32, 150, 150]               0
         ConvBlock-4         [-1, 32, 150, 150]               0
            Conv2d-5         [-1, 64, 150, 150]          18,432
       BatchNorm2d-6         [-1, 64, 150, 150]             128
              ReLU-7         [-1, 64, 150, 150]               0
         ConvBlock-8         [-1, 64, 150, 150]               0
            Conv2d-9          [-1, 128, 75, 75]           8,192
      BatchNorm2d-10          [-1, 128, 75, 75]             256
             ReLU-11          [-1, 128, 75, 75]               0
        ConvBlock-12          [-1, 128, 75, 75]               0
           Conv2d-13         [-1, 64, 150, 150]             576
           Conv2d-14        [-1, 128, 1

In [4]:
# Middle Flow

class Middle(nn.Module):
    def __init__(self):
        super(Middle, self).__init__()
        
        self.relu = nn.ReLU()
        self.sepconv = SeparableConv(728, 728)
        
    def forward(self, x):
        
        pre_x = x
        
        for _ in range(3):
            x = self.relu(x)
            x = self.sepconv(x)
        
        return x+pre_x

In [5]:
if __name__ == '__main__':

    from torchsummary import summary
    model = Middle()
    summary(model, (728,19,19))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
              ReLU-1          [-1, 728, 19, 19]               0
            Conv2d-2          [-1, 728, 19, 19]           6,552
            Conv2d-3          [-1, 728, 19, 19]         529,984
       BatchNorm2d-4          [-1, 728, 19, 19]           1,456
     SeparableConv-5          [-1, 728, 19, 19]               0
              ReLU-6          [-1, 728, 19, 19]               0
            Conv2d-7          [-1, 728, 19, 19]           6,552
            Conv2d-8          [-1, 728, 19, 19]         529,984
       BatchNorm2d-9          [-1, 728, 19, 19]           1,456
    SeparableConv-10          [-1, 728, 19, 19]               0
             ReLU-11          [-1, 728, 19, 19]               0
           Conv2d-12          [-1, 728, 19, 19]           6,552
           Conv2d-13          [-1, 728, 19, 19]         529,984
      BatchNorm2d-14          [-1, 728,

In [6]:
# Exit Flow

class Exit(nn.Module):
    def __init__(self):
        super(Exit, self).__init__()
        
        self.relu = nn.ReLU()
        self.conv1x1_728 = ConvBlock(728, 1024, kernel_size=1, stride=2, padding=0)
        
        self.sepconv1 = SeparableConv(728, 728)
        self.sepconv2 = SeparableConv(728, 1024)
        
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.sepconv3 = SeparableConv(1024, 1536)
        self.sepconv4 = SeparableConv(1536, 2048)
        
        self.dropout = nn.Dropout(0.5)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.linear = nn.Linear(2048, 1000)
        
    def forward(self, x):
        
        pre_x = x
        pre_x = self.conv1x1_728(x)
        
        x = self.relu(x)
        x = self.sepconv1(x)
        
        x = self.relu(x)
        x = self.sepconv2(x)
        
        x = self.maxpool(x)
        
        x = self.sepconv3(pre_x+x)
        x = self.relu(x)
        
        x = self.sepconv4(x)
        x = self.relu(x)
        
        x = self.dropout(x)
        x = self.avgpool(x)
        
        x = x.view(x.size(0),-1)
        x = self.linear(x)
        
        return x

In [7]:
if __name__ == '__main__':

    from torchsummary import summary
    model = Exit()
    summary(model, (728,19,19))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 1024, 10, 10]         745,472
       BatchNorm2d-2         [-1, 1024, 10, 10]           2,048
              ReLU-3         [-1, 1024, 10, 10]               0
         ConvBlock-4         [-1, 1024, 10, 10]               0
              ReLU-5          [-1, 728, 19, 19]               0
            Conv2d-6          [-1, 728, 19, 19]           6,552
            Conv2d-7          [-1, 728, 19, 19]         529,984
       BatchNorm2d-8          [-1, 728, 19, 19]           1,456
     SeparableConv-9          [-1, 728, 19, 19]               0
             ReLU-10          [-1, 728, 19, 19]               0
           Conv2d-11          [-1, 728, 19, 19]           6,552
           Conv2d-12         [-1, 1024, 19, 19]         745,472
      BatchNorm2d-13         [-1, 1024, 19, 19]           2,048
    SeparableConv-14         [-1, 1024,

### 최종 Xception

In [8]:
class Xception(nn.Module):    
    def __init__(self):
        super(Xception, self).__init__()
        
        self.entry = Entry()
        self.middle = Middle()
        self.exit = Exit()
        
    def forward(self, x):
        
        x = self.entry(x)        
        for _ in range(8):
            x = self.middle(x)            
        x = self.exit(x)
        
        return x

In [9]:
if __name__ == '__main__':

    from torchsummary import summary
    model = Xception()
    summary(model, (3,299,299))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 150, 150]             864
       BatchNorm2d-2         [-1, 32, 150, 150]              64
              ReLU-3         [-1, 32, 150, 150]               0
         ConvBlock-4         [-1, 32, 150, 150]               0
            Conv2d-5         [-1, 64, 150, 150]          18,432
       BatchNorm2d-6         [-1, 64, 150, 150]             128
              ReLU-7         [-1, 64, 150, 150]               0
         ConvBlock-8         [-1, 64, 150, 150]               0
            Conv2d-9          [-1, 128, 75, 75]           8,192
      BatchNorm2d-10          [-1, 128, 75, 75]             256
             ReLU-11          [-1, 128, 75, 75]               0
        ConvBlock-12          [-1, 128, 75, 75]               0
           Conv2d-13         [-1, 64, 150, 150]             576
           Conv2d-14        [-1, 128, 1