&emsp;&emsp;在min-batch较小的情况下,批次归一化并不是一个很好的选择.
因为批次较小,可能会造成平均值和方差的波动比较大,在这种情况下进行批次归一化,
最后结果反而会减少训练过程中数值的稳定性.为了解决这个问题,这里引入了组归一化的方法,
这个方法的特点是减少了对批次的依耐性,而且不需要跨批次进行估算.

In [130]:
import torch
import torch.nn as nn

In [131]:
entry = torch.ones(54, 12, 28, 28)
gn = nn.GroupNorm(num_groups=3, # 分组的数量,必须能整除num_chanels
                  num_channels=12, # 与BatchNorm参数类似
                  eps=0, # 与BatchNorm参数类似
                  affine=True) # 与BatchNorm affine参数类似,即是否使用weight和bias进行仿射变换
# 注:GroupNorm没有track_running_stats参数,\mu和\sigma总由本批次的数据计算得出.故预测不需要设定gn.eval()
gn(entry).shape

torch.Size([54, 12, 28, 28])

In [132]:
gn.weight

Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True)

In [133]:
gn.bias

Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

In [134]:
def MyGroupNorm(x, G):
    """GroupNorm进行归一化时的操作"""
    N, C, H, W = x.shape
    x = torch.reshape(x, [N, G, C // G, H, W]) # 分组
    x_mean = torch.mean(x, dim=(2, 3, 4), keepdim=True,)
    x_var = torch.var(x, dim=(2, 3, 4), keepdim=True, unbiased=False)
    x = (x - x_mean) / torch.sqrt(x_var) # 均值为0,方差为1归一化
    x = torch.reshape(x, [N, C, H, W])
    return x

In [135]:
entry1 = torch.arange(1470, dtype=torch.float32).reshape((5, 6, 7, 7))
a = MyGroupNorm(entry1, 2)
a

tensor([[[[-1.7203, -1.6967, -1.6732,  ..., -1.6260, -1.6025, -1.5789],
          [-1.5553, -1.5318, -1.5082,  ..., -1.4611, -1.4375, -1.4140],
          [-1.3904, -1.3668, -1.3433,  ..., -1.2961, -1.2726, -1.2490],
          ...,
          [-1.0605, -1.0369, -1.0133,  ..., -0.9662, -0.9426, -0.9191],
          [-0.8955, -0.8719, -0.8484,  ..., -0.8012, -0.7777, -0.7541],
          [-0.7305, -0.7070, -0.6834,  ..., -0.6363, -0.6127, -0.5891]],

         [[-0.5656, -0.5420, -0.5184,  ..., -0.4713, -0.4478, -0.4242],
          [-0.4006, -0.3771, -0.3535,  ..., -0.3064, -0.2828, -0.2592],
          [-0.2357, -0.2121, -0.1885,  ..., -0.1414, -0.1178, -0.0943],
          ...,
          [ 0.0943,  0.1178,  0.1414,  ...,  0.1885,  0.2121,  0.2357],
          [ 0.2592,  0.2828,  0.3064,  ...,  0.3535,  0.3771,  0.4006],
          [ 0.4242,  0.4478,  0.4713,  ...,  0.5184,  0.5420,  0.5656]],

         [[ 0.5891,  0.6127,  0.6363,  ...,  0.6834,  0.7070,  0.7305],
          [ 0.7541,  0.7777,  

In [136]:
b = nn.GroupNorm(num_channels=6, num_groups=2)(entry1)
b

tensor([[[[-1.7203, -1.6967, -1.6732,  ..., -1.6260, -1.6025, -1.5789],
          [-1.5553, -1.5318, -1.5082,  ..., -1.4611, -1.4375, -1.4140],
          [-1.3904, -1.3668, -1.3433,  ..., -1.2961, -1.2726, -1.2490],
          ...,
          [-1.0605, -1.0369, -1.0133,  ..., -0.9662, -0.9426, -0.9191],
          [-0.8955, -0.8719, -0.8484,  ..., -0.8012, -0.7777, -0.7541],
          [-0.7305, -0.7070, -0.6834,  ..., -0.6363, -0.6127, -0.5891]],

         [[-0.5656, -0.5420, -0.5184,  ..., -0.4713, -0.4478, -0.4242],
          [-0.4006, -0.3771, -0.3535,  ..., -0.3064, -0.2828, -0.2592],
          [-0.2357, -0.2121, -0.1885,  ..., -0.1414, -0.1178, -0.0943],
          ...,
          [ 0.0943,  0.1178,  0.1414,  ...,  0.1885,  0.2121,  0.2357],
          [ 0.2592,  0.2828,  0.3064,  ...,  0.3535,  0.3771,  0.4006],
          [ 0.4242,  0.4478,  0.4713,  ...,  0.5184,  0.5420,  0.5656]],

         [[ 0.5891,  0.6127,  0.6363,  ...,  0.6834,  0.7070,  0.7305],
          [ 0.7541,  0.7777,  

In [137]:
torch.sum(a-b)


tensor(0., grad_fn=<SumBackward0>)