In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch import optim
from torch.optim.lr_scheduler import StepLR
import numpy as np
import time
import copy

In [10]:
class BasicBlock(nn.Module): # in_channels -> out_channels * expansion
    expansion = 1
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        
        # in_channels -> out_channels * expansion
        
        self.residual_function = nn.Sequential(
            # in_channels -> out_channels
            
            nn.Conv2d(in_channels, out_channels, 
                      kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            
            # out_channels -> out_channels * expansion
            
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, 
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion)
        )
        
        # shortcut: in_channels -> out_channels * expansion
        
        self.shortcut = nn.Sequential()
        
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 
            # 1x1 convolution:
            # shortcut의 input은 residual_function을 따라가지 않으므로, residual_function에서 일어나는 size 변화를 아직 따라잡지 못했음.
            # 따라서, shortcut을 거치면 residual_function의 output과 size가 같아지도록 구성함.
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )
            
        #
        
        self.relu = nn.ReLU()
            
    def forward(self, x):
        x = self.residual_function(x) + self.shortcut(x)
        x = self.relu(x)
        
        return x
    
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super().__init__()
        
        self.in_channels = 64
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 
                      kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.conv2_x = self._make_layer(block, 64, num_blocks[0], 1) 
        self.conv3_x = self._make_layer(block, 128, num_blocks[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_blocks[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_blocks[3], 2)
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.fc = nn.Linear(512 * block.expansion, num_classes)
    
    def _make_layer(self, block, out_channels, num_blocks, stride): 
        layers = []
        
        strides = [stride] + [1] * (num_blocks - 1)
        
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            
            self.in_channels = out_channels * block.expansion 
                # self.in_channels을 계속 갱신함으로써, 이전 layer의 출력 channel과 이후 layer의 입력 channel이 같도록 함 
            
        return nn.Sequential(*layers)
        """
        # 맨 처음 호출되었을 때: 
            첫 layer: 64 -> out_channels1 * expansion
            self.in_channels = out_channels1 * block.expansion
            
            이후 layer: out_channels1 * expansion -> out_channels1 * expansion
            self.in_channels = out_channels1 * block.expansion
        
        # 두번째 호출되었을 때:
            첫 layer: out_channels1 * block.expansion -> out_channels2 * block.expansion
            self.in_channels = out_channels2 * block.expansion
            
            이후 layer: out_channels2 * block.expansion -> out_channels2 * block.expansion
            self.in_channels = out_channels2 * block.expansion
        
        # ...
        """
    
    def forward(self, x):
        out = self.conv1(x) # 3 -> 64
        
        out = self.conv2_x(out) # 64 -> 64 * expansion
        out = self.conv3_x(out) # 64 * expansion -> 128 * expansion
        out = self.conv4_x(out) # 128 * expansion -> 256 * expansion
        out = self.conv5_x(out) # 256 * expansion -> 512 * expansion
        
        out = self.avg_pool(out) # (N, 512 * expansion, ?, ?) -> (N, 512 * expansion, 1, 1)
        
        out = out.view(out.size(0), -1) # (N, 512 * expansion, 1, 1) -> (N, 512 * expansion)
        
        out = self.fc(out) # (N, 512 * expansion) -> (N, num_classes)
        
        return out
    
def resnet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])
    
# 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = resnet34().to(device)

x = torch.randn(16, 3, 224, 224).to(device)

output = model(x)

print(output.size())

summary(model, (3, 224, 224))

torch.Size([16, 10])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-