In [1]:
import d2lzh as d2l
from torch import nn
from torch.nn import init
from torch import optim
from torch.nn import functional as F
import torch

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum, train=True):
    if not train:
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        X_hat = (X - mean) / torch.sqrt(var + eps)
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta
    return Y, moving_mean, moving_var

In [2]:
class BatchNorm(nn.Module):
    def __init__(self, num_features, num_dims):
        super(BatchNorm, self).__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.zeros(shape)
        
    def forward(self, X, train=True):
        Y, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean, self.moving_var, eps=1e-5, momentum=0.9, train=train)
        return Y

In [3]:
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5),
    BatchNorm(6, num_dims=4),
    nn.Sigmoid(),
    nn.MaxPool2d(2),
    nn.Conv2d(6, 16, kernel_size=5),
    BatchNorm(16, num_dims=4),
    nn.Sigmoid(),
    nn.MaxPool2d(2),
    d2l.FlattenLayer(),
    nn.Linear(256, 120),
    BatchNorm(120, num_dims=2),
    nn.Sigmoid(),
    nn.Linear(120, 84),
    BatchNorm(84, num_dims=2),
    nn.Sigmoid(),
    nn.Linear(84, 10)
)

In [4]:
lr, num_epochs, batch_size = 1.0, 5, 256
device = 'cuda' if torch.cuda.is_available() else 'cpu'
for name, param in net.named_parameters():
    if name.endswith('weight'):
        init.xavier_normal_(param)
optimizer = optim.SGD(net.parameters(), lr=lr)
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device,
              num_epochs)

training on cpu
epoch 1, loss 0.0025, train acc 0.777, test acc 0.831, time 17.8 sec
epoch 2, loss 0.0015, train acc 0.861, test acc 0.856, time 18.1 sec
epoch 3, loss 0.0013, train acc 0.879, test acc 0.833, time 18.5 sec
epoch 4, loss 0.0012, train acc 0.887, test acc 0.879, time 18.6 sec
epoch 5, loss 0.0011, train acc 0.893, test acc 0.880, time 18.5 sec


In [5]:
net[1].gamma.data.reshape((-1,)), net[1].beta.data.reshape((-1,))

(tensor([1.3439, 1.7160, 1.9706, 1.6053, 1.8013, 1.1581]),
 tensor([-1.8126,  0.0580, -0.0977,  1.2077, -1.8987,  0.3900]))

In [6]:
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5),
    nn.BatchNorm2d(6),
    nn.Sigmoid(),
    nn.MaxPool2d(2),
    nn.Conv2d(6, 16, kernel_size=5),
    nn.BatchNorm2d(16),
    nn.Sigmoid(),
    nn.MaxPool2d(2),
    d2l.FlattenLayer(),
    nn.Linear(256, 120),
    nn.BatchNorm1d(120),
    nn.Sigmoid(),
    nn.Linear(120, 84),
    nn.BatchNorm1d(84),
    nn.Sigmoid(),
    nn.Linear(84, 10)
)

In [7]:
for name, param in net.named_parameters():
    if name.endswith('weight') and len(param.shape) > 1:
        init.xavier_normal_(param)
optimizer = optim.SGD(net.parameters(), lr=lr)
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device,
              num_epochs)

training on cpu
epoch 1, loss 0.0027, train acc 0.759, test acc 0.792, time 18.3 sec
epoch 2, loss 0.0015, train acc 0.858, test acc 0.855, time 18.7 sec
epoch 3, loss 0.0013, train acc 0.877, test acc 0.848, time 18.6 sec
epoch 4, loss 0.0012, train acc 0.886, test acc 0.837, time 18.6 sec
epoch 5, loss 0.0012, train acc 0.890, test acc 0.882, time 18.5 sec
