## ResNet

##### 모델 구조

- 깊어진 신경망을 효과적으로 학습하기 위해 고안된 모델

- **Residual Block**

    - 기울기가 잘 전파될 수 있도록 Skip Connection을 만들어 줌
    
    - Shortcut을 두어 기울기 소실 문제를 방지하는 효과

- Skip Connection

  - `x`가 들어와서 `F(x)`가 나오는 게 기존의 방식이라면, Skip Connection은 `x + F(x)`가 나오게
 함

<img src="https://blog.kakaocdn.net/dn/cjhsNq/btq0Skr7p89/uQGihddf4XJ8z5chvEAZr0/img.png" width="400px">

- Bottleneck Block

  - 레이어를 계속 쌓으면 파라미터 수가 급격히 늘어나게 되고, 이를 해결하기 위해 병목 블록 사용함

  - ResNet34와 달리 ResNet50에는 3x3 Convolution 앞뒤로, **1x1 Convolution**이 붙어 있음
  
    - 1x1 Convolution으로 feature map의 채널 수를 조절하여 파라미터 수를 줄임

<img src="https://blog.kakaocdn.net/dn/GcSbV/btqTkgJVRN1/vkap6By64KbxsCLdpTpyhk/img.png" width="400px">

---

#### Skip Connection 이해

입력 x를 받아서 만들고 싶은 이상적인 값이 `H(x)` 라고 할 때,

(1) `F(x)`로 `H(x)` 만들기

(2) `x + F(x)`로 `H(x)` 만들기

    어느 경우가 더 유리할까?

만들고 싶은 이상적인 `H(x)`라는 것이 `x와 비슷한 값`이라고 가정해보자 -> `H(x) ≈ x`

- Skip Connecntion이 없는 MLP라면?

- 즉, `(x * weight)`을 `x`와 같게 만들어야 한다면,

  - weight matrix는 **identity matrix**가 되어야 함

- Skip Connecntion이 있는 MLP라면?

- 즉, `x + (x * weight)`를 `x`와 같게 만들어야 한다면,

  - weight matrix는 **영행렬**이 되어야 함

- *weight는 애초에 평균 0 근처로 초기화 됨*

- 그렇다면, 어떤 경우가 더 `H(x) ≈ x` 를 만들기 쉬울까?
    - (2) Skip Connection



레이어가 매우 깊다면, 입력으로부터 차근차근 조금씩 값을 바꿔 나가는게 이상적일 것임

즉, 가까운 레이어 사이에서의 값의 변화 `x` -> `H(x)`가 그리 크진 않을 것임

-> `H(x) ≈ x`

Skip Connection이 있을 때는 `x`와 비슷한 `H(x)`를 만들기 쉽기 때문에,

Skip Connecntion을 해준다는 것은

*"값의 변화가 그리 크지 않을테니, 레이어 하나에서 다 하려고 하지 말고 조금씩만 바꿔 나가라" 하고 AI에게 귀띔 해주는 셈*

###### Ref: https://youtu.be/Fypk0ec32BU

---

#### 기본 블록

  - 3x3 합성곱

  - 3x3 합성곱

In [13]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=False):
        super().__init__()

        # 3x3 Conv
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        # 3x3 Conv
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)

        if downsample == True:
            conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
            bn = nn.BatchNorm2d(out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None

        self.downsample = downsample


    def foward(self, x):
        i = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)

        if self.downsample is not None:
            i = self.downsample(i) # Downsampling 적용

        x += i # Skip Connection 적용
        x = self.relu(x)

        return x

#### 병목 블록 : 레이어를 더 깊게 쌓으면서 계산량 줄일 수 있음

  - **1x1 합성곱**

  - 3x3 합성곱
  
  - **1x1 합성곱**

In [14]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1, downsample=False):
        super().__init__()

        # 1x1 Conv
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        # 3x3 Conv
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # 1x1 Conv
        self.conv3 = nn.Conv2d(out_channels, self.expansion*out_channels, kernel_size=1, stride=1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.expansion*out_channels)

        self.relu = nn.ReLU(inplace=True)

        if downsample == True:
            conv = nn.Conv2d(in_channels, self.expansion*out_channels, kernel_size=1, stride=stride, bias=False)
            bn = nn.BatchNorm2d(self.expansion*out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None

        self.downsample = downsample


    def foward(self, x):
        i = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)

        if self.downsample is not None:
            i = self.downsample(i) # Downsampling 적용

        x += i # Skip Connection 적용
        x = self.relu(x)

        return x

#### ResNet 네트워크

In [15]:
class ResNet(nn.Module):
    def __init__(self, config, output_dim, zero_init_residual=False):
        super().__init__()

        block, n_blocks, channels = config # ResNet 호출 시 넘겨줄 값들

        self.in_channels = channels[0]

        assert len(n_blocks) == len(channels) == 4

        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
        self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride=2)
        self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride=2)
        self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.in_channels, output_dim)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def get_resnet_layer(self, block, n_blocks, channels, stride=1):
        layers = []
        if self.in_channels != block.expansion * channels:
            downsample = True
        else:
            downsample = False

        layers.append(block(self.in_channels, channels, stride, downsample))
        for i in range(1, n_blocks):
            layers.append(block(block.expansion*channels, channels))

        self.in_channels = block.expansion * channels
        return nn.Sequential(*layers)


    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        h = x.view(x.shape[0], -1)
        x = self.fc(h)
        return x, h

In [16]:
from collections import namedtuple

ResNetConfig = namedtuple('ResNetConfig', ['block', 'n_blocks', 'channels'])

- 기본 블록을 사용하는 ResNet : 18, 34

In [17]:
resnet18_config = ResNetConfig(
    block=BasicBlock,
    n_blocks=[2, 2, 2, 2],
    channels=[64, 128, 256, 512]
)

resnet34_config = ResNetConfig(
    block=BasicBlock,
    n_blocks=[3, 4, 6, 3],
    channels=[64, 128, 256, 512]
)

- 병목 블록을 사용하는 ResNet : 50, 101, 152

In [18]:
resnet50_config = ResNetConfig(
    block=Bottleneck,
    n_blocks=[3, 4, 6, 3],
    channels=[64, 128, 256, 512]
)

resnet101_config = ResNetConfig(
    block=Bottleneck,
    n_blocks=[3, 4, 23, 3],
    channels=[64, 128, 256, 512]
)

resnet152_config = ResNetConfig(
    block=Bottleneck,
    n_blocks=[3, 8, 36, 3],
    channels=[64, 128, 256, 512]
)

In [19]:
OUTPUT_DIM = 2
model = ResNet(resnet50_config, OUTPUT_DIM)
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_run

#### pre-trained model

In [20]:
pretrained_model = torchvision.models.resnet50(pretrained = True)
pretrained_model

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 224MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [21]:
summary(pretrained_model.to(torch.device('cuda')), input_size=(3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,