## tensor hook

- [https://pytorch.org/docs/stable/notes/autograd.html](https://pytorch.org/docs/stable/notes/autograd.html)

- [https://pytorch.org/docs/stable/generated/torch.autograd.graph.Node.register_hook.html](https://pytorch.org/docs/stable/generated/torch.autograd.graph.Node.register_hook.html)

### example 1

In [10]:
import torch
import torch.nn.functional as F
print(f'torch.__version__: {torch.__version__}')

torch.__version__: 2.2.1+cu118


In [2]:
# vanilla backprob
v = torch.tensor([0., 0., 0.], requires_grad=True)
print(v.grad)
v.backward(torch.tensor([1., 2., 3.]))
print(v.grad)

None
tensor([1., 2., 3.])


In [3]:
# register hook to double the gradient
v.grad = None
h = v.register_hook(lambda x: x * 2)
v.backward(torch.tensor([1., 2., 3.]))
print(v.grad)

tensor([2., 4., 6.])


In [4]:
# removes the hook
h.remove()
v.grad = None
v.backward(torch.tensor([1., 2., 3.]))
print(v.grad)

tensor([1., 2., 3.])


### example 2

In [56]:
a = torch.tensor(2.0, requires_grad=True) # leaf
b = torch.tensor(2.0, requires_grad=True) # leaf
c = a*b # non-leaf

def c_hook(grad):
    print(grad)
    return grad + 2

c.retain_grad()
c.register_hook(c_hook)
c.register_hook(lambda grad: print(grad))

d = torch.tensor(4.0, requires_grad=True) # leaf
d.register_hook(lambda grad: grad + 100)

e = c * d # noon-leaf
e.retain_grad()
e.register_hook(lambda grad: grad * 2)
e.retain_grad()

e.backward()

tensor(8.)
tensor(10.)


### example 3

In [55]:
a = torch.tensor([0., 0., 0.], requires_grad=True) # leaf node
b = a.clone() # non-leaf node
print(f'a: {a}, {b.requires_grad}, {a.grad_fn}')
print(f'b: {b}, {b.requires_grad}, {b.grad_fn}')
assert isinstance(b.grad_fn, torch.autograd.graph.Node)

a: tensor([0., 0., 0.], requires_grad=True), True, None
b: tensor([0., 0., 0.], grad_fn=<CloneBackward0>), True, <CloneBackward0 object at 0x7f6d9c337dc0>


In [51]:
# it double the gradient for backpropagation but it does not change its own gradient
handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,)) # should define gradient input and output

# because loss = b1 + b2 + b3, each gradient is equally 1. 
# and because we double propagted gradient, it should be [2., 2., 2.] for a
b.sum().backward(retain_graph=True)
print(a.grad)
print(b.grad)

handle.remove() # Removes the hook
a.grad, b.grad = None, None
b.sum().backward(retain_graph=True)
print(a.grad)
print(b.grad)

tensor([2., 2., 2.])
None
tensor([1., 1., 1.])
None


  print(b.grad)
  print(b.grad)


In [52]:
a = torch.tensor([0., 0., 0.], requires_grad=True) # leaf node
b = a.clone() # non-leaf node
b.retain_grad()
print(f'a: {a}, {b.requires_grad}, {a.grad_fn}')
print(f'b: {b}, {b.requires_grad}, {b.grad_fn}')
assert isinstance(b.grad_fn, torch.autograd.graph.Node)

a: tensor([0., 0., 0.], requires_grad=True), None
b: tensor([0., 0., 0.], grad_fn=<CloneBackward0>), <CloneBackward0 object at 0x7f6da4b0ba30>


In [53]:
# if 
handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
b.sum().backward(retain_graph=True)
print(a.grad)
print(b.grad)

handle.remove() # Removes the hook
a.grad, b.grad = None, None
b.sum().backward(retain_graph=True)
print(a.grad)
print(b.grad)

tensor([2., 2., 2.])
tensor([1., 1., 1.])
tensor([1., 1., 1.])
tensor([1., 1., 1.])


## module hook

- [https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution](https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution)

In [5]:
from src.models import get_dummy_mlp_model
embedding_dim = 5
hidden_dim = 8
model = get_dummy_mlp_model(embedding_dim, hidden_dim, torch.bfloat16).cuda().train()
device = next(iter(model.parameters())).device
print(model)

GoodNet(
  (emb): Embedding(5, 8)
  (fc2): Linear(in_features=8, out_features=8, bias=True)
  (out): Linear(in_features=8, out_features=5, bias=True)
)


In [6]:
def forward_print_hook(module, input, output):
    print(f'''
    module        : {module}
    input         : {input}
    output        : {output}
    ''')

for name, module in model.named_modules():
    module.register_forward_hook(forward_print_hook)

In [7]:
grad_accs = []
def wrapper(param):
    param_tmp = param.expand_as(param)
    grad_acc = param_tmp.grad_fn.next_functions[0][0]

    def print_grad_size(*notneeded):
        print(param.grad.size())
    grad_acc.register_hook(print_grad_size)
    grad_accs.append(grad_acc)

for p in model.parameters():
    if p.requires_grad:
        wrapper(p)

In [8]:
input = torch.randint(5, (1,5)).to(device)

In [9]:
logit = model(input)[:, :-1].contiguous().view(-1, embedding_dim)
target = input[:, 1:].contiguous().view(-1)
print('====='*15)
F.cross_entropy(logit, target).backward()


    module        : Embedding(5, 8)
    input         : (tensor([[3, 1, 0, 1, 0]], device='cuda:0'),)
    output        : tensor([[[ 1.6641,  1.0312, -0.6055, -0.0190,  0.7773,  0.3750, -0.4453,
           0.3418],
         [-0.1748,  1.5469,  0.2988,  0.9102, -0.2334, -0.7695,  1.6016,
           0.1992],
         [ 0.6055, -0.2539,  1.1875, -0.7109,  0.8477, -0.7305, -0.2832,
          -1.1797],
         [-0.1748,  1.5469,  0.2988,  0.9102, -0.2334, -0.7695,  1.6016,
           0.1992],
         [ 0.6055, -0.2539,  1.1875, -0.7109,  0.8477, -0.7305, -0.2832,
          -1.1797]]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<EmbeddingBackward0>)
    

    module        : Linear(in_features=8, out_features=8, bias=True)
    input         : (tensor([[[1.6641, 1.0312, 0.0000, 0.0000, 0.7773, 0.3750, 0.0000, 0.3418],
         [0.0000, 1.5469, 0.2988, 0.9102, 0.0000, 0.0000, 1.6016, 0.1992],
         [0.6055, 0.0000, 1.1875, 0.0000, 0.8477, 0.0000, 0.0000, 0.0000],
         [0.0