In [1]:
import torch
import torch.nn as nn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class MyMul(nn.Module):
    def forward(self, input):
        out = input * 2
        return out

class MyMean(nn.Module):            # 自定义除法module
    def forward(self, input):
        out = input/4
        return out

def tensor_hook(grad):
    print('tensor hook')
    print('grad:', grad)
    return grad

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.f1 = nn.Linear(4, 1, bias=True)    
        self.f2 = MyMean()
        self.weight_init()

    def forward(self, input):
        self.input = input
        output = self.f1(input)       # 先进行运算1，后进行运算2
        output = self.f2(output)      
        return output

    def weight_init(self):
        self.f1.weight.data.fill_(8.0)    # 这里设置Linear的权重为8
        self.f1.bias.data.fill_(2.0)      # 这里设置Linear的bias为2

    def my_hook(self, module, grad_input, grad_output):
        print('doing my_hook')
        print('original grad:', grad_input)
        print('original outgrad:', grad_output)
        # grad_input = grad_input[0]*self.input   # 这里把hook函数内对grad_input的操作进行了注释，
        # grad_input = tuple([grad_input])        # 返回的grad_input必须是tuple，所以我们进行了tuple包装。
        # print('now grad:', grad_input)        

        return grad_input


In [2]:
input = torch.tensor([1, 2, 3, 4], dtype=torch.float32, requires_grad=True).to(device)

net = MyNet()
net.to(device)

net.register_backward_hook(net.my_hook)   # 这两个hook函数一定要result = net(input)执行前执行，因为hook函数实在forward的时候进行绑定的
input.register_hook(tensor_hook)
result = net(input)

print('result =', result)

result.backward()

print('input.grad:', input.grad)
for param in net.parameters():
    print('{}:grad->{}'.format(param, param.grad))

result = tensor([20.5000], grad_fn=<DivBackward0>)
doing my_hook
original grad: (tensor([0.2500]), None)
original outgrad: (tensor([1.]),)
tensor hook
grad: tensor([2., 2., 2., 2.])
input.grad: tensor([2., 2., 2., 2.])
Parameter containing:
tensor([[8., 8., 8., 8.]], requires_grad=True):grad->tensor([[0.2500, 0.5000, 0.7500, 1.0000]])
Parameter containing:
tensor([2.], requires_grad=True):grad->tensor([0.2500])


