In [2]:
import torch
from torch import nn
from torchvision.ops.misc import Permute
from torchvision.ops import StochasticDepth

In [3]:
class CNBlock(nn.Module):
    def __init__(self, in_channels, layer_scale, stochastic_depth_prob):
        super().__init__()

        self.residual = nn.Sequential(nn.Conv2d(in_channels, in_channels, 7, padding=3, groups=in_channels), # torchvision 구현체 보면 bias=True
                                      Permute([0, 2, 3, 1]), # 개채행열 -> 개행열채
                                      nn.LayerNorm(in_channels, eps=1e-6), # pixel-wise로 채널 축의 값들을 이용해서 normalize 하는 것
                                      # nn.LayerNorm 은 평균, 분산을 "last D dimensions" 에 대해 구한다.
                                      # 예를 들어, x = torch.randn(개, 채, 행, 열) 라면 layer_norm = nn.LayerNorm([채, 행, 열]) 이렇게 주면
                                      # 채,행,열 에 대해서 평균, 분산 구하는 데 참여시킨다는 뜻
                                      # 그러나 [개, 행, 열] 이런 식으로 건너뛸 순 없다! 따라서 채널 축의 값들만 이용하고 싶다면 Permute 해줘야
                                      Permute([0, 3, 1, 2]), # 개행열채 -> 개채행열
                                      nn.Conv2d(in_channels, 4 * in_channels, 1),
                                      nn.GELU(),
                                      nn.Conv2d(4 * in_channels, in_channels, 1))
        self.layer_scale = nn.Parameter(torch.ones(1,in_channels, 1, 1) * layer_scale)
        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")

    def forward(self, x):
        residual = self.layer_scale * self.residual(x) # 어떤 channel이 중요한지를 학습시키자 (SE Net 아이디어 비슷)
        residual = self.stochastic_depth(residual)
        out = residual + x
        return out

class ConvNeXt(nn.Module):
    def __init__(self, block_setting, stochastic_depth_prob = 0.0, layer_scale = 1e-6, num_classes = 1000, **kwargs):
        super().__init__()

        layers = []
        layers += [nn.Sequential(nn.Conv2d(3, block_setting[0][0], kernel_size=4, stride=4),
                                 Permute([0, 2, 3, 1]),
                                 nn.LayerNorm(block_setting[0][0], eps=1e-6),
                                 Permute([0, 3, 1, 2]))]

        total_stage_blocks = sum([setting[2] for setting in block_setting])
        stage_block_id = 0
        for in_channels, out_channels, num_blocks in block_setting:
            stage = []
            for _ in range(num_blocks):
                sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1) # 1 빼야 마지막 블록이 설정한 stochastic_depth_prob을 가지게 된다.
                stage.append(CNBlock(in_channels, layer_scale, sd_prob))
                stage_block_id += 1
            layers += [nn.Sequential(*stage)]
            if out_channels is not None:
                downsample = nn.Sequential(Permute([0, 2, 3, 1]),
                                           nn.LayerNorm(in_channels),
                                           Permute([0, 3, 1, 2]),
                                           nn.Conv2d(in_channels, out_channels, kernel_size=2, stride=2))
                layers += [downsample]

        self.features = nn.Sequential(*layers)

        # https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py#L160 참고
        # Swin과 달리 pre-activation이 아니라서 LN-GAP-fc 가 아님. 근데 GAP-fc 만 하기엔 마지막 LN이 멀어서 GAP-fc 사이에 LN을 추가한 듯
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(nn.LayerNorm(block_setting[-1][0]),
                                        nn.Linear(block_setting[-1][0], num_classes))

        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.trunc_normal_(m.weight, std=0.02) # 논문엔 0.2 라 나와있는데 torchvision 코드는 0.02로 되어있음. 참고: https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py#L165
                # timm 코드도 0.02 로 되어있어서 코드도 0.02로 반영. 참고: https://github.com/huggingface/pytorch-image-models/blob/4d9c3ae2fb7cc4739ec57d4c06254d2ffc7e2c89/timm/models/convnext.py#L380
                # head init scale 은 fine-tuning 때 하는 것이므로 여기선 생략. 적용 코드 참고: https://github.com/huggingface/pytorch-image-models/blob/4d9c3ae2fb7cc4739ec57d4c06254d2ffc7e2c89/timm/models/convnext.py#L383
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [4]:
def ConvNeXt_T(**kwargs):
    block_setting = [[96, 192, 3], # 192는 sep. d.s. conv에서 patch merging 따라하기 위해 똑같이 임베딩 차원 두 배 하는 것
                     [192, 384, 3],
                     [384, 768, 9],
                     [768, None, 3]]
    return ConvNeXt(block_setting, stochastic_depth_prob = 0.1,  **kwargs)

def ConvNeXt_S(**kwargs):
    block_setting = [[96, 192, 3],
                     [192, 384, 3],
                     [384, 768, 27],
                     [768, None, 3]]
    return ConvNeXt(block_setting, stochastic_depth_prob = 0.4, **kwargs)

def ConvNeXt_B(**kwargs):
    block_setting = [[128, 256, 3],
                     [256, 512, 3],
                     [512, 1024, 27],
                     [1024, None, 3]]
    return ConvNeXt(block_setting, stochastic_depth_prob = 0.5, **kwargs)

def ConvNeXt_L(**kwargs):
    block_setting = [[192, 384, 3],
                     [384, 768, 3],
                     [768, 1536, 27],
                     [1536, None, 3]]
    return ConvNeXt(block_setting, stochastic_depth_prob = 0.5, **kwargs)

In [5]:
model = ConvNeXt_L()
# print(model)
!pip install torchinfo
from torchinfo import summary
summary(model, input_size=(2,3,224,224), device='cpu')

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


Layer (type:depth-idx)                        Output Shape              Param #
ConvNeXt                                      [2, 1000]                 --
├─Sequential: 1-1                             [2, 1536, 7, 7]           --
│    └─Sequential: 2-1                        [2, 192, 56, 56]          --
│    │    └─Conv2d: 3-1                       [2, 192, 56, 56]          9,408
│    │    └─Permute: 3-2                      [2, 56, 56, 192]          --
│    │    └─LayerNorm: 3-3                    [2, 56, 56, 192]          384
│    │    └─Permute: 3-4                      [2, 192, 56, 56]          --
│    └─Sequential: 2-2                        [2, 192, 56, 56]          --
│    │    └─CNBlock: 3-5                      [2, 192, 56, 56]          306,048
│    │    └─CNBlock: 3-6                      [2, 192, 56, 56]          306,048
│    │    └─CNBlock: 3-7                      [2, 192, 56, 56]          306,048
│    └─Sequential: 2-3                        [2, 384, 28, 28]          --
│

In [6]:
x = torch.randn(2,3,224,224)
print(model(x).shape)

torch.Size([2, 1000])


In [7]:
# Layer scale 기법에 대해
class MyModel(nn.Module):
    def __init__(self, in_channels, layer_scale):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.layer_scale = nn.Parameter(torch.ones(1, in_channels, 1, 1) * layer_scale)

    def forward(self, x):
        out = self.conv(x)
        print(out.shape)
        print(self.layer_scale.shape)
        print(self.layer_scale)
        print(out)
        out = out * self.layer_scale  # Apply layer scaling (broadcasting)
        print(out)
        return out

model = MyModel(in_channels=2, layer_scale=0.5)
y=model(torch.randn(1, 2, 4, 4))
print(y.shape)

torch.Size([1, 2, 4, 4])
torch.Size([1, 2, 1, 1])
Parameter containing:
tensor([[[[0.5000]],

         [[0.5000]]]], requires_grad=True)
tensor([[[[ 0.0205,  0.3084,  0.3084,  0.5715],
          [-0.0061,  0.5617,  0.5617,  1.1278],
          [-0.0061,  0.5617,  0.5617,  1.1278],
          [ 0.2057,  0.7788,  0.7788,  1.1372]],

         [[ 0.4279, -0.0342, -0.0342, -0.2843],
          [ 0.2425, -0.1454, -0.1454, -0.2460],
          [ 0.2425, -0.1454, -0.1454, -0.2460],
          [ 0.3192,  0.0039,  0.0039, -0.1718]]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[[ 0.0103,  0.1542,  0.1542,  0.2857],
          [-0.0030,  0.2809,  0.2809,  0.5639],
          [-0.0030,  0.2809,  0.2809,  0.5639],
          [ 0.1029,  0.3894,  0.3894,  0.5686]],

         [[ 0.2139, -0.0171, -0.0171, -0.1421],
          [ 0.1213, -0.0727, -0.0727, -0.1230],
          [ 0.1213, -0.0727, -0.0727, -0.1230],
          [ 0.1596,  0.0019,  0.0019, -0.0859]]]], grad_fn=<MulBackward0>)
torch.Size([1, 2, 4, 4

In [8]:
# Layer norm, 각 픽셀 위치 마다, 채널 축에 대해서 샘플의 평균, 분산 구하려면?
x = torch.randn(2, 3, 2, 2)
ln = nn.LayerNorm(3, eps=1e-5)
y = ln(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
print("After LayerNorm: ", y) # 각 픽셀 위치에서 채널 축으로 평균 0 분산 1 이 되도록 함
print("weight shape: ", ln.weight.shape)

After LayerNorm:  tensor([[[[-1.1600,  0.1625],
          [ 0.9776,  0.6300]],

         [[ 1.2806,  1.1354],
          [ 0.3957, -1.4115]],

         [[-0.1205, -1.2979],
          [-1.3733,  0.7815]]],


        [[[-0.3606,  0.1804],
          [-1.4109,  1.0528]],

         [[-1.0039, -1.3049],
          [ 0.6233,  0.2913]],

         [[ 1.3645,  1.1246],
          [ 0.7876, -1.3441]]]], grad_fn=<PermuteBackward0>)
weight shape:  torch.Size([3])


In [9]:
# Layer norm, 각 픽셀 위치 마다, 채널 축에 대해서 하려면? (직접 구하기)
mean = x.mean(dim=1, keepdim=True)
std = x.std(dim=1, keepdim=True, unbiased=False)
print(mean.shape)
print(std.shape)
y_manual = (x - mean) / (std + 1e-5)
print("Manual normalization: ", y_manual)

torch.Size([2, 1, 2, 2])
torch.Size([2, 1, 2, 2])
Manual normalization:  tensor([[[[-1.1600,  0.1625],
          [ 0.9778,  0.6300]],

         [[ 1.2805,  1.1354],
          [ 0.3958, -1.4115]],

         [[-0.1205, -1.2979],
          [-1.3736,  0.7815]]],


        [[[-0.3606,  0.1804],
          [-1.4110,  1.0528]],

         [[-1.0039, -1.3049],
          [ 0.6233,  0.2913]],

         [[ 1.3645,  1.1245],
          [ 0.7877, -1.3441]]]])


In [10]:
# 배치놈은 어땟나
x = torch.randn(2, 3, 2, 2)
bn = nn.BatchNorm2d(3, eps=1e-5) # BN은 반대로 3에 해당하는 '채' 제외, '개행열'에 대해 평균 분산 구하는 것으로 구현되어 있다
y_bn = bn(x)
print("After BatchNorm2d: ", y_bn)
print("weight shape: ", bn.weight.shape)

After BatchNorm2d:  tensor([[[[ 0.2045,  1.1235],
          [-1.4211,  0.1841]],

         [[ 0.2215,  0.4723],
          [ 1.8592, -1.3537]],

         [[ 1.0293,  0.4474],
          [ 1.3295, -1.1833]]],


        [[[ 0.9182,  0.7681],
          [ 0.0142, -1.7914]],

         [[-0.5079, -1.2443],
          [-0.2132,  0.7662]],

         [[ 0.8556, -1.5449],
          [-0.3716, -0.5620]]]], grad_fn=<NativeBatchNormBackward0>)
weight shape:  torch.Size([3])


In [11]:
# 배치놈은 어땟나 (직접 구하기)
mean = x.mean(dim=(0, 2, 3), keepdim=True)
std = x.std(dim=(0, 2, 3), keepdim=True, unbiased=False)
print(mean.shape)
print(std.shape)
y_manual = (x - mean) / (std + 1e-5)
print("Manual Batch Normalization: ", y_manual)

torch.Size([1, 3, 1, 1])
torch.Size([1, 3, 1, 1])
Manual Batch Normalization:  tensor([[[[ 0.2045,  1.1235],
          [-1.4211,  0.1841]],

         [[ 0.2215,  0.4723],
          [ 1.8591, -1.3537]],

         [[ 1.0293,  0.4474],
          [ 1.3295, -1.1833]]],


        [[[ 0.9182,  0.7681],
          [ 0.0142, -1.7914]],

         [[-0.5079, -1.2443],
          [-0.2132,  0.7662]],

         [[ 0.8556, -1.5449],
          [-0.3716, -0.5620]]]])


In [12]:
# 흔히 알려진 layer norm 그림에선..
x = torch.randn(2, 3, 2, 2)
ln = nn.LayerNorm([3, 2, 2], eps=1e-5)
y = ln(x)
print("After LayerNorm: ", y)
print("weight shape: ", ln.weight.shape) # nn.LayerNorm([C,H,W]) 이면 C*H*W 개에 대해서 평균, 분산 내서 normalize 하는 거고 재배치할 평균, 분산도 각각 CHW 개다

After LayerNorm:  tensor([[[[-2.0943, -1.3114],
          [-0.9862,  0.1284]],

         [[ 0.1934,  0.9875],
          [ 0.2357,  0.6703]],

         [[-0.5168,  1.5059],
          [ 0.3218,  0.8657]]],


        [[[-0.0390,  0.0820],
          [ 0.0447,  0.3354]],

         [[ 0.6128,  0.7946],
          [ 0.3978, -3.1131]],

         [[-0.3371, -0.0195],
          [ 0.8816,  0.3597]]]], grad_fn=<NativeLayerNormBackward0>)
weight shape:  torch.Size([3, 2, 2])


In [13]:
# 흔히 알려진 layer norm 그림에선.. (직접 구하기)
mean = x.mean(dim=(1, 2, 3), keepdim=True)
std = x.std(dim=(1, 2, 3), keepdim=True, unbiased=False)
print(mean.shape)
print(std.shape)
y_manual = (x - mean) / (std + 1e-5)
print("Manual normalization: ", y_manual)

torch.Size([2, 1, 1, 1])
torch.Size([2, 1, 1, 1])
Manual normalization:  tensor([[[[-2.0943, -1.3114],
          [-0.9862,  0.1284]],

         [[ 0.1934,  0.9875],
          [ 0.2357,  0.6703]],

         [[-0.5168,  1.5059],
          [ 0.3218,  0.8657]]],


        [[[-0.0390,  0.0820],
          [ 0.0447,  0.3354]],

         [[ 0.6128,  0.7946],
          [ 0.3978, -3.1131]],

         [[-0.3371, -0.0195],
          [ 0.8816,  0.3597]]]])


In [14]:
# # 내 필기용
# from functools import partial

# def power(base, exponent):
#     return base ** exponent

# square = partial(power, 2)  # 밑을 2로 고정
# cube = partial(power, 3)  # 밑을 3으로 고정

# print(square(4))  # 출력: 16 (2의 4제곱)
# print(cube(3))  # 출력: 27 (3의 3제곱)