In [17]:
import torch
import torch.nn as nn

class CustomLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)
        std = x.var(dim=1, keepdim=True, unbiased=False).sqrt()
        x = (x - mean) / (std + self.eps)
        x = x * self.weight[:, None, None] + self.bias[:, None, None]
        return x

# 초기 텐서 생성
x = torch.randn(2, 96, 56, 56)

# 첫 번째 방식: nn.LayerNorm([56, 56]) 적용
layer_norm1 = nn.LayerNorm([96, 56, 56])
normalized_x1 = layer_norm1(x)

# 두 번째 방식: permute 후 nn.LayerNorm(96) 적용
x_permuted = x.permute(0, 2, 3, 1)  # [2, 56, 56, 96]
layer_norm2 = nn.LayerNorm(96)
normalized_x2 = layer_norm2(x_permuted)

# 세 번째 방식: CustomLayerNorm 사용
custom_layer_norm = CustomLayerNorm(96)
normalized_x3 = custom_layer_norm(x)

# 네 번째 방식: GroupNorm 사용
group_norm = nn.GroupNorm(96, 96)
normalized_x4 = group_norm(x)

# 결과 비교를 위한 표준편차 계산
std1 = normalized_x1.std(dim=[0, 2, 3], unbiased=False)
std2 = normalized_x2.permute(0, 3, 1, 2).std(dim=[0, 2, 3], unbiased=False)  # 원래 차원으로 복귀
std3 = normalized_x3.std(dim=[0, 2, 3], unbiased=False)
std4 = normalized_x4.std(dim=[0,2,3], unbiased=False)

std1, std2, std3, std4

(tensor([0.9917, 0.9980, 1.0210, 0.9905, 1.0053, 1.0068, 0.9942, 1.0201, 0.9816,
         0.9958, 1.0180, 1.0059, 0.9985, 0.9910, 0.9986, 0.9837, 0.9886, 1.0084,
         0.9957, 0.9957, 1.0077, 0.9890, 0.9877, 1.0070, 1.0104, 0.9996, 0.9904,
         1.0004, 0.9909, 0.9894, 0.9895, 1.0045, 0.9990, 0.9849, 1.0025, 0.9937,
         1.0088, 1.0016, 0.9880, 1.0066, 0.9954, 1.0027, 0.9957, 1.0057, 0.9925,
         1.0021, 1.0067, 0.9926, 1.0107, 0.9716, 0.9954, 1.0064, 0.9946, 0.9923,
         1.0039, 0.9959, 1.0022, 1.0017, 1.0136, 1.0150, 0.9984, 1.0175, 0.9983,
         0.9916, 0.9975, 1.0173, 0.9998, 1.0065, 1.0120, 1.0021, 0.9995, 0.9929,
         0.9996, 1.0072, 1.0122, 1.0131, 0.9930, 0.9912, 1.0137, 0.9920, 0.9977,
         1.0005, 0.9932, 1.0052, 0.9938, 1.0091, 0.9908, 1.0182, 1.0022, 1.0080,
         0.9961, 0.9909, 1.0015, 0.9928, 0.9902, 1.0062],
        grad_fn=<StdBackward0>),
 tensor([0.9921, 0.9977, 1.0203, 0.9895, 1.0054, 1.0042, 0.9930, 1.0199, 0.9818,
         0.9960, 1