In [1]:
import time
import torch
from torch import nn
import torch.nn.functional as F

import sys
sys.path.append('..')
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):
    if not is_training:
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
            var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
        X_hat = (X - mean) / torch.sqrt(var + eps)
        
        moving_mean = momentum * moving_mean + (1 - momentum) * mean
        moving_var = momentum * moving_var + (1 - momentum) * var
    Y = gamma * X_hat + beta
    return Y, moving_mean, moving_var

In [3]:
class BatchNorm(nn.Module):
    def __init__(self, num_features, num_dims):
        super(BatchNorm, self).__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.ones(shape))
        
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.zeros(shape)
    
    def forward(self, X):
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        
        Y, self.moving_mean, self.moving_var = batch_norm(self.training, X,
                                                          self.gamma, self.beta, self.moving_mean,
                                                          self.moving_var, eps=1e-5, momentum=0.9)
        return Y

In [4]:
net = nn.Sequential(
    nn.Conv2d(1, 6, 5),
    BatchNorm(6, num_dims=4),
    nn.Sigmoid(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(6, 16, 5),
    BatchNorm(16, num_dims=4),
    nn.Sigmoid(),
    nn.MaxPool2d(2, 2),
    d2l.FlattenLayer(),
    nn.Linear(16 * 4 *4, 120),
    BatchNorm(120, num_dims=2),
    nn.Sigmoid(),
    nn.Linear(120, 84),
    BatchNorm(84, num_dims=2),
    nn.Sigmoid(),
    nn.Linear(84, 10),
)

In [5]:
batch_size = 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, root='../data')

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

training on  cuda:1
epoch 1, loss 0.8857, train acc 0.800, test acc 0.715, time 6.8 sec
epoch 2, loss 0.2002, train acc 0.870, test acc 0.855, time 5.6 sec
epoch 3, loss 0.1114, train acc 0.884, test acc 0.825, time 5.6 sec
epoch 4, loss 0.0757, train acc 0.894, test acc 0.753, time 5.5 sec
epoch 5, loss 0.0567, train acc 0.900, test acc 0.836, time 5.5 sec


In [6]:
net[1].gamma.view((-1, )), net[1].beta.view((-1,))

(tensor([0.8695, 0.7605, 0.7448, 1.1751, 1.1841, 1.1055], device='cuda:1',
        grad_fn=<ViewBackward>),
 tensor([0.4669, 0.3190, 0.7185, 1.0706, 1.2025, 0.6424], device='cuda:1',
        grad_fn=<ViewBackward>))

In [9]:
net = nn.Sequential(
    nn.Conv2d(1, 6, 5),
    nn.BatchNorm2d(6),
    nn.Sigmoid(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(6, 16, 5),
    nn.BatchNorm2d(16),
    nn.Sigmoid(),
    nn.MaxPool2d(2, 2),
    d2l.FlattenLayer(),
    nn.Linear(16 * 4 *4, 120),
    nn.BatchNorm1d(120),
    nn.Sigmoid(),
    nn.Linear(120, 84),
    nn.BatchNorm1d(84),
    nn.Sigmoid(),
    nn.Linear(84, 10),
)

In [10]:
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, root='../data')

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

training on  cuda:1
epoch 1, loss 1.0034, train acc 0.781, test acc 0.817, time 2.3 sec
epoch 2, loss 0.2314, train acc 0.862, test acc 0.812, time 2.2 sec
epoch 3, loss 0.1230, train acc 0.879, test acc 0.860, time 2.2 sec
epoch 4, loss 0.0833, train acc 0.885, test acc 0.854, time 2.2 sec
epoch 5, loss 0.0623, train acc 0.891, test acc 0.862, time 2.3 sec
