# 批量规范化

In [None]:
import torch
from torch import nn

def batch_norm(x, gamma, beta, moving_mean,
               moving_var, eps, moemntum):
    # 通过is_grad_enabled判断当前模式是否为训练模式
    if not torch.is_grad_enabled():
        x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(x.shape) in (2,4)
        if len(x.shape) == 2:
            # 使用全连接层的情况，计算特征维上的均值和方差
            mean = x.mean(dim=0)
            var = ((x - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况，计算通道维度上（axis=1）的均值和方差。这里我们需要保持
            # 这里我们需要保持x的形状以便后面可以做广播运算
            mean = x.mean(dim=(0,2,3), keepdim=True)
            var = ((x - mean) ** 2).mean(dim=(0,2,3), keepdim=True) 
        # 训练模式下使用当前的均值和方差做标准化
        x_hat = (x - mean) / torch.sqrt(var + eps)  
        # 更新移动平均的均值和方差
        moving_mean = moemntum * moving_mean + (1.0 - moemntum) * mean
        moving_var = moemntum * moving_var + (1.0 - moemntum) * var
    y = gamma * x_hat + beta  # 拉伸和偏移
    return y, moving_mean, moving_var
            
        
