In [100]:
import torch
import torch.nn as nn
class BasicBlock(nn.Module):
    """
    2개의 3*3 컨볼루션 연산
    첫 컨볼루션 연산시 strid값에 따라 down sampling 발생 여부 결정됨
    맨 마지막 output채널수는 output_channels*expansion
    """

    expansion = 1
    def __init__(self , input_channels , output_channels , stride:int = 1 , downsample:bool=False):
        """(conv=>bn=>relu)*2 (kernel_size = 3 , stride = 1 , padding = 1 , bias = False)
        downsample 이 필요한 경우 처음 입력 Tensor에 대하여(conv=>bn) (kernel_size = 1, strid = 2 , padding=0)

        Args:
            input_channels (_type_): 입력 채널 갯수
            output_channels (_type_): 출력 채널 갯수
            stride (int, optional): . Defaults to 1.
            downsample (bool, optional): downsampling 여부. Defaults to False.
        """
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels , output_channels , kernel_size=3 , stride=stride , padding = 1,bias=False)
        self.bn1 = nn.BatchNorm2d(output_channels)
        self.conv2 = nn.Conv2d(output_channels , output_channels , kernel_size=3 , stride=1 , padding = 1,bias=False)
        self.bn2 = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)
        
        if downsample:
            conv = nn.Conv2d(input_channels , output_channels , kernel_size = 1 , stride=2 ,bias=False )
            bn = nn.BatchNorm2d(output_channels)
            self.downsample = nn.Sequential(
                conv,bn
            )
        else:
            self.downsample = None
    def forward(self , x):
        i= x
        # print("input : " , i.shape , self.downsample)
        output = self.conv1(x)
        output = self.bn1(output)
        output = self.relu(output)
        output = self.conv2(output)
        output = self.bn2(output)
        # print("before down => output : ", output.shape , "I=>sampled ", {i.shape})
        if self.downsample is not None:
            i = self.downsample(i)
        # print("output : ", output.shape , "I=>sampled ", {i.shape})
        output += i
        output = self.relu(output)
        return output

        

In [52]:

class Bottleneck(nn.Module):
    """
    3번의 컨볼루션 연산 ()
    1x1(K:1,S:1,P:0) => 3x3(K:3,S:1,P:1) => 1x1(K:1,S:1,P:0) convolution layers
    
    stride가 1이 아닌 경우 (3x3)에서 downsamling 발생
    맨 마지막 output채널수는 output_channels*expansion
    
    """
    expansion = 4
    def __init__(self , input_channels , output_channels , stride = 1 , downsample=False):
        super().__init__()
        # 3개의 conv2(1x1 -> 3x3 -> 1x1)
        self.conv1 = nn.Conv2d(input_channels , output_channels , kernel_size=1 , stride=1 , padding=0,bias=False)
        self.bn1 = nn.BatchNorm2d(output_channels)
        self.conv2 = nn.Conv2d(output_channels , output_channels , kernel_size=3 , stride=stride , padding = 1,bias=False)
        self.bn2 = nn.BatchNorm2d(output_channels)
        self.conv3 = nn.Conv2d(output_channels , self.expansion*output_channels , kernel_size=1 , stride=1 , padding=0,bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*output_channels)
        self.relu = nn.ReLU(inplace = True)
        
        if downsample: 
            # downsample = True 여도, stride =1이면 (H,W)는 변경되지 않고 output 채널수만 변경됨
            # 만약 stride = 2이면 (H,W) => (H/2 , W/2)로 되고 output 채널수 변경됨
            self.downsample = nn.Sequential(
                nn.Conv2d( input_channels , self.expansion*output_channels , kernel_size=1, stride=stride,bias=False),
                nn.BatchNorm2d(self.expansion*output_channels)
            )
        else:
            self.downsample = None
    def forward(self , x):
        I = x
        output = self.conv1(x)
        output = self.bn1(output)
        output = self.relu(output)
        output = self.conv2(output)
        output = self.bn2(output)
        output = self.conv3(output)
        output = self.bn3(output)
        if self.downsample is not None:
            I = self.downsample(I)
        output += I
        output = self.relu(output)
        return output

In [45]:
from collections import namedtuple
class ResNet(nn.Module):
    def __init__(self , config:namedtuple, output_dim:int , zero_init_residual=False):
        """ResNet 클래스

        Args:
            config (_type_): 모델의 형태에 대한 정보
                "block" : basicblock | bottlenect
                "n_blocks": 
                "channels" : 각 블럭에 해당되는 입력 채널 갯수
            output_dim (_type_): class 개수
            zero_init_residual (bool, optional): 초기 가중치 초기화 방법. Defaults to False.
        """
        super().__init__()
        
        block , n_blocks , channels = config
        self.input_channels = channels[0]
        assert len(n_blocks) == len(channels) ==4 ,f"Error : number of n_blocks and channesls must 4 but len(n_blocks) : {len(n_blocks)} , len(channels) :{len(channels)}"
        
        self.conv1 = nn.Conv2d(3 , self.input_channels , kernel_size=7 , stride=2, padding=3,bias=False)
        self.bn1 = nn.BatchNorm2d(self.input_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3 , stride=2, padding=1)
        # 여기까지는 항상(N,channels[0] , 64,64)
        # 크게 4가지의 layer로 구성되며 3번의 downsampling을 통해 64/2^3 = 64/8 = 8로 최종 (H,W) = (8,8)
        self.layer1 = self._get_resnet_layer(block , n_blocks[0] , channels[0])
        self.layer2 = self._get_resnet_layer(block , n_blocks[1] , channels[1],stride = 2)
        self.layer3 = self._get_resnet_layer(block , n_blocks[2] , channels[2],stride = 2)
        self.layer4 = self._get_resnet_layer(block , n_blocks[3] , channels[3],stride = 2)
        
        
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1)) # ([N, output_chanels, 8, 8]) 이거 적용하면 => [N, output_chanels, 1,1] 됨
        
        self.fc = nn.Linear(self.input_channels, output_dim)
        
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m , Bottleneck):
                    nn.init.constant_(m.bn3.weight,0)
                elif isinstance(m , BasicBlock):
                    nn.init.constant_(m.bn2.weight,0)
    
    def _get_resnet_layer(self , block , n_blocks , channels , stride=1):
        layers = []
        
        """총 4번의 _get_resnet_layer함수를 호출
        처음 호출 시는 self.input_channels = 64 , 

        """
        if self.input_channels != block.expansion*channels:
            downsample = True
        else:
            downsample = False
        
        layers.append(block(self.input_channels , channels , stride , downsample))
        # 한개의 layer를 먼저 쌓으면 출력 채널은 입력 채널*block.expansion이 됨
        
        for i in range(1, n_blocks):
            layers.append(block(block.expansion*channels , channels))
        
        # 다음 입력 채널 변경(맨 마지막 for문에서의 출력 채널수는 입력 채널*block.expansion)
        self.input_channels = block.expansion*channels
        return nn.Sequential(*layers)
    
    def forward(self , x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x= self.layer1(x)
        x= self.layer2(x)
        x= self.layer3(x)
        x= self.layer4(x)
        x = self.avgpool(x)
        h = x.view(x.shape[0] , -1)
        x = self.fc(h)
        return x,h
        
    
        

In [101]:
import torchvision.models.resnet as resnet
from collections import namedtuple
ResNetConfig = namedtuple("ResNetConfig" , ["block" , "n_blocks" , "channels"])


In [102]:
resnet18_config = ResNetConfig(block=BasicBlock , n_blocks=[2,2,2,2] , channels=[64,128,256,512])
resnet50_config = ResNetConfig(block=Bottleneck , n_blocks=[3,4,6,3] , channels=[64,128,256,512])

In [103]:
from torchinfo import summary
resnet18 = ResNet(resnet18_config, 2)
summary(resnet18 , input_size = (1,3,256,256))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 2]                    --
├─Conv2d: 1-1                            [1, 64, 128, 128]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 128, 128]         128
├─ReLU: 1-3                              [1, 64, 128, 128]         --
├─MaxPool2d: 1-4                         [1, 64, 64, 64]           --
├─Sequential: 1-5                        [1, 64, 64, 64]           --
│    └─BasicBlock: 2-1                   [1, 64, 64, 64]           --
│    │    └─Conv2d: 3-1                  [1, 64, 64, 64]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 64, 64]           128
│    │    └─ReLU: 3-3                    [1, 64, 64, 64]           --
│    │    └─Conv2d: 3-4                  [1, 64, 64, 64]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 64, 64]           128
│    │    └─ReLU: 3-6                    [1, 64, 64, 64]           --
│

In [105]:
from torchinfo import summary
resnet50_config = ResNetConfig(block=Bottleneck , n_blocks=[3,4,6,3] , channels=[64,128,256,512])
resnet50 = ResNet(resnet50_config, 2)
summary(resnet50 , input_size = (1,3,256,256))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 2]                    --
├─Conv2d: 1-1                            [1, 64, 128, 128]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 128, 128]         128
├─ReLU: 1-3                              [1, 64, 128, 128]         --
├─MaxPool2d: 1-4                         [1, 64, 64, 64]           --
├─Sequential: 1-5                        [1, 256, 64, 64]          --
│    └─Bottleneck: 2-1                   [1, 256, 64, 64]          --
│    │    └─Conv2d: 3-1                  [1, 64, 64, 64]           4,096
│    │    └─BatchNorm2d: 3-2             [1, 64, 64, 64]           128
│    │    └─ReLU: 3-3                    [1, 64, 64, 64]           --
│    │    └─Conv2d: 3-4                  [1, 64, 64, 64]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 64, 64]           128
│    │    └─Conv2d: 3-6                  [1, 256, 64, 64]          16,38

In [41]:
from torchvision import models
pretrain_model = models.resnet50(pretrained=True)

In [48]:
pretrain_model.fc = nn.Linear(in_features=2048 , out_features=2 , bias=True)


In [69]:
from torchinfo import summary
summary(pretrain_model , input_size = (1,3,256,256))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 2]                    --
├─Conv2d: 1-1                            [1, 64, 128, 128]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 128, 128]         128
├─ReLU: 1-3                              [1, 64, 128, 128]         --
├─MaxPool2d: 1-4                         [1, 64, 64, 64]           --
├─Sequential: 1-5                        [1, 256, 64, 64]          --
│    └─Bottleneck: 2-1                   [1, 256, 64, 64]          --
│    │    └─Conv2d: 3-1                  [1, 64, 64, 64]           4,096
│    │    └─BatchNorm2d: 3-2             [1, 64, 64, 64]           128
│    │    └─ReLU: 3-3                    [1, 64, 64, 64]           --
│    │    └─Conv2d: 3-4                  [1, 64, 64, 64]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 64, 64]           128
│    │    └─ReLU: 3-6                    [1, 64, 64, 64]           --
│ 

In [49]:
pretrain_model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [51]:
ResNet(resnet50_config, 2).load_state_dict(pretrain_model.state_dict() )

<All keys matched successfully>

In [24]:
import torchvision
torchvision.models.resnet50

<function torchvision.models.resnet.resnet50(*, weights: Optional[torchvision.models.resnet.ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> torchvision.models.resnet.ResNet>