In [1]:
import torch
import torch.nn as nn

## BatchNormalization

In [None]:
class BatchNorm_1d_2d(nn.Module):
    """Some Information about BatchNorm_1d_2d"""
    def __init__(self, num_channels, num_dims):
        super().__init__()

        if num_dims == 2:
            shape = (1, num_channels)

        else:
            shape = (1, num_channels, 1, 1)

        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))

        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def batch_norm(self, gamma, beta, x, moving_mean, moving_var, eps, momentum):
        if not torch.is_grad_enabled():
            # * 这个moving_mean和moving_var用于推理
            x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)

        else:
            if len(x.shape) == 2:
                cur_mean = torch.mean(x, dim=0)
                cur_var = torch.mean((x - cur_mean) ** 2, dim=0)

            else:

                cur_mean = torch.mean(x, dim=(0, 2, 3), keepdim=True)
                cur_var = torch.mean((x - cur_mean) ** 2, dim=(0,2,3), keepdim=True)

            x_hat = (x - cur_mean) / torch.sqrt(cur_var + eps)

            # * 这个计算方法叫做指数移动平均
            # * 之所以采用这种方法是为了让之前计算的小批量数据的mean和variance能够贡献于全局数据的均值与方差
            moving_mean = momentum * moving_mean + (1 - momentum) * cur_mean
            moving_var = momentum * moving_mean + (1 - momentum) * cur_var

        y = gamma * x_hat + beta

        return y, moving_mean, moving_var

    def forward(self, x):
        if self.moving_mean.device != x.device:
            self.moving_mean = self.moving_mean.to(x.device)
            self.moving_var = self.moving_var.to(x.device)

        output, self.moving_mean, self.moving_var = self.batch_norm(self.gamma, self.beta, x, self.moving_mean, self.moving_var, eps=1e-5, momentum=0.9)

        return output


x_test = torch.randn(4, 3, 64, 64)

batch_norm = BatchNorm_1d_2d(3, 4)


print(batch_norm(x_test).shape)


x_test = torch.randn(4, 64)

batch_norm = BatchNorm_1d_2d(64, 2)


print(batch_norm(x_test).shape)