## Pytorch hooks
You can use a hook to modify the procedure of a module.  

**Tutorial** [This one](https://www.youtube.com/watch?v=syLFCVYua6Q) by Elliot Waite is one of the best tutorials about hooks.

In [1]:
import torch
import torch.nn as nn 
import numpy as np
print("torch version:", torch.__version__)

torch version: 2.1.2


#### Forward hook and forward pre-hook
A forward pre-hook is computed every time before `forward` comptues an output. [doc](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_pre_hook)  
A forward hook is computed every time after `forward` computed an output. [doc](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook)  

In [8]:
class SumNet(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return x + y

def modify_x_pre_hook(module, inputs):
    x, y = inputs 
    return (x+10), y 

def modify_output_hook(module, args, output):
    return output + 0.5

def test_forward_pre_hook():
    m = SumNet()
    x = torch.tensor(2.)
    y = torch.tensor(3.)

    print("Original output:", m(x, y))
    handle = m.register_forward_pre_hook(modify_x_pre_hook)
    print("After the forward pre-hook:", m(x, y))
    handle.remove()
    print("After removing the hook:", m(x, y))

def test_forward_hook():
    m = SumNet()
    x = torch.tensor(2.)
    y = torch.tensor(3.)

    print("Original output:", m(x, y))
    handle = m.register_forward_hook(modify_output_hook)
    print("After the forward hook:", m(x, y))
    handle.remove()
    print("After removing the hook:", m(x, y))

test_forward_pre_hook()
test_forward_hook()

Original output: tensor(5.)
After the forward pre-hook: tensor(15.)
After removing the hook: tensor(5.)
Original output: tensor(5.)
After the forward hook: tensor(5.5000)
After removing the hook: tensor(5.)


#### Backward hook and full backward hook
A backward hook is called every time a *gradient* with respect to the tensor is computed.

- You can register a *full* backward hook [pytorch doc](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook)  
- You'll also see why there's a bug in the prior pytorch implementation (`register_backward_hook`)  

In [22]:
class Mult3Net(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y, z):
        return x * y * z

# Backward, module-level hook example of gradient clipping
def clip_gradients(grad, bounds):
    return torch.clip(grad, min=bounds[0], max=bounds[1])

def print_and_clip_grad_backward_hook(module, grad_inputs, grad_outputs):
    """
    Inputs:
        module: the torch.nn.Module that computes the output from the input
        grad_inputs: tuple, containing the gradients w.r.t the module input
        grad_outputs: tuple, containing the gradients w.r.t the module output
    Returns:
        new_grad_input
    Note:
    - It's not allowed to modify the inputs/outputs in-place in a backward hook.
    """
    print("Grads wrt input:", grad_inputs)
    print("Grads wrt output:", grad_outputs)
    return tuple([clip_gradients(grad_input, [-5, 5]) for grad_input in grad_inputs])

def test_backward_module_level_hook():
    m = Mult3Net()
    x = torch.tensor(2., requires_grad=True)
    y = torch.tensor(3., requires_grad=True)
    z = torch.tensor(4., requires_grad=True)
    
    # Try run without commenting out this line
    # m.register_full_backward_hook(print_and_clip_grad_backward_hook)

    # The following line leads to an incorrect "grads wrt input". This is a bug in pytorch
    # m.register_backward_hook(print_and_clip_grad_backward_hook)

    res = m(x, y, z)
    res.retain_grad()
    res.backward()
    print("Computed gradients:", res.grad, x.grad, y.grad, z.grad)
    
    
test_backward_module_level_hook()

Grads wrt input: (tensor(4.), tensor(6.))
Grads wrt output: (tensor(1.),)
Computed gradients: tensor(1.) tensor(12.) tensor(8.) tensor(5.)




#### Backward, tensor-level hook
Pytorch alternatively allows you to register a hook on the tensor level, but this mechanism only works for the backward hook.
[Pytorch documentation here](https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html)  

In [25]:
def test_backward_tensor_level_hook():
    m = Mult3Net()
    x = torch.tensor(2., requires_grad=True)
    y = torch.tensor(3., requires_grad=True)
    z = torch.tensor(4., requires_grad=True)

    # Tensor-level hook
    x.register_hook(lambda grad: -grad)

    res = m(x, y, z)
    res.retain_grad()
    res.backward()
    print("Computed gradients:", res.grad, x.grad, y.grad, z.grad)
    
    
test_backward_tensor_level_hook()

Computed gradients: tensor(1.) tensor(-12.) tensor(8.) tensor(6.)


## Practice questions
1. If you want to view and save the intermediate representations of a module, how would you implement that?  
2. If you want to view and save the gradients of a module, how would you implement that?  
3. If you want to do a "copy-pasting" mechanism onto a model, i.e., temporarily overwriting the weights of model A using the parameters from model B, but only for one forward() step. How would you implement that?  