In [1]:
import torch
import torch.nn as nn

def channel_norm(h):
    """
    使用GroupNorm对H按照-3轴进行归一化，并把m,std拼接在h的后面
    """
    b, c, rx, tx = h.shape
    
    # 使用GroupNorm进行归一化 (将每个通道视为一个group)
    h_norm = nn.GroupNorm(c, c)(h)  # 这样每个通道独立归一化
    
    # 计算均值和标准差
    m = torch.mean(h, dim=-3, keepdim=True)  # shape: (b, 1, rx, tx)
    std = torch.std(h, dim=-3, keepdim=True)  # shape: (b, 1, rx, tx)
    
    # 拼接
    h_with_stats = torch.cat([h_norm, m, std], dim=-3)
    
    return h_with_stats

In [2]:
h = torch.randn(2, 3, 4, 5)
print(h.shape)
hnorm = channel_norm(h)
print(hnorm.shape)

torch.Size([2, 3, 4, 5])
torch.Size([2, 5, 4, 5])


In [4]:
m = torch.mean(h, dim=-3, keepdim=True)  # shape: (b, 1, rx, tx)
std = torch.std(h, dim=-3, keepdim=True)  # shape: (b, 1, rx, tx)
m.shape, std.shape

(torch.Size([2, 1, 4, 5]), torch.Size([2, 1, 4, 5]))

In [11]:
m1 = hnorm[:, 3, ...].unsqueeze(1)
m-m1 

tensor([[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]], grad_fn=<SubBackward0>)

In [12]:
m1 = hnorm[:, 4, ...].unsqueeze(1)
std-m1 

tensor([[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]], grad_fn=<SubBackward0>)