In [11]:
import numpy as np
import torch

In [15]:
x = np.arange(32).reshape(2, 4, 2, 2)
print(x)

[[[[ 0  1]
   [ 2  3]]

  [[ 4  5]
   [ 6  7]]

  [[ 8  9]
   [10 11]]

  [[12 13]
   [14 15]]]


 [[[16 17]
   [18 19]]

  [[20 21]
   [22 23]]

  [[24 25]
   [26 27]]

  [[28 29]
   [30 31]]]]


In [31]:
###################################
# claude
###################################

def group_norm(x, num_groups, num_channels, eps=1e-5):
    """
    Applies Group Normalization to a tensor.

    Args:
        x: Input tensor of shape (batch_size, channels, height, width)
        num_groups: Number of groups to divide the channels into
        num_channels: Total number of channels
        eps: Small constant for numerical stability

    Returns:
        Normalized tensor of the same shape as input
    """
    # Get input shape
    batch_size, _, height, width = x.shape

    # Ensure number of channels is divisible by number of groups
    assert num_channels % num_groups == 0, "Number of channels must be divisible by number of groups"
    channels_per_group = num_channels // num_groups

    # Reshape input to separate groups
    # From: (batch_size, channels, height, width)
    # To: (batch_size, num_groups, channels_per_group, height, width)
    x = x.reshape(batch_size, num_groups, channels_per_group, height, width)

    # Calculate mean and variance for each group
    # Mean and var shape: (batch_size, num_groups, 1, 1, 1)
    mean = np.mean(x, axis=(2, 3, 4), keepdims=True)
    print('mean', mean)
    var = np.var(x, axis=(2, 3, 4), keepdims=True)

    # Normalize
    x = (x - mean) / np.sqrt(var + eps)

    # Reshape back to original shape
    x = x.reshape(batch_size, num_channels, height, width)

    return x

In [32]:
####################################################
# chatgpt
####################################################

# GroupNorm 설정
num_groups = 2
channels_per_group = x.shape[1] // num_groups  # 4 채널을 2 그룹으로 나누므로, 각 그룹당 2 채널

# 정규화된 결과를 저장할 배열
normalized_x = np.zeros_like(x, dtype=np.float32)

# 각 배치에 대해 독립적으로 정규화를 수행
for n in range(x.shape[0]):
    for g in range(num_groups):
        # 각 그룹에 대한 채널 범위를 설정
        start_channel = g * channels_per_group
        end_channel = start_channel + channels_per_group

        # 해당 그룹의 데이터 추출
        group_data = x[n, start_channel:end_channel, :, :]

        # 평균과 표준편차 계산
        group_mean = np.mean(group_data)
        print("mean", group_mean)

        group_std = np.std(group_data)

        # 정규화 수행
        normalized_x[n, start_channel:end_channel, :, :] = ((group_data - group_mean) / (group_std + 1e-5))  # 1e-5는 수치적 안정성을 위한 작은 값

print("Normalized output:\n", normalized_x)


mean 3.5
mean 11.5
mean 19.5
mean 27.5
Normalized output:
 [[[[-1.5275185  -1.0910847 ]
   [-0.6546508  -0.21821694]]

  [[ 0.21821694  0.6546508 ]
   [ 1.0910847   1.5275185 ]]

  [[-1.5275185  -1.0910847 ]
   [-0.6546508  -0.21821694]]

  [[ 0.21821694  0.6546508 ]
   [ 1.0910847   1.5275185 ]]]


 [[[-1.5275185  -1.0910847 ]
   [-0.6546508  -0.21821694]]

  [[ 0.21821694  0.6546508 ]
   [ 1.0910847   1.5275185 ]]

  [[-1.5275185  -1.0910847 ]
   [-0.6546508  -0.21821694]]

  [[ 0.21821694  0.6546508 ]
   [ 1.0910847   1.5275185 ]]]]


In [34]:
print( group_norm(x, 2, 4) )

mean [[[[[ 3.5]]]


  [[[11.5]]]]



 [[[[19.5]]]


  [[[27.5]]]]]
[[[[-1.52752378 -1.09108841]
   [-0.65465305 -0.21821768]]

  [[ 0.21821768  0.65465305]
   [ 1.09108841  1.52752378]]

  [[-1.52752378 -1.09108841]
   [-0.65465305 -0.21821768]]

  [[ 0.21821768  0.65465305]
   [ 1.09108841  1.52752378]]]


 [[[-1.52752378 -1.09108841]
   [-0.65465305 -0.21821768]]

  [[ 0.21821768  0.65465305]
   [ 1.09108841  1.52752378]]

  [[-1.52752378 -1.09108841]
   [-0.65465305 -0.21821768]]

  [[ 0.21821768  0.65465305]
   [ 1.09108841  1.52752378]]]]


In [35]:
x_torch = torch.tensor(x, dtype=torch.float)
m = torch.nn.GroupNorm(2,4)
print( m(x_torch) )


tensor([[[[-1.5275, -1.0911],
          [-0.6547, -0.2182]],

         [[ 0.2182,  0.6547],
          [ 1.0911,  1.5275]],

         [[-1.5275, -1.0911],
          [-0.6547, -0.2182]],

         [[ 0.2182,  0.6547],
          [ 1.0911,  1.5275]]],


        [[[-1.5275, -1.0911],
          [-0.6547, -0.2182]],

         [[ 0.2182,  0.6547],
          [ 1.0911,  1.5275]],

         [[-1.5275, -1.0911],
          [-0.6547, -0.2182]],

         [[ 0.2182,  0.6547],
          [ 1.0911,  1.5275]]]], grad_fn=<NativeGroupNormBackward0>)


tensor([[[[-1.5275, -1.0911],
          [-0.6547, -0.2182]],

         [[ 0.2182,  0.6547],
          [ 1.0911,  1.5275]],

         [[-1.5275, -1.0911],
          [-0.6547, -0.2182]],

         [[ 0.2182,  0.6547],
          [ 1.0911,  1.5275]]],


        [[[-1.5275, -1.0911],
          [-0.6547, -0.2182]],

         [[ 0.2182,  0.6547],
          [ 1.0911,  1.5275]],

         [[-1.5275, -1.0911],
          [-0.6547, -0.2182]],

         [[ 0.2182,  0.6547],
          [ 1.0911,  1.5275]]]], grad_fn=<NativeGroupNormBackward0>)

array([[[[-1.52752378, -1.09108841],
         [-0.65465305, -0.21821768]],

        [[ 0.21821768,  0.65465305],
         [ 1.09108841,  1.52752378]],

        [[-1.52752378, -1.09108841],
         [-0.65465305, -0.21821768]],

        [[ 0.21821768,  0.65465305],
         [ 1.09108841,  1.52752378]]],


       [[[-1.52752378, -1.09108841],
         [-0.65465305, -0.21821768]],

        [[ 0.21821768,  0.65465305],
         [ 1.09108841,  1.52752378]],

        [[-1.52752378, -1.09108841],
         [-0.65465305, -0.21821768]],

        [[ 0.21821768,  0.65465305],
         [ 1.09108841,  1.52752378]]]])