In [2]:
import torch
from torch import nn
from utils.train_utils import train_classify
from utils.gpu_mem_maneger import GPUMemoryManager
from utils.fashion_mnist import load_data_fashion_mnist

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
train_loader, test_loader = load_data_fashion_mnist(64)

## 批量规范化

在每次训练迭代中，我们首先规范化输入，即通过减去其均值并除以其标准差，其中两者均基于当前小批量处理。接下来，我们应用比例系数和比例偏移。

请注意，如果我们尝试使用大小为1的小批量应用批量规范化，我们将无法学到任何东西。这是因为在减去均值之后，每个隐藏单元将为0。所以，只有使用足够大的小批量，批量规范化这种方法才是有效且稳定的。请注意，在应用批量规范化时，批量大小的选择可能比没有批量规范化时更重要。

* 在模型训练过程中，批量规范化利用小批量的均值和标准差，不断调整神经网络的中间输出，使整个神
经网络各层的中间输出值更加稳定。

* 目前对此的有效性解释还较模糊


### 批量规范化层

* 全连接层：h = ϕ(BN(Wx + b))

* 卷积层：在卷积后，激活函数前，对每个通道的所有元素执行规范化

#### 关于训练和推理

批量规范化在训练模式和预测模式下的行为通常不同。

首先，将训练好的模型用于预测时，我们不再需要样本均值中的噪声以及在微批次上估计每个小批次产生的样本方差了。

其次，例如，我们可能需要使用我们的模型对逐个样本进行预测。一种常用的方法是通过移动平均估算整个训练数据集的样本均值和方差，并在预测时使用它们得到确定的输出。

In [4]:
# batch norm的逐步实现
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 预测模式，直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        # 训练模式
        # 如果是全连接层有两个维度，如果是卷积层有4个维度
        assert len(X.shape) in (2, 4)

        if len(X.shape) == 2:
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    
        # 训练模式，使用当前计算的均值和方差进行标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)

        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    
    Y = gamma * X_hat + beta # 缩放和移位
    return Y, moving_mean.data, moving_var.data


In [5]:
# 需要将批量规范化层使用nn.Module的形式，才可以和其他pytorch层相结合
class BatchNorm(nn.Module):
    # num_features：完全连接层的输出数量或卷积层的输出通道数。
    # num_dims：2表示完全连接层，4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数，分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # 如果X不在内存上，将moving_mean和moving_var
        # 复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9
        )
        return Y


In [6]:
# 将批量规范化层添加到LeNet
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10)
)

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(net.parameters(), lr=0.001, weight_decay=0.01)
train_classify(net.to(device), train_loader, test_loader, optimizer, criterion, num_epochs=10)

Epoch 1: 100%|██████████| 938/938 [00:11<00:00, 84.05it/s, accuracy=80.1, loss=0.71]  

Epoch: 1, loss: 0.7096558342070214, acc: 80.07





Epoch: 1, test loss: 0.008006558230519295, test acc: 82.3


Epoch 2: 100%|██████████| 938/938 [00:08<00:00, 110.78it/s, accuracy=85.5, loss=0.414]


Epoch: 2, loss: 0.41396947328978256, acc: 85.52166666666666
Epoch: 2, test loss: 0.008126352483034133, test acc: 81.6


Epoch 3: 100%|██████████| 938/938 [00:07<00:00, 119.78it/s, accuracy=87.2, loss=0.356]

Epoch: 3, loss: 0.35643584684713053, acc: 87.21666666666667





Epoch: 3, test loss: 0.0059641113802790646, test acc: 86.64


Epoch 4: 100%|██████████| 938/938 [00:08<00:00, 108.63it/s, accuracy=88.2, loss=0.328]

Epoch: 4, loss: 0.3280871064185715, acc: 88.155





Epoch: 4, test loss: 0.006109306104481221, test acc: 86.09


Epoch 5: 100%|██████████| 938/938 [00:08<00:00, 115.69it/s, accuracy=88.7, loss=0.31] 

Epoch: 5, loss: 0.3099025943433679, acc: 88.66166666666666





Epoch: 5, test loss: 0.0068259724974632265, test acc: 84.63000000000001


Epoch 6: 100%|██████████| 938/938 [00:08<00:00, 110.74it/s, accuracy=89.4, loss=0.293]


Epoch: 6, loss: 0.293259441113866, acc: 89.36666666666666
Epoch: 6, test loss: 0.0056387895554304126, test acc: 86.72999999999999


Epoch 7: 100%|██████████| 938/938 [00:08<00:00, 106.50it/s, accuracy=89.6, loss=0.283]


Epoch: 7, loss: 0.2826578642990289, acc: 89.625
Epoch: 7, test loss: 0.005060989627242088, test acc: 88.29


Epoch 8: 100%|██████████| 938/938 [00:08<00:00, 113.53it/s, accuracy=90.1, loss=0.268]

Epoch: 8, loss: 0.26813640620217905, acc: 90.085





Epoch: 8, test loss: 0.005220098730176688, test acc: 87.79


Epoch 9: 100%|██████████| 938/938 [00:08<00:00, 108.26it/s, accuracy=90.5, loss=0.259]

Epoch: 9, loss: 0.25900316415533325, acc: 90.46666666666667





Epoch: 9, test loss: 0.00508926648274064, test acc: 88.28


Epoch 10: 100%|██████████| 938/938 [00:09<00:00, 103.73it/s, accuracy=90.7, loss=0.253]

Epoch: 10, loss: 0.2525282779386811, acc: 90.7





Epoch: 10, test loss: 0.005327409267425537, test acc: 87.15


In [8]:
concise_net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
    nn.Linear(84, 10)
)

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(concise_net.parameters(), lr=0.001, weight_decay=0.01)
train_classify(concise_net.to(device), train_loader, test_loader, optimizer, criterion, num_epochs=10)

Epoch 1: 100%|██████████| 938/938 [00:05<00:00, 167.27it/s, accuracy=80.3, loss=0.714]

Epoch: 1, loss: 0.7142524112230425, acc: 80.33





Epoch: 1, test loss: 0.0075736663520336155, test acc: 82.99


Epoch 2: 100%|██████████| 938/938 [00:06<00:00, 148.60it/s, accuracy=85.5, loss=0.414]

Epoch: 2, loss: 0.4138561901030764, acc: 85.49833333333333





Epoch: 2, test loss: 0.007367620442807674, test acc: 82.89999999999999


Epoch 3: 100%|██████████| 938/938 [00:05<00:00, 159.99it/s, accuracy=87, loss=0.363]  

Epoch: 3, loss: 0.3629363060061103, acc: 86.96833333333333





Epoch: 3, test loss: 0.006311663088202477, test acc: 85.53


Epoch 4: 100%|██████████| 938/938 [00:06<00:00, 155.53it/s, accuracy=88, loss=0.335]  

Epoch: 4, loss: 0.3349387014169556, acc: 87.98333333333333





Epoch: 4, test loss: 0.005541987027227879, test acc: 87.15


Epoch 5: 100%|██████████| 938/938 [00:06<00:00, 151.02it/s, accuracy=88.5, loss=0.314]

Epoch: 5, loss: 0.3139982062663986, acc: 88.525





Epoch: 5, test loss: 0.005664476057887077, test acc: 86.88


Epoch 6: 100%|██████████| 938/938 [00:06<00:00, 149.38it/s, accuracy=89.1, loss=0.298]

Epoch: 6, loss: 0.29805169513484814, acc: 89.10166666666667





Epoch: 6, test loss: 0.005693399979174137, test acc: 86.8


Epoch 7: 100%|██████████| 938/938 [00:06<00:00, 152.21it/s, accuracy=89.4, loss=0.286]


Epoch: 7, loss: 0.2861076091874891, acc: 89.39166666666667
Epoch: 7, test loss: 0.004979297122359276, test acc: 88.19


Epoch 8: 100%|██████████| 938/938 [00:06<00:00, 148.34it/s, accuracy=89.8, loss=0.275]

Epoch: 8, loss: 0.27498836208508215, acc: 89.84





Epoch: 8, test loss: 0.005116794270277023, test acc: 88.17


Epoch 9: 100%|██████████| 938/938 [00:06<00:00, 148.71it/s, accuracy=90.2, loss=0.265]

Epoch: 9, loss: 0.26491802209603, acc: 90.16666666666667





Epoch: 9, test loss: 0.005244582363218069, test acc: 87.83999999999999


Epoch 10: 100%|██████████| 938/938 [00:05<00:00, 163.15it/s, accuracy=90.5, loss=0.255]

Epoch: 10, loss: 0.25522983417327977, acc: 90.47833333333334





Epoch: 10, test loss: 0.00473815803155303, test acc: 89.27000000000001
