Hooks in PyTorch are under-documented but provide powerful functionality during backpropagation. They can be compared to Doctor Fate, a lesser-known superhero.

One of the reasons hooks are valuable is that they allow actions during backpropagation. A hook can be registered on a Tensor or a nn.Module and is essentially a function executed when either the forward or backward pass is called.

When referring to "forward," it doesn't mean the forward function of a nn.Module where the output is computed. Instead, it means the forward function of the `torch.Autograd.Function` object that is the `grad_fn` of a Tensor.

PyTorch provides two types of hooks:
- **Forward Hook**
- **Backward Hook**

A forward hook is executed during the forward pass, while the backward hook is executed when the backward function is called.

#### Hooks for Tensors
A hook is essentially a function with a specific signature. For tensors, the signature for the backward hook is:

In [1]:
# hook(grad) -> Tensor or None

There is no forward hook for a tensor.

`grad` represents the value contained in the grad attribute of the tensor after backward is called. The function should not modify its argument. It must either return None or a Tensor to be used in place of grad for further gradient computation. Here's an example:

In [4]:
import torch

In [7]:
def backward_hook(grad):
    # Modify the gradient or perform some action
    modified_grad = grad * 2
    return modified_grad

tensor = torch.tensor(1., requires_grad=True)
tensor.register_hook(backward_hook)

<torch.utils.hooks.RemovableHandle at 0x2231569d6d0>

In the first experiment, we calculate the gradients of tensors `a` and `b` with respect to the scalar `c`. The gradient of `a` is computed using autograd. Then, we print the gradients of `a` and `b`. In the second experiment, we redo the computation of gradients for `a` and `b`, but this time we register a hook on tensor `b` that multiplies its gradient by 2. Here's the code:

In [9]:
import torch 

# Experiment 1
a = torch.ones(5)
a.requires_grad = True

b = 2*a
b.retain_grad()   # Since b is non-leaf and its grad will be destroyed otherwise.

c = b.mean()

c.backward()

print("Experiment 1:")
print("Gradient of a:", a.grad)
print("Gradient of b:", b.grad)

# Experiment 2
a = torch.ones(5)
a.requires_grad = True

b = 2*a
b.retain_grad()

# Register a hook that multiplies b's grad by 2
b.register_hook(lambda grad: grad * 2)  

b.mean().backward() 

print("Experiment 2:")
print("Gradient of a:", a.grad)
print("Gradient of b:", b.grad)

Experiment 1:
Gradient of a: tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000])
Gradient of b: tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
Experiment 2:
Gradient of a: tensor([0.8000, 0.8000, 0.8000, 0.8000, 0.8000])
Gradient of b: tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])


Hooks provide versatile functionality during the backward pass. Here are some key uses:

1. **Debugging and Logging**: Hooks allow you to print or log gradient values for debugging purposes. This is particularly useful for non-leaf variables whose gradients may not be retained unless you explicitly call `retain_grad()` on them. Using hooks provides a cleaner way to aggregate and inspect these values.

2. **Gradient Modification**: Hooks enable you to modify gradients during the backward pass. This is crucial for various applications. For example, in the previous example where we multiplied b's gradient by 2, subsequent gradient calculations (such as those of a or any tensor that depends on b for gradient) now use the modified gradient (`2 * grad(b)`) instead of the original gradient. If you were to update parameters individually after the backward pass, you'd have to manually multiply each gradient by 2, which could be cumbersome and error-prone.

Using hooks streamlines the process of debugging, logging, and modifying gradients during backpropagation, enhancing the flexibility and efficiency of your PyTorch code.


#### Hooks for nn.Module
Hooks for `nn.Module` objects provide additional flexibility during the forward and backward passes of the neural network. However, their usage can sometimes lead to breaking the abstraction of `nn.Module`. Here's an explanation of the hook function signatures:

**Backward Hook:**

In [10]:
# hook(module, grad_input, grad_output) -> Tensor or None

This hook function is executed during the backward pass. The module parameter represents the nn.Module object to which the hook is attached. `grad_input` is a tuple containing the gradients of the inputs of the `nn.Module` object with respect to the loss (e.g., `dL / dx`, `dL / dw`, `dL / db`). `grad_output` is the gradient of the output of the `nn.Module` object with respect to the loss. However, due to the possibility of multiple forward calls within an `nn.Module`, the interpretation of these gradients can be ambiguous.

**Forward Hook:**

In [11]:
# hook(module, input, output) -> None

This hook function is executed during the forward pass. Similar to the backward hook, the module parameter represents the `nn.Module` object. input is a tuple containing the inputs to the `nn.Module` object from different forward calls, and output is the output of the forward call.

Using hooks on `nn.Module` objects can be powerful but may break the abstraction of `nn.Module`. Since an `nn.Module` is meant to represent a modularized layer, hooks introduce an arbitrary number of forward and backward calls, potentially complicating the interpretation of gradients. Therefore, careful consideration should be given before using hooks on `nn.Module` objects.

In [12]:
import torch 
import torch.nn as nn

class myNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 2, stride=2)
        self.relu = nn.ReLU()
        self.flatten = lambda x: x.view(-1)
        self.fc1 = nn.Linear(160, 5)

    def forward(self, x):
        x = self.relu(self.conv(x))
        return self.fc1(self.flatten(x))

net = myNet()

def hook_fn(m, i, o):
    print(m)
    print("------------Input Grad------------")
    for grad in i:
        try:
            print(grad.shape)
        except AttributeError: 
            print("None found for Gradient")
    print("------------Output Grad------------")
    for grad in o:  
        try:
            print(grad.shape)
        except AttributeError: 
            print("None found for Gradient")
    print("\n")

net.conv.register_backward_hook(hook_fn)
net.fc1.register_backward_hook(hook_fn)

inp = torch.randn(1, 3, 8, 8)
out = net(inp)

(1 - out.mean()).backward()

Linear(in_features=160, out_features=5, bias=True)
------------Input Grad------------
torch.Size([5])
torch.Size([5])
------------Output Grad------------
torch.Size([5])


Conv2d(3, 10, kernel_size=(2, 2), stride=(2, 2))
------------Input Grad------------
None found for Gradient
torch.Size([10, 3, 2, 2])
torch.Size([10])
------------Output Grad------------
torch.Size([1, 10, 4, 4])






In this code snippet, we define a simple neural network `myNet` consisting of a convolutional layer, a ReLU activation function, and a fully connected layer. We then define a hook function `hook_fn` that prints information about the input and output gradients of specific layers during the backward pass. Finally, we register the hook function to the convolutional and fully connected layers of the network and perform a forward pass followed by a backward pass.

Here's a breakdown of what the code does:

- We define a neural network class `myNet` that inherits from `nn.Module`.
- In the **__init__** method of myNet, we define the layers of the network: a convolutional layer, a ReLU activation function, and a fully connected layer.
- We define the forward method of `myNet`, which specifies the forward pass of the network.
- We define a hook function `hook_fn` that takes three arguments: `m` (the module to which the hook is attached), `i` (the input gradients), and `o` (the output gradients).
- Inside the hook function, we print the module, input gradients, and output gradients.
- We register the hook function to the convolutional and fully connected layers of the network using the `register_backward_hook` method.
- We create an input tensor `inp` with random values and perform a forward pass through the network.
- We compute the loss as `1 - out.mean()` and perform a backward pass to compute gradients.

When the backward pass is performed, the hook function `hook_fn` is called for each layer to which it is registered. The hook function prints information about the module (layer), input gradients, and output gradients.

This allows us to inspect the gradients flowing through specific layers of the network during the backward pass, which can be helpful for debugging and understanding the behavior of the network during training.

#### Proper Way of Using Hooks

In [13]:
import torch 
import torch.nn as nn

class myNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 2, stride=2)
        self.relu = nn.ReLU()
        self.flatten = lambda x: x.view(-1)
        self.fc1 = nn.Linear(160, 5)
   
    def forward(self, x):
        x = self.relu(self.conv(x))
        
        # Register a hook to clamp gradients of conv layer to be non-negative
        x.register_hook(lambda grad: torch.clamp(grad, min=0))
        
        # Register a hook to print whether there are any negative gradients
        x.register_hook(lambda grad: print("Gradients less than zero:", bool((grad < 0).any())))
        
        return self.fc1(self.flatten(x))

net = myNet()

# Apply hooks to biases of linear layers to zero out their gradients
for name, param in net.named_parameters():
    if "fc" in name and "bias" in name:
        param.register_hook(lambda grad: torch.zeros(grad.shape))

# Perform a forward pass
out = net(torch.randn(1, 3, 8, 8)) 

# Compute gradients and perform backward pass
(1 - out).mean().backward()

# Print the gradients of the biases
print("The biases are", net.fc1.bias.grad)  # Bias gradients should be zero

Gradients less than zero: False
The biases are tensor([0., 0., 0., 0., 0.])


- **Clamping Convolutional Layer Gradients**: The hook registered with `x.register_hook(lambda grad: torch.clamp(grad, min=0))` ensures that gradients flowing through the convolutional layer (conv) are clamped to be non-negative. This ensures that no negative gradient is backpropagated through this layer.

- **Checking Negative Gradients**: The hook registered with `x.register_hook(lambda grad: print("Gradients less than zero:", bool((grad < 0).any())))` prints a message indicating whether there are any negative gradients flowing through the convolutional layer. This is helpful for monitoring and debugging purposes.

- **Zeroing Bias Gradients**: The loop over named parameters applies hooks to the biases of linear layers (fc), ensuring that their gradients are zeroed out during the backward pass.

#### Visualizing Activations

Using forward hooks for visualizing activations is a common technique in deep learning. It allows you to capture intermediate feature maps during the forward pass for visualization and analysis. Here's how you can implement it:

In [14]:
import torch 
import torch.nn as nn

class myNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 2, stride=2)
        self.relu = nn.ReLU()
        self.flatten = lambda x: x.view(-1)
        self.fc1 = nn.Linear(160, 5)
   
    def forward(self, x):
        x = self.relu(self.conv(x))
        return self.fc1(self.flatten(x))

# Define a dictionary to store intermediate feature maps
visualisation = {}

# Define input tensor
inp = torch.randn(1, 3, 8, 8)

# Define hook function to save intermediate feature maps
def hook_fn(m, i, o):
    visualisation[m] = o

# Create an instance of the network
net = myNet()

# Register forward hooks for each layer in the network
for name, layer in net._modules.items():
    layer.register_forward_hook(hook_fn)

# Perform a forward pass through the network
out = net(inp)

# Intermediate feature maps are now stored in the visualisation dictionary

By modifying the forward method of the nn.Module subclass, we can append the intermediate outputs to a dictionary or list, allowing you to access them after the forward pass. Additionally, our method handles cases where a model contains sequential layers, ensuring that the hooks are appropriately registered to the individual layers within the sequential block.

Here's the adjusted code incorporating the approach:

In [15]:
import torch 
import torch.nn as nn

class myNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 2, stride=2)
        self.relu = nn.ReLU()
        self.flatten = lambda x: x.view(-1)
        self.fc1 = nn.Linear(160, 5)
        self.seq = nn.Sequential(nn.Linear(5, 3), nn.Linear(3, 2))
        
        # Dictionary to store intermediate feature maps
        self.visualisation = {}

    def forward(self, x):
        x = self.relu(self.conv(x))
        x = self.fc1(self.flatten(x))
        
        # Append intermediate output of fc1 layer to visualisation dictionary
        self.visualisation['fc1_output'] = x
        
        # Apply sequential layers
        x = self.seq(x)
        
        # Append intermediate output of sequential layers to visualisation dictionary
        self.visualisation['seq_output'] = x
        
        return x

net = myNet()

# Perform a forward pass through the network
out = net(torch.randn(1, 3, 8, 8))

# Check the keys of the visualisation dictionary to verify we captured all layers
print(net.visualisation.keys())

dict_keys(['fc1_output', 'seq_output'])


Finally, you can turn this tensors into numpy arrays and plot activations .