In [1]:
# hooks are a way to monitor states during forward or backward pass
# - pre-forward hook: happens befpre forward pass
# - forward hook -> after layer execution
# - backward hook -> after gradient computation
# - tensor.register_hook() -> monitor gradients in specific tensors without affecting layers

# hooks are useful for debugging, visualization, and interpretability


In [2]:
import torch
from torch import nn

In [19]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, kernel_size=3)
        self.conv2 = nn.Conv2d(4, 8, kernel_size=3)
        self.fc1 = nn.Linear(128, 128)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

model = SimpleCNN()

In [None]:
def activation_hook(module, input, output):
    print("Activations: ", output)

# register hook on conv2 layer
hook_handle = model.conv2.register_forward_hook(activation_hook)

# dummy input to trigger forward pass
dummy_input = torch.randn(1,1,8,8)
output = model(dummy_input)

# remove hook
hook_handle.remove() # detach function from layer

Activations:  tensor([[[[-0.0555,  0.2934, -0.2884, -0.1731],
          [ 0.2727, -0.5773, -0.4750, -0.0273],
          [ 0.5547, -0.7700,  0.0366,  0.4569],
          [-0.1109,  0.3970, -0.0747,  0.2563]],

         [[-0.0559, -0.1338,  0.5201, -0.3818],
          [ 0.0336,  0.2275,  0.4396, -0.6283],
          [ 0.1205,  0.2646, -0.6408,  0.3123],
          [-0.4601,  0.3708, -0.1609,  0.3418]],

         [[-0.2659,  0.3472,  0.8026,  0.1176],
          [ 0.0028, -0.0660,  0.1390, -0.7788],
          [ 0.4445,  0.0030,  0.3412, -0.2067],
          [-0.0678,  0.2435,  0.1821,  0.2626]],

         [[ 0.3335,  0.3773, -0.0103, -0.3971],
          [-0.0759,  0.1605,  0.0917, -0.2314],
          [-0.4902,  0.2994, -0.5484, -0.1300],
          [-0.1805,  0.2096, -0.6668,  0.3159]],

         [[-0.3859, -0.3208, -0.1518,  0.1928],
          [ 0.2099, -0.4059,  0.4788, -0.2851],
          [ 0.5219, -0.2789,  0.0666,  0.2895],
          [ 0.1380, -0.1994, -0.1995,  0.4108]],

         [[ 0.39

In [22]:
output = model(dummy_input)