### 1. pytorch hooks  -> x.retain_grad

In [15]:
import torch

x = torch.Tensor([0,1,2,3]).requires_grad_()
y = torch.Tensor([0,1,2,3]).requires_grad_()
w = torch.Tensor([0,1,2,3]).requires_grad_()

z = x + y
z.retain_grad()  # if we save grad
output = w.matmul(z)
output.retain_grad() #if we save grad
output.backward()

print(z)
print(output)
print('x.requires_grad:', x.requires_grad) # True
print('y.requires_grad:', y.requires_grad) # True
print('z.requires_grad:', z.requires_grad) # True
print('w.requires_grad:', w.requires_grad) # True
print('o.requires_grad:', output.requires_grad) # True

# their grad is not avaiable
print('x.grad:', x.grad) # tensor([1., 2., 3., 4.])
print('y.grad:', y.grad) # tensor([1., 2., 3., 4.])
print('w.grad:', w.grad) # tensor([ 4.,  6.,  8., 10.])
print('z.grad:', z.grad) # None
print('o.grad:', output.grad) # None

tensor([0., 2., 4., 6.], grad_fn=<AddBackward0>)
tensor(28., grad_fn=<DotBackward>)
x.requires_grad: True
y.requires_grad: True
z.requires_grad: True
w.requires_grad: True
o.requires_grad: True
x.grad: tensor([0., 1., 2., 3.])
y.grad: tensor([0., 1., 2., 3.])
w.grad: tensor([0., 2., 4., 6.])
z.grad: tensor([0., 1., 2., 3.])
o.grad: tensor(1.)


### 2. hook_fn(grad) -> Tensor or None


In [26]:
import torch

x = torch.Tensor([0, 1, 2, 3]).requires_grad_()
y = torch.Tensor([4, 5, 6, 7]).requires_grad_()
w = torch.Tensor([1, 2, 3, 4]).requires_grad_()
z = x+y

# ===================
def hook_fn(grad):
    print(grad)

def hook_fn_update(grad):
    grad = 2 * grad
    print(grad)
    return grad

z.register_hook(hook_fn_update)
# ===================

o = w.matmul(z)

print('=====Start backprop=====')
o.backward()
print('=====End backprop=====')

print('x.grad:', x.grad)
print('y.grad:', y.grad)
print('w.grad:', w.grad)
print('z.grad:', z.grad)

=====Start backprop=====
tensor([2., 4., 6., 8.])
=====End backprop=====
x.grad: tensor([2., 4., 6., 8.])
y.grad: tensor([2., 4., 6., 8.])
w.grad: tensor([ 4.,  6.,  8., 10.])
z.grad: None




### 3. Hook for Modules

网络模块 module 不像上一节中的 Tensor，拥有显式的变量名可以直接访问，而是被封装在神经网络中间。我们通常只能获得网络整体的输入和输出，对于夹在网络中间的模块，我们不但很难得知它输入/输出的梯度,为了解决这个麻烦，PyTorch 设计了两种 hook.

+ ``register_forward_hook``: **获取/修改**前向传播过程中，各个网络模块的输入和输出，**不返回任何值**
+ ``register_backward_hook``： 获取神经网络反向传播过程中，各个模块***输入端***和***输出端的梯度值***

In [49]:
import torch
from torch import nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(3, 4)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(4, 1)
        self.initialize()
    
    def initialize(self):
        with torch.no_grad():
            self.fc1.weight = torch.nn.Parameter(
                torch.Tensor([[1., 2., 3.],
                              [-4., -5., -6.],
                              [7., 8., 9.],
                              [-10., -11., -12.]]))

            self.fc1.bias = torch.nn.Parameter(torch.Tensor([1.0, 2.0, 3.0, 4.0]))
            self.fc2.weight = torch.nn.Parameter(torch.Tensor([[1.0, 2.0, 3.0, 4.0]]))
            self.fc2.bias = torch.nn.Parameter(torch.Tensor([1.0]))
    
    def forward(self, x):
        o = self.fc1(x)
        o = self.relu1(o)
        o = self.fc2(o)
        return o

# 全局变量，存储中间 features
total_feat_out = []
total_feat_in = []

def hook_fn_forward(module, input, output):
    print(module) # 用于区分模块
    print('input', input) # 首先打印出来
    print('output', output)
    total_feat_out.append(output) # 然后分别存入全局 list 中
    total_feat_in.append(input)

    
total_grad_out = []
total_grad_in = []
def hook_fn_backward(module, grad_input, grad_output):
    print(module)
    print('grad input ', grad_input)
    print('grad ouput ', grad_output)
    total_grad_in.append(grad_input)
    total_grad_out.append(grad_output)

model = Model()
for name, module in model.named_children():
    module.register_forward_hook(hook_fn_forward)
    module.register_backward_hook(hook_fn_backward)

x = torch.Tensor([[1.0, 1.0, 1.0]]).requires_grad_() 
o = model(x)
o.backward()

print('\n\n==========Saved inputs and outputs==========')
for idx in range(len(total_feat_in)):
    print('input: ', total_feat_in[idx])
    print('output: ', total_feat_out[idx])
    
print('\n\n==========Saved backward grad==========')
for idx in range(len(total_feat_in)):
    print('input: ', total_grad_in[idx])
    print('output: ', total_grad_out[idx])

Linear(in_features=3, out_features=4, bias=True)
input (tensor([[1., 1., 1.]], requires_grad=True),)
output tensor([[  7., -13.,  27., -29.]], grad_fn=<AddmmBackward>)
ReLU()
input (tensor([[  7., -13.,  27., -29.]], grad_fn=<AddmmBackward>),)
output tensor([[ 7.,  0., 27.,  0.]], grad_fn=<ReluBackward0>)
Linear(in_features=4, out_features=1, bias=True)
input (tensor([[ 7.,  0., 27.,  0.]], grad_fn=<ReluBackward0>),)
output tensor([[89.]], grad_fn=<AddmmBackward>)
Linear(in_features=4, out_features=1, bias=True)
grad input  (tensor([1.]), tensor([[1., 2., 3., 4.]]), tensor([[ 7.],
        [ 0.],
        [27.],
        [ 0.]]))
grad ouput  (tensor([[1.]]),)
ReLU()
grad input  (tensor([[1., 0., 3., 0.]]),)
grad ouput  (tensor([[1., 2., 3., 4.]]),)
Linear(in_features=3, out_features=4, bias=True)
grad input  (tensor([1., 0., 3., 0.]), tensor([[22., 26., 30.]]), tensor([[1., 0., 3., 0.],
        [1., 0., 3., 0.],
        [1., 0., 3., 0.]]))
grad ouput  (tensor([[1., 0., 3., 0.]]),)


input