# Pytorch Hooks

Example PyTorch model
```
          +--+  z1  +--+  z2
     +--->|La|----->|Lb|-----+
     |    +--+      +--+     |
 x   |                       |   +--+ (o1,o2)
-----+                       +-->|SD|------->
     |                       |   +--+
     |    +--+ h1            |
     +--->+Lb+---------------+
          +--+
```

In [1]:
import hashlib
from typing import Callable, Optional, Tuple, Union

import torch


class SumDiff(torch.nn.Module):
    def forward(self, a, b):
        return a + b, a - b


class LinearA(torch.nn.Linear):
    pass


class LinearB(torch.nn.Linear):
    pass


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_a = LinearA(2, 2)
        self.linear_b = LinearB(2, 4, bias=False)
        self.sum_diff = SumDiff()

    def forward(self, x):
        z1 = self.linear_a(x)
        z2 = self.linear_b(z1)
        h1 = self.linear_b(x)
        o1, o2 = self.sum_diff(z2, h1)
        out = o1 / o2
        return out

The network works fine

In [2]:
x = torch.rand(3, 2).requires_grad_()
net = Net()
out = net(x)
out.backward(torch.ones_like(out))
print(
    f"x   : {tuple(x.shape)}",
    f"grad: {tuple(x.grad.shape)}",
    f"out : {tuple(out.shape)}",
    sep="\n",
)

x   : (3, 2)
grad: (3, 2)
out : (3, 4)


## Module forward hook
A module forward hook runs after the module output is computed. It received the input and output tensors of the model.

In [3]:
def tensor_hex(tensor: torch.Tensor) -> str:
    return hex(id(tensor))[-4:]


def module_forward_hook(
    module: torch.nn.Module,
    inputs: Tuple[torch.Tensor],
    outputs: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Optional[torch.Tensor]:
    print(module.__class__.__name__)
    if not isinstance(outputs, tuple):
        outputs = (outputs,)
    for idx, t in enumerate(inputs):
        print(f"IN[{idx}]  {tuple(t.shape)} {tensor_hex(t)}")
    for idx, t in enumerate(outputs):
        print(f"OUT[{idx}] {tuple(t.shape)} {tensor_hex(t)}")
    print()


handles = [
    net.register_forward_hook(module_forward_hook),
    net.linear_a.register_forward_hook(module_forward_hook),
    net.linear_b.register_forward_hook(module_forward_hook),
    net.sum_diff.register_forward_hook(module_forward_hook),
]
net(x)
for h in handles:
    h.remove()

LinearA
IN[0]  (3, 2) b1c0
OUT[0] (3, 2) ccc0

LinearB
IN[0]  (3, 2) ccc0
OUT[0] (3, 4) 41c0

LinearB
IN[0]  (3, 2) b1c0
OUT[0] (3, 4) ce00

SumDiff
IN[0]  (3, 4) 41c0
IN[1]  (3, 4) ce00
OUT[0] (3, 4) b680
OUT[1] (3, 4) 4240

Net
IN[0]  (3, 2) b1c0
OUT[0] (3, 4) 4280



## Tensor backward hook
Custom tensor backward hook that receives both the tensor and its grad as parameters.
It is wrapped in an object that automatically removes the hook after it runs.

In [4]:
BackwardHookFn = Callable[[torch.Tensor, torch.Tensor], Optional[torch.Tensor]]


def module_forward_hook(module, inputs, outputs):
    """Register a one time tensor hook on the inputs of a module"""
    print("Module forward hook:", module.__class__.__name__)
    for i, tensor in enumerate(inputs):
        ptg = PrintTensorGrad(f"{module.__class__.__name__} IN[{i}]")
        one_time_tensor_hook(tensor, ptg)


def one_time_tensor_hook(tensor: torch.Tensor, backward_hook_fn: BackwardHookFn):
    """Register a one time tensor hook that will receive both the tensor and the grad"""
    def inner(grad: torch.Tensor) -> Optional[torch.Tensor]:
        try:
            new_grad = backward_hook_fn(tensor, grad)
            return new_grad
        finally:
            handle.remove()

    handle = tensor.register_hook(inner)


class PrintTensorGrad(BackwardHookFn):
    def __init__(self, label: str):
        self.label = label

    def __call__(self, tensor: torch.Tensor, grad: torch.Tensor) -> None:
        print(self)
        print(" - Tensor:", tuple(tensor.shape), tensor_hex(tensor))
        print(" - Grad  :", tuple(grad.shape))

    def __repr__(self):
        return f"{self.__class__.__name__}({self.label})"


handles = [
    net.register_forward_hook(module_forward_hook),
    net.linear_a.register_forward_hook(module_forward_hook),
    net.linear_b.register_forward_hook(module_forward_hook),
    net.sum_diff.register_forward_hook(module_forward_hook),
]

print("Forward pass\n============")
out = net(x)

print("\nBackward pass\n=============")
out.sum().backward()

for h in handles:
    h.remove()

Forward pass
Module forward hook: LinearA
Module forward hook: LinearB
Module forward hook: LinearB
Module forward hook: SumDiff
Module forward hook: Net

Backward pass
PrintTensorGrad(SumDiff IN[1])
 - Tensor: (3, 4) dcc0
 - Grad  : (3, 4)
PrintTensorGrad(SumDiff IN[0])
 - Tensor: (3, 4) 0380
 - Grad  : (3, 4)
PrintTensorGrad(LinearB IN[0])
 - Tensor: (3, 2) cdc0
 - Grad  : (3, 2)
PrintTensorGrad(LinearA IN[0])
 - Tensor: (3, 2) b1c0
 - Grad  : (3, 2)
PrintTensorGrad(LinearB IN[0])
 - Tensor: (3, 2) b1c0
 - Grad  : (3, 2)
PrintTensorGrad(Net IN[0])
 - Tensor: (3, 2) b1c0
 - Grad  : (3, 2)
