In [5]:
import torch
import torch.nn as nn

class CustomBatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-10):
        super(CustomBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            mean = x.mean(dim=0, keepdim=True)
            var = x.var(dim=0, keepdim=True, unbiased=False)
            self.running_mean = (1 - self.eps) * self.running_mean + self.eps * mean
            self.running_var = (1 - self.eps) * self.running_var + self.eps * var
            out = (x - mean) / (var.sqrt() + self.eps)
        else:
            out = (x - self.running_mean) / (self.running_var.sqrt() + self.eps)
        return out

In [6]:
# 创建一个随机输入张量
input_tensor = torch.randn(64, 624, 1)

# 创建一个CustomBatchNorm实例
norm_layer = CustomBatchNorm(num_features=624)

# 将输入张量传递给norm_layer
output_tensor = norm_layer(input_tensor)

# 打印输出张量的形状
print(output_tensor.shape)

torch.Size([64, 624, 1])
