# 파이토치

In [2]:
import torch
import torch.nn as nn

class Bottleneck(nn.Module):
    def __init__(self, in_ch, growth_rate):
        super(Bottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_ch)
        self.conv1 = nn.Conv2d(in_ch, growth_rate*4, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(growth_rate*4)
        self.conv2 = nn.Conv2d(growth_rate*4, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(True)

    def forward(self, x):
        x1 = self.bn1(x)
        x1 = self.relu(x1)
        x1 = self.conv1(x1)
        x1 = self.bn2(x1)
        x1 = self.relu(x1)
        x1 = self.conv2(x1)
        x = torch.cat((x, x1), dim=1)
        return x

class DenseBlock(nn.Module):
    def __init__(self, n_bottleneck, in_ch, growth_rate):
        super(DenseBlock, self).__init__()

        for i in range(n_bottleneck):
          # i번째 DenseBlock을 생성하고 거기에 Bottleneck의 속성을 차례대로 부여한다.
          setattr(self, 'DenseBlock_{}'.format(i), Bottleneck(in_ch+i*growth_rate, growth_rate))

        # forward 함수에서 n_bottleneck를 쓰기 위함
        self.n_bottleneck = n_bottleneck

    def forward(self, x):

        for i in range(self.n_bottleneck):
          # setattr을 통해 만들어진 DenseBlock들을 찾아오고, 평소처럼 레이어 x에 연결된다.
          x = getattr(self, 'DenseBlock_{}'.format(i))(x)
        return x

class Transition(nn.Module):
    def __init__(self, in_ch):
        super(Transition, self).__init__()
        num_ch = int(in_ch*0.5) # 논문에서 제시한 세타 값 = 0.5
        self.bn = nn.BatchNorm2d(in_ch)
        self.relu = nn.ReLU(True)
        self.conv = nn.Conv2d(in_ch, num_ch, kernel_size=3, padding=1, bias=False)
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self,x):
        x = self.bn(x)
        x = self.relu(x)
        x = self.conv(x)
        x = self.avgpool(x)
        return x

class DenseNet(nn.Module): # DenseNet-121
    def __init__(self):
        super(DenseNet, self).__init__()
        self.conv = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.dense1 = DenseBlock(6, 64, 32) # growth rate는 논문에서 항상 32로 고정됨
        self.trans1 = Transition(256) # 최종 in_ch+i*growth_rate -> 64+6*32 = 256
        self.dense2 = DenseBlock(12, 128, 32) # Transition의 conv 거치면 채널 수 절반으로(세타=0.5 영향). 256/2 = 128
        self.trans2 = Transition(512) # 128+12*32 = 512
        self.dense3 = DenseBlock(24, 256, 32) # 512/2 = 256
        self.trans3 = Transition(1024) # 256+24*32 = 1024
        self.dense4 = DenseBlock(16, 512, 32)
        self.avgpool = nn.AdaptiveAvgPool2d((7,7))
        self.linear = nn.Linear(50176, 10)

    def forward(self, x):
        in_size = x.size(0)
        x = self.conv(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.dense1(x)
        x = self.trans1(x)
        x = self.dense2(x)
        x = self.trans2(x)
        x = self.dense3(x)
        x = self.trans3(x)
        x = self.dense4(x)
        x = self.avgpool(x)
        x = x.view(in_size, -1)
        x = self.linear(x)

model = DenseNet()
from torchsummary import summary
summary(model, (3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
       BatchNorm2d-5             [-1, 64, 8, 8]             128
              ReLU-6             [-1, 64, 8, 8]               0
            Conv2d-7            [-1, 128, 8, 8]           8,192
       BatchNorm2d-8            [-1, 128, 8, 8]             256
              ReLU-9            [-1, 128, 8, 8]               0
           Conv2d-10             [-1, 32, 8, 8]          36,864
       Bottleneck-11             [-1, 96, 8, 8]               0
      BatchNorm2d-12             [-1, 96, 8, 8]             192
             ReLU-13             [-1, 96, 8, 8]               0
           Conv2d-14            [-1, 12

          Conv2d-253             [-1, 32, 2, 2]          36,864
      Bottleneck-254            [-1, 768, 2, 2]               0
     BatchNorm2d-255            [-1, 768, 2, 2]           1,536
            ReLU-256            [-1, 768, 2, 2]               0
          Conv2d-257            [-1, 128, 2, 2]          98,304
     BatchNorm2d-258            [-1, 128, 2, 2]             256
            ReLU-259            [-1, 128, 2, 2]               0
          Conv2d-260             [-1, 32, 2, 2]          36,864
      Bottleneck-261            [-1, 800, 2, 2]               0
     BatchNorm2d-262            [-1, 800, 2, 2]           1,600
            ReLU-263            [-1, 800, 2, 2]               0
          Conv2d-264            [-1, 128, 2, 2]         102,400
     BatchNorm2d-265            [-1, 128, 2, 2]             256
            ReLU-266            [-1, 128, 2, 2]               0
          Conv2d-267             [-1, 32, 2, 2]          36,864
      Bottleneck-268            [-1, 832