## Batch NomalizationとLayer Normalization

### Hook

In [42]:
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from functools import partial
import torchvision
from torchvision import transforms

import utils

#### .register_hook

In [43]:
a = torch.ones(5, requires_grad=True)
b = 2 * a
b.retain_grad()
b.register_hook(lambda grad: print(grad))
c = b.mean()
c.backward()

tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])


#### .register_forward_hook()

In [44]:
conv_model = nn.Sequential(
    # 1x28x28 -> 4x14x14
    nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    # 4x14x14 -> 8x7x7
    nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    # 8x7x7 -> 16x3x3
    nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    # 16x4x4 -> 32x1x1
    nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    #nn.AdaptiveAvgPool2d(1),   # <- GAP
    # 32
    nn.Flatten(),
    # 10
    nn.Linear(32, 10)
)

In [45]:
outputs = {}
def save_output(name, module, inp, out):
    module_name = f'{name}_{str(module)}'
    outputs[module_name] = out.shape
for name, module in conv_model.named_modules():
    if name:
        module.register_forward_hook(partial(save_output, name))

In [46]:
def print_hooks(model):
    for name, module in model.named_modules():
        if hasattr(module, '_forward_hooks'):
            for hook in module._forward_hooks.values():
                print(f'Module {name} has forward hook: {hook}')
        if hasattr(module, '_backward_hooks'):
            for hook in module._backward_hooks.values():
                print(f'Module {name} has backward hook: {hook}')

In [47]:
print_hooks(conv_model)

Module 0 has forward hook: functools.partial(<function save_output at 0x7fb8d312db80>, '0')
Module 1 has forward hook: functools.partial(<function save_output at 0x7fb8d312db80>, '1')
Module 2 has forward hook: functools.partial(<function save_output at 0x7fb8d312db80>, '2')
Module 3 has forward hook: functools.partial(<function save_output at 0x7fb8d312db80>, '3')
Module 4 has forward hook: functools.partial(<function save_output at 0x7fb8d312db80>, '4')
Module 5 has forward hook: functools.partial(<function save_output at 0x7fb8d312db80>, '5')
Module 6 has forward hook: functools.partial(<function save_output at 0x7fb8d312db80>, '6')
Module 7 has forward hook: functools.partial(<function save_output at 0x7fb8d312db80>, '7')
Module 8 has forward hook: functools.partial(<function save_output at 0x7fb8d312db80>, '8')
Module 9 has forward hook: functools.partial(<function save_output at 0x7fb8d312db80>, '9')


In [48]:
# forwardでhook発動
X = torch.randn((1, 1, 28, 28))
output = conv_model(X)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x128 and 32x10)

#### .register_full_backward_hook

In [None]:
grads = {}
def save_grad_in(name, module, grad_in, grad_out):
    module_name = f'{name}_{str(module)}'
    grads[module_name] = grad_in

for name, module in conv_model.named_modules():
    if name:
        module.register_full_backward_hook(partial(save_grad_in, name))

In [None]:
print_hooks(conv_model) 

Module 0 has forward hook: functools.partial(<function save_output at 0x7fb8c1526dc0>, '0')
Module 0 has backward hook: functools.partial(<function save_grad_in at 0x7fb8a17fb280>, '0')
Module 1 has forward hook: functools.partial(<function save_output at 0x7fb8c1526dc0>, '1')
Module 1 has backward hook: functools.partial(<function save_grad_in at 0x7fb8a17fb280>, '1')
Module 2 has forward hook: functools.partial(<function save_output at 0x7fb8c1526dc0>, '2')
Module 2 has backward hook: functools.partial(<function save_grad_in at 0x7fb8a17fb280>, '2')
Module 3 has forward hook: functools.partial(<function save_output at 0x7fb8c1526dc0>, '3')
Module 3 has backward hook: functools.partial(<function save_grad_in at 0x7fb8a17fb280>, '3')
Module 4 has forward hook: functools.partial(<function save_output at 0x7fb8c1526dc0>, '4')
Module 4 has backward hook: functools.partial(<function save_grad_in at 0x7fb8a17fb280>, '4')
Module 5 has forward hook: functools.partial(<function save_output at 

In [None]:
# backward
X = torch.randn((1, 1, 28, 28))
output = conv_model(X)
loss = output.mean()
loss.backward()

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x128 and 32x10)

### Activationの可視化

In [49]:
conv_model = nn.Sequential(
    # 1x28x28 -> 4x14x14
    nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    # 4x14x14 -> 8x7x7
    nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    # 8x7x7 -> 16x3x3
    nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    # 16x4x4 -> 32x1x1
    nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    #nn.AdaptiveAvgPool2d(1),   # <- GAP
    # 32
    nn.Flatten(),
    # 10
    nn.Linear(128, 10)
)

In [50]:
# forward hook
act_means = [[] for module in conv_model if isinstance(module, nn.ReLU)] # RELUの数だけリストを作成
act_stds = [[] for module in conv_model if isinstance(module, nn.ReLU)]

def save_out_stats(module, inp, out):
    act_means.append(out.mean().item())
    act_stds.append(out.std().item())

relu_layers = [module for module in conv_model if isinstance(module, nn.ReLU)]
for i, relu in enumerate(relu_layers):
    relu.register_forward_hook(partial(save_out_stats, i))

In [51]:
# データ準備
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
    ])
train_dataset = torchvision.datasets.FashionMNIST('~/tmp/fashion_mnist', download=True, train=True, transform=transform)
val_dataset = torchvision.datasets.FashionMNIST('~/tmp/fashion_mnist', download=True, train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False)

opt = optim.SGD(conv_model.parameters(), lr=0.6)

In [52]:
train_losses, val_losses, val_accuracies = utils.learn(conv_model, train_loader, val_loader, opt, F.cross_entropy, 3)

                                                

TypeError: save_out_stats() takes 3 positional arguments but 4 were given

## Batch normalization

### Batch norm スクラッチで実装

In [54]:
def batch_norm(X, gamma, beta, eps=1e-5):
    mean = X.mean(dim=(0, 2, 3), keepdim=True)
    var = X.var(dim=(0, 2, 3), keepdim=True)
    X_norm = (X - mean) / torch.sqrt(var + eps)
    return gamma * X_norm + beta

### nn.BatchNorm2d

In [57]:
X , y = next(iter(train_loader))
conv_out = nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1)(X)
norm_out = nn.BatchNorm2d(8)(conv_out)

In [58]:
norm_out.mean(), norm_out.var()

(tensor(7.9828e-10, grad_fn=<MeanBackward0>),
 tensor(1.0000, grad_fn=<VarBackward0>))

### Layer norm スクラッチ実装

In [59]:
def layer_norm(X, gamma, beta, eps=1e-5):
    mean = X.mean(dim=(1, 2, 3), keepdim=True)
    var = X.var(dim=(1, 2, 3), keepdim=True)
    X_norm = (X - mean) / torch.sqrt(var + eps)
    return gamma * X_norm + beta

### nn.LayerNorm

In [60]:
X , y = next(iter(train_loader))
conv_out = nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1)(X)
norm_out = nn.LayerNorm([8, 14, 14])(conv_out)