在 PyTorch 中，hook 是一种用于在神经网络的前向传播或反向传播过程中插入自定义操作的机制。hook 可以用于调试、可视化、梯度裁剪等任务。PyTorch 提供了三种类型的 hook：

- Forward Hook：在前向传播过程中执行。

In [24]:
import torch
import torch.nn as nn

# 定义一个简单的神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 创建网络实例
net = SimpleNet()

# 定义一个 forward hook 函数
def forward_hook(module, input, output):
    print(f"Inside {module.__class__.__name__} forward hook")
    print(f"Input: {input}")
    print(f"Output: {output}")

# 定义一个 forward hook 函数
def forward_hook2(module, input, output):
    print(f"Inside {module.__class__.__name__} forward hook")
    print(f"Input: {input}")
    print(f"Output: {output}")
    modified_input = (input[0] * 2,)
    print(f"Modified Input: {modified_input}")

    modified_output = output * 100000
    print(f"Modified Output: {modified_output}")
    return modified_output
    

# 注册 forward hook
hook_handle = net.fc1.register_forward_hook(forward_hook2)
# hook_handle2 = net.fc2.register_forward_hook(forward_hook2)

# 创建一个随机输入
x = torch.randn(1, 10)

# 前向传播
output = net(x)
print(output)

# 移除 hook
# hook_handle.remove()
hook_handle.remove()

Inside Linear forward hook
Input: (tensor([[ 1.2408, -0.7267,  0.2505, -1.8252, -1.7672, -1.7213,  1.7390, -0.4912,
          1.1297,  1.5345]]),)
Output: tensor([[ 1.2112, -0.6202, -0.4395, -0.1143, -0.4221]],
       grad_fn=<AddmmBackward0>)
Modified Input: (tensor([[ 2.4816, -1.4534,  0.5010, -3.6504, -3.5343, -3.4426,  3.4780, -0.9825,
          2.2594,  3.0689]]),)
Modified Output: tensor([[121118.4531, -62017.8086, -43953.1016, -11430.0859, -42211.1094]],
       grad_fn=<MulBackward0>)
tensor([[19626.5566]], grad_fn=<AddmmBackward0>)


- Backward Hook：在反向传播过程中执行

In [16]:
import torch
import torch.nn as nn

# 定义一个简单的神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# 创建网络实例
net = SimpleNet()

# 定义一个 backward hook 函数
def backward_hook(module, grad_input, grad_output):
    print(f"Inside {module.__class__.__name__} backward hook")
    print(f"Grad Input: {grad_input}")
    print(f"Grad Output: {grad_output}")

# 注册 backward hook
hook_handle = net.fc1.register_backward_hook(backward_hook)

# 创建一个随机输入
x = torch.randn(1, 10, requires_grad=True)

# 前向传播
output = net(x)

# 计算损失
loss = output.sum()

# 反向传播
loss.backward()

# 移除 hook
hook_handle.remove()

Inside Linear backward hook
Grad Input: (tensor([-0.0967, -0.3673,  0.0775,  0.1022, -0.0757]), tensor([[-0.0853,  0.0116, -0.0570, -0.0550, -0.0887, -0.0518,  0.0237,  0.0228,
          0.0409,  0.0147]]), tensor([[-0.0214, -0.0812,  0.0171,  0.0226, -0.0167],
        [ 0.0058,  0.0221, -0.0047, -0.0062,  0.0046],
        [ 0.1036,  0.3937, -0.0831, -0.1095,  0.0812],
        [ 0.0635,  0.2411, -0.0509, -0.0671,  0.0497],
        [ 0.0156,  0.0591, -0.0125, -0.0164,  0.0122],
        [ 0.0289,  0.1098, -0.0232, -0.0305,  0.0226],
        [-0.0808, -0.3071,  0.0648,  0.0854, -0.0633],
        [ 0.0338,  0.1285, -0.0271, -0.0357,  0.0265],
        [-0.0303, -0.1153,  0.0243,  0.0321, -0.0238],
        [ 0.1349,  0.5124, -0.1081, -0.1425,  0.1057]]))
Grad Output: (tensor([[-0.0967, -0.3673,  0.0775,  0.1022, -0.0757]]),)


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


- Pre-Forward Hook：在前向传播之前执行。

In [1]:
import torch
import torch.nn as nn

# 定义一个简单的神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5, bias=False)
        self.fc2 = nn.Linear(5, 1, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# 创建网络实例
net = SimpleNet()

# 定义一个 pre-forward hook 函数
def pre_forward_hook(module, input):
    print(f"Inside {module.__class__.__name__} pre-forward hook")
    input = (input[0] * 0,)
    print(f"Input: {input}")
    return input

# 注册 pre-forward hook
hook_handle = net.fc1.register_forward_pre_hook(pre_forward_hook)

# 创建一个随机输入
x = torch.randn(1, 10)

# 前向传播
output = net(x)
print(f"Output: {output}")

# 移除 hook
hook_handle.remove()

Inside Linear pre-forward hook
Input: (tensor([[-0., 0., -0., -0., 0., -0., 0., 0., -0., 0.]]),)
Output: tensor([[0.]], grad_fn=<MmBackward0>)


In [None]:

# 定义一个 post-backward hook 函数
def post_backward_hook(module, grad_input, grad_output):
    print(f"\nInside {module.__class__.__name__} post-backward hook")
    print(f"Original grad_input shapes: {[g.shape if g is not None else None for g in grad_input]}")
    print(f"Original grad_output shapes: {[g.shape if g is not None else None for g in grad_output]}")
    
    # 修改梯度示例：将第一个梯度输入乘以0.5
    modified_grad_input = list(grad_input)
    if modified_grad_input[0] is not None:
        modified_grad_input[0] = modified_grad_input[0] * 0.01
    
    return tuple(modified_grad_input)

# 注册 post-backward hook
hook_handle = net.fc1.register_full_backward_hook(post_backward_hook)

# 创建一个随机输入和标签
x = torch.randn(1, 10, requires_grad=True)
target = torch.randn(1, 1)

# 前向传播
output = net(x)
print(f"Output: {output}")

# 定义损失函数
criterion = nn.MSELoss()
loss = criterion(output, target)

# 反向传播
loss.backward()

# 打印梯度
print("\nGradients after backward:")
for name, param in net.named_parameters():
    if param.grad is not None:
        print(f"{name}: {param.grad.shape}")

# 移除 hook
hook_handle.remove()

Output: tensor([[-0.1126]], grad_fn=<MmBackward0>)

Inside Linear post-backward hook
Original grad_input: [torch.Size([1, 10])]
Original grad_output: [torch.Size([1, 5])]

Gradients after backward:
fc1.weight: torch.Size([5, 10])
fc2.weight: torch.Size([1, 5])
