# Mastering PyTorch Hooks

Hooks are a powerful feature in PyTorch that allow you to inspect or modify the behavior of a model during the forward or backward pass. They are essential for tasks like:
- **Feature Visualization** (e.g., Grad-CAM)
- **Debugging** (checking gradients for NaNs)
- **Modifying Gradients** (clipping, scaling)

There are two main types of hooks:
1. **Forward Hooks**: Triggered during the forward pass.
2. **Backward Hooks**: Triggered during the backward pass.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Let's define a simple CNN model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=3)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=3)
        self.fc = nn.Linear(320, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = self.fc(x)
        return x

model = SimpleNet()

## 1. Forward Hooks

A forward hook has the signature `hook(module, input, output)`. It allows you to access the input and output of a layer.

In [None]:
# Define a forward hook
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

# Register the hook on the second conv layer
# This is often used in Grad-CAM to get the feature maps
handle_fwd = model.conv2.register_forward_hook(get_activation('conv2'))

# Let's run a forward pass
dummy_input = torch.randn(1, 1, 28, 28)
output = model(dummy_input)

print(f"Captured activation shape: {activation['conv2'].shape}")

## 2. Backward Hooks

A backward hook has the signature `hook(module, grad_input, grad_output)`. It allows you to inspect or modify gradients.

In [None]:
gradients = {}
def get_gradients(name):
    def hook(model, grad_input, grad_output):
        gradients[name] = grad_output[0].detach()
        print(f"Gradients for {name} captured!")
    return hook

# Register backward hook
# Note: register_full_backward_hook is generally preferred over register_backward_hook
handle_bwd = model.conv2.register_full_backward_hook(get_gradients('conv2'))

# Backward pass
loss = output.sum()
loss.backward()

print(f"Captured gradient shape: {gradients['conv2'].shape}")

## 3. Removing Hooks

It is important to remove hooks when you are done, otherwise they will accumulate and slow down your model or cause memory leaks.

In [None]:
# Remove the hooks using the handles we stored earlier
handle_fwd.remove()
handle_bwd.remove()
print("Hooks removed!")