# Pre-Activation ResNet
기존의 residual에서 순서를 바꿔 성능 향상 
- Batch Normalization -> 활성화 함수 -> convolution layer

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch import optim
import numpy as np
import time
import copy

In [9]:
class BottleNeck(nn.Module):
    # channel 변화: in_channels -> out_channels * expansion
    # residual + shortcut => 출력
    expansion = 4
    
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        
        # residual: in_channels -> out_channels * expansion
        
        self.residual = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, 
                      1, stride=stride),
            
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels,
                      3, stride=1, padding=1),
            
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, 
                      1)
        )
        
        # shortcut: in_channels -> out_channels * expansion
        
        self.shortcut = nn.Sequential()
        
        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Conv2d(in_channels, out_channels * BottleNeck.expansion,
                                      1, stride=stride)
            
    def forward(self, x):
        out = self.residual(x)
        out += self.shortcut(x)
        
        return out
    
class PreActResNet(nn.Module):
    def __init__(self, num_blocks, num_classes=10):
        super().__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64,
                      3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(3, 2, 1)
        )
        
        self.in_channels = 64
        
        self.conv2 = self._make_layers(num_blocks[0], 64, 1)
        self.conv3 = self._make_layers(num_blocks[1], 128, 2)
        self.conv4 = self._make_layers(num_blocks[2], 256, 2)
        self.conv5 = self._make_layers(num_blocks[3], 512, 2)
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.fc = nn.Linear(512 * BottleNeck.expansion, num_classes)
        
    def _make_layers(self, num_blocks, out_channels, stride):
        layers = []
        strides = [stride] + [1] * (num_blocks - 1)
        
        for stride in strides:
            layers.append(BottleNeck(self.in_channels, out_channels, stride))
            
            self.in_channels = out_channels * BottleNeck.expansion # 현재 layer의 출력 channel 개수를 바로다음 layer의 입력 channel 개수로 넘겨줌
            
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        
        x = self.avg_pool(x)
        
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x
    
def PreActResNet50():
    return PreActResNet([3, 4, 6, 3])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = PreActResNet50().to(device)

summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
         MaxPool2d-4         [-1, 64, 112, 112]               0
       BatchNorm2d-5         [-1, 64, 112, 112]             128
              ReLU-6         [-1, 64, 112, 112]               0
            Conv2d-7         [-1, 64, 112, 112]           4,160
       BatchNorm2d-8         [-1, 64, 112, 112]             128
              ReLU-9         [-1, 64, 112, 112]               0
           Conv2d-10         [-1, 64, 112, 112]          36,928
      BatchNorm2d-11         [-1, 64, 112, 112]             128
             ReLU-12         [-1, 64, 112, 112]               0
           Conv2d-13        [-1, 256, 112, 112]          16,640
           Conv2d-14        [-1, 256, 1