In [1]:
import torch
import torch.nn.functional as F

from torch import nn
from torch import Tensor
from torchvision import transforms
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

import pytorch_lightning as pl

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def conv3x3(in_channels: int, out_channels: int, stride: int = 1, groups: int = 1, dilation: int = 1):
    # BatchNorm에 bias가 있으므로 bias는 False
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_channels: int, out_channels: int, stride: int = 1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)

In [3]:
class BasicBlock(pl.LightningModule):
    def __init__(self, in_channels: int = 3, out_channels: int = 3, stride: int = 1, downsample =  None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer = None):
        super().__init__()

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = norm_layer(out_channels)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = norm_layer(out_channels)
        self.downsample = downsample

    def forward(self, x):
        res = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.ReLU(inplace=True)(x)

        x = self.conv2(x)
        x = self.bn2(x)

        if self.downsample is not None:
            res = self.downsample(res)

        x += res
        x = nn.ReLU(inplace=True)(x)

        return x


In [4]:
class Bottleneck(pl.LightningModule):

    expansion = 4 # 3번째 conv layer에서 차원을 증가시키기 위한 확장계수

    def __init__(self, in_channels: int, out_channels: int, stride: int, downsample = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer = None):
        super().__init__()

        if norm_layer is not None:
            norm_layer = nn.BatchNorm2d

        width = int(out_channels * (base_width / 64.)) * groups
        
        self.conv1 = conv1x1(in_channels, width)
        self.bn1 = norm_layer(width)
        # Kaming He에 따르면 stride를 어디에 배치하든 연산의 차이는 크게 없다. 즉 의미는 없다.
        # 요지는 conv1x1 -> conv3x3 -> conv1x1에 따른 연산량 조절이 핵심.
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, out_channels * self.expansion)
        self.bn3 = norm_layer(out_channels * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        res = x

        # 1x1 conv layer
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        # 3x3 conv layer
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        # 1x1 conv layer
        x = self.conv3(x)
        x = self.bn3(x)
        # skip connection
        if self.downsample is not None:
            res = self.downsample(res)
        
        x += res
        x = self.relu(x)

        return x


In [None]:
class ResNet(pl.LightningModule):
    def __init__(self, 
    block, 
    layers, 
    num_classes = 1000, 
    zero_init_residual=False, 
    groups=1, 
    width_per_group=64, 
    replace_stride_width_dilation=None,
    norm_layer=None
    ):

        super().__init__()

        if norm_layer is not None:
            norm_layer = nn.BatchNorm2d

        self.norm_layer = norm_layer
        
            


In [18]:
summary(BasicBlock(), (3, 224, 224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 3, 224, 224]              81
       BatchNorm2d-2          [-1, 3, 224, 224]               6
            Conv2d-3          [-1, 3, 224, 224]              81
       BatchNorm2d-4          [-1, 3, 224, 224]               6
Total params: 174
Trainable params: 174
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 4.59
Params size (MB): 0.00
Estimated Total Size (MB): 5.17
----------------------------------------------------------------
