### PyTorch hooks
* a mechanism to incorporate user-defined codes into PyTorch tensors or module instances in the computation graph where there is no explicit way to access them
    * allows: 1) inspection; 2) alteration of nodes in the compute graph (tensors or modules) that are not otherwise accessible 
* see [this tutorial](https://www.youtube.com/watch?v=syLFCVYua6Q) for more details
* a hook is a callable function object
* a hook can be registered to a node in the compute graph by several methods:
    * tensor hooks:
        * .register_hook(callable)
    * module hooks:
        * 1) .register_forward_pre_hook(callable): apply the hook function before forward() pass
        * 2) .register_forward_hook(callable): apply the hook function after forward() pass
        * 3) .register_backward_hook(callable): apply the hook function after backward() pass
            * note: the .register_backward_hook() utility seems to have some bugs in PyTorch, currently work-around is to use tensor hooks for this situation
* two types of hooks:
    * tensor hooks: e.g., for gradients in the backward graph
    * module hooks: e.g., for each layer in a network instance
    
![compute graph](./assets/images/pytorch-09-hooks-compute-graph-illustration.png)

* when register a hook to a node:
    * in the forward graph, a pointer to the hook callable function is added to an ordered dictionary field in the node
        * there can be multiple hooks registered to the same node; they would be called orderly
    * in the backward graph, if the node is an intermediate node, there would add a pointer to the ordered dictionary in the forward node
* the ordered dictionary has key:value pair that:
    * key is a hook handler; it is also returned by the .register_hook() call
    * value is the hook callable function
* tesnor hooks: never use inplace operations (or any operation that alters the input) in a hook callable function
    * this is because in the backward graph, the gradients where the hook is applied to are likely also passed on to other nodes

In [None]:
""" register tensor hooks """
import torch

# leaf node A
a = torch.tensor(2.0, requires_grad=True)
print(a)
# leaf node B
b = torch.tensor(3.0, requires_grad=True)
print(b)
# intermediate node C
c = a * b
print(c)
# leaf node D
d = torch.tensor(4.0, requires_grad=True)
print(d)
# intermediate node E
e = c * d
print(e)

# hook function
def c_hook(grad):
    print(grad)
    return grad + 2

# register hooks to intermediate node C
c.register_hook(c_hook)
c.register_hook(lambda grad: print(grad))   # can use lambda functions
c.retain_grad()     # retain_grad() is also registered as a hook function

# register hooks to leaf node D
d.register_hook(lambda grad: grad + 100)

# register hooks to intermediate node E
e.retain_grad()
e.register_hook(lambda grad: grad * 2)

e.backward()