In [1]:
%load_ext autoreload 
%autoreload 2
import torch
import torch.nn.functional as F
from torch import nn, autograd, func

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
torch.manual_seed(0)
N = 4
din = 3
dout = 2
x = torch.randn(N, din)
w_hat = torch.rand(din, dout)
y =  x @ w_hat + torch.tensor([1.])

In [50]:

class MatMul(autograd.Function):
    
    @staticmethod
    def forward(ctx, input, weights, bias) -> torch.Tensor:
        out = input @ weights.transpose(0, 1) + bias
        ctx.save_for_backward(input, weights, bias, out)
        return out
    
    @staticmethod
    def backward(ctx, *grad_outputs):
        input, weights, bias, out = ctx.saved_tensors
        print()
        print()
        print()
        print("received ", grad_outputs)
        
        bz, fz = input.size()
        out_feat, in_feat = weights.size()
        grad_w =  input.unsqueeze(1).expand(bz, out_feat, -1) * grad_outputs[0].unsqueeze(-1) 
        grad_w = grad_w.sum(0)
        grad_b =  (bias.new_ones(bias.size()).unsqueeze(0).unsqueeze(0) * grad_outputs[0]).sum(0).sum(0)
        inp_grad = None
        if ctx.needs_input_grad[0]:
            print("Need inp grad ")
            inp_grad =  grad_outputs[0] @ weights
        print("Grad_w is : ", grad_w )
        print("Grad_b is : ", grad_b )
        print()
        return  inp_grad, grad_w, grad_b,

class NetLinear(nn.Module):
    def __init__(self, in_feat, out_feat, bias=True):
        super().__init__()
        self.register_buffer("in_feat", torch.tensor(in_feat).long())
        self.register_buffer("out_feat", torch.tensor(out_feat).long())
        self.weight = nn.Parameter(torch.Tensor(self.out_feat, self.in_feat))
        self.scale = nn.Parameter(torch.ones(self.out_feat))
        if bias:
            self.bias = nn.Parameter(torch.zeros(self.out_feat))
        else:
            self.register_buffer("bias", None)
        
        with torch.no_grad():
            nn.init.xavier_normal_(self.weight)
        
    def forward(self, x):
        out = MatMul.apply(x, self.weight, self.bias)
        return out 

In [51]:
def backward_pre_hook(module,  grad_out):
    print(f"Backward pre hook ", grad_out)
    return (2 * grad_out[0], )

def backward_hook(module,  grad_inp, grad_out):
    print(f"Backward hook : {grad_inp}")
    print(f"Backward hook : {grad_out}")
    # return (grad_inp[0], grad_inp[1] * 2, grad_inp[2])

def criterion(inp, target):
    return (inp**2 - target).sum()

torch.manual_seed(0)
net_lin = NetLinear(din, dout)
net_lin.zero_grad()
net_lin.register_full_backward_pre_hook(backward_pre_hook)
net_lin.register_full_backward_hook(backward_hook)
# net_lin.register_backward_hook(backward_hook)
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)

Backward pre hook  (tensor([[ 9.1173,  5.3650],
        [ 5.3650,  4.3707],
        [ 2.4574,  0.4128],
        [-1.0664,  0.2064]]),)
Backward hook : (None,)
Backward hook : (tensor([[18.2347, 10.7301],
        [10.7301,  8.7415],
        [ 4.9149,  0.8257],
        [-2.1329,  0.4129]]),)



received  (tensor([[ 9.1173,  5.3650],
        [ 5.3650,  4.3707],
        [ 2.4574,  0.4128],
        [-1.0664,  0.2064]]),)
Grad_w is :  tensor([[ 18.5208,  -5.7981, -29.3299],
        [ 10.8352,  -6.0916, -18.0615]])
Grad_b is :  tensor([15.8734, 10.3550])

tensor(35.1679, grad_fn=<SumBackward0>)


In [52]:
class ModelAutograd(nn.Module):
    def __init__(self, in_feat, out_feat, bias=True):
        super().__init__()
        self.net1 = NetLinear(in_feat, 5, bias=bias)
        self.net2 = NetLinear(5, out_feat , bias=bias)
        with torch.no_grad():
            self.net1.weight.data.copy_(torch.ones_like(self.net1.weight))
            self.net2.weight.data.copy_(torch.ones_like(self.net2.weight))
        if bias:
            nn.init.constant_(self.net1.bias, 0.)
            nn.init.constant_(self.net2.bias, 0.)
    
    def forward(self, x):
        return self.net2(self.net1(x))

In [54]:
def backward_pre_hook(module,  grad_out):
    print(f"Backward pre hook ", grad_out)
    g = grad_out[0]
    g *= 2
    # return (2 * grad_out[0], )

def backward_hook(module,  grad_inp, grad_out):
    print(f"Backward hook grad inp: {grad_inp}")
    print(f"Backward hook grad out : {grad_out}")
    print()
    return (None, )

def backward_pre_hook1(module,  grad_out):
    print(f"Backward pre hook net 1 ")

def backward_pre_hook2(module,  grad_out):
    print(f"Backward pre hook net 2 ", grad_out)
    print("multiply by 3")
    return (3 * grad_out[0], )

def backward_hook2(module,    grad_inp, grad_out):
    print("+" * 20)
    print(f"Backward hook 2  grad inp: {grad_inp}")
    print(f"Backward hook 2 grad out : {grad_out}")
    print("+" * 20)
    
def backward_hook1(module, grad_inp, grad_out):
    print("-" * 20)
    print(f"Backward hook 1  grad inp: {grad_inp}")
    print(f"Backward hook 1 grad out : {grad_out}")
    print("-" * 20)

def criterion(inp, target):
    return (inp**2 - target).sum()

torch.manual_seed(0)
net_lin = ModelAutograd(din, dout)
net_lin.zero_grad()
net_lin.register_full_backward_pre_hook(backward_pre_hook)
net_lin.net1.register_full_backward_pre_hook(backward_pre_hook1)
net_lin.net2.register_full_backward_pre_hook(backward_pre_hook2)
net_lin.register_full_backward_hook(backward_hook)
net_lin.net2.register_full_backward_hook(backward_hook2)
net_lin.net1.register_full_backward_hook(backward_hook1)
# net_lin.register_backward_hook(backward_hook)
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)

Backward pre hook  (tensor([[ -9.3122,  -9.3122],
        [-19.1469, -19.1469],
        [  5.2212,   5.2212],
        [ -8.1794,  -8.1794]]),)
Backward hook grad inp: (None,)
Backward hook grad out : (tensor([[-18.6244, -18.6244],
        [-38.2937, -38.2937],
        [ 10.4423,  10.4423],
        [-16.3588, -16.3588]]),)

Backward pre hook net 2  (tensor([[-18.6244, -18.6244],
        [-38.2937, -38.2937],
        [ 10.4423,  10.4423],
        [-16.3588, -16.3588]]),)
multiply by 3



received  (tensor([[-18.6244, -18.6244],
        [-38.2937, -38.2937],
        [ 10.4423,  10.4423],
        [-16.3588, -16.3588]]),)
Need inp grad 
Grad_w is :  tensor([[109.4967, 109.4967, 109.4967, 109.4967, 109.4967],
        [109.4967, 109.4967, 109.4967, 109.4967, 109.4967]])
Grad_b is :  tensor([-62.8347, -62.8347])

++++++++++++++++++++
Backward hook 2  grad inp: (tensor([[-37.2489, -37.2489, -37.2489, -37.2489, -37.2489],
        [-76.5875, -76.5875, -76.5875, -76.5875, -76.5875],
        [ 20.8

In [56]:
def backward_pre_hook(module,  grad_out):
    print(f"Backward pre hook ", grad_out)
    g = grad_out[0]
    g *= 2
    # return (2 * grad_out[0], )

def backward_hook(module,  grad_inp, grad_out):
    print(f"Backward hook grad inp: {grad_inp}")
    print(f"Backward hook grad out : {grad_out}")
    print()
    return (None, )

def backward_pre_hook1(module,  grad_out):
    print(f"Backward pre hook net 1 ")

def backward_pre_hook2(module,  grad_out):
    print(f"Backward pre hook net 2 ", grad_out)
    print("multiply by 3 fixed*************************************")
    g = grad_out[0]
    g *= 3
    return (3 * grad_out[0], )

def backward_hook2(module,    grad_inp, grad_out):
    print("+" * 20)
    print(f"Backward hook 2  grad inp: {grad_inp}")
    print(f"Backward hook 2 grad out : {grad_out}")
    print("+" * 20)
    
def backward_hook1(module, grad_inp, grad_out):
    print("-" * 20)
    print(f"Backward hook 1  grad inp: {grad_inp}")
    print(f"Backward hook 1 grad out : {grad_out}")
    print("-" * 20)

def criterion(inp, target):
    return (inp**2 - target).sum()

torch.manual_seed(0)
net_lin = ModelAutograd(din, dout)
net_lin.zero_grad()
net_lin.register_full_backward_pre_hook(backward_pre_hook)
net_lin.net1.register_full_backward_pre_hook(backward_pre_hook1)
net_lin.net2.register_full_backward_pre_hook(backward_pre_hook2)
net_lin.register_full_backward_hook(backward_hook)
net_lin.net2.register_full_backward_hook(backward_hook2)
net_lin.net1.register_full_backward_hook(backward_hook1)
# net_lin.register_backward_hook(backward_hook)
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)

Backward pre hook  (tensor([[ -9.3122,  -9.3122],
        [-19.1469, -19.1469],
        [  5.2212,   5.2212],
        [ -8.1794,  -8.1794]]),)
Backward hook grad inp: (None,)
Backward hook grad out : (tensor([[-18.6244, -18.6244],
        [-38.2937, -38.2937],
        [ 10.4423,  10.4423],
        [-16.3588, -16.3588]]),)

Backward pre hook net 2  (tensor([[-18.6244, -18.6244],
        [-38.2937, -38.2937],
        [ 10.4423,  10.4423],
        [-16.3588, -16.3588]]),)
multiply by 3 fixed*************************************



received  (tensor([[ -55.8733,  -55.8733],
        [-114.8812, -114.8812],
        [  31.3269,   31.3269],
        [ -49.0765,  -49.0765]]),)
Need inp grad 
Grad_w is :  tensor([[328.4901, 328.4901, 328.4901, 328.4901, 328.4901],
        [328.4901, 328.4901, 328.4901, 328.4901, 328.4901]])
Grad_b is :  tensor([-188.5041, -188.5041])

++++++++++++++++++++
Backward hook 2  grad inp: (tensor([[-111.7467, -111.7467, -111.7467, -111.7467, -111.7467],
        [-229.76

In [76]:
def backward_pre_hook(module,  grad_out):
    print(f"Backward pre hook model", grad_out)
    print()
    g = grad_out[0]

def backward_hook(module,  grad_inp, grad_out):
    print(f"Backward hook model grad inp: {grad_inp}")
    print(f"Backward hook model grad out : {grad_out}")
    print()
    return (None, )

def backward_pre_hook1(module,  grad_out):
    print(f"Backward pre hook net 1 ", grad_out)
    print()

def backward_pre_hook2(module,  grad_out):
    print(f"Backward pre hook net 2 ", grad_out)
    print('Gradient for weight  is: ', module.weight.grad)
    print("multiply by 3 fixed*************************************")
    print()
    g = grad_out[0]

def backward_hook2(module,    grad_inp, grad_out):
    print("+" * 20)
    print(f"Backward hook 2  grad inp: {grad_inp}")
    print(f"Backward hook 2 grad out : {grad_out}")
    print('Gradient for weight  is: ', module.weight.grad)
    print("+" * 20)
    print()
    
    
def backward_hook1(module, grad_inp, grad_out):
    print("-" * 20)
    print(f"Backward hook 1  grad inp: {grad_inp}")
    print(f"Backward hook 1 grad out : {grad_out}")
    print("-" * 20)
    
def backward_hook_criterion(module, grad_inp, grad_out):
    print("@" * 20)
    print(f"Backward hook criterion  grad inp: {grad_inp}")
    print(f"Backward hook criterion grad out : {grad_out}")
    print("@" * 20)
    print()
    
def backward_pre_hook_criterion(module, grad_out):
    print("@" * 20)
    print(f"Backward pre hook criterion  grad out : {grad_out}")
    print("@" * 20)
    print()
    
    
class Loss(nn.Module):
    
    def forward(self, inp, target):
        return (inp**2 - target).sum()

torch.manual_seed(0)
net_lin = ModelAutograd(din, dout)
criterion = Loss()

net_lin.zero_grad()
criterion.register_full_backward_hook(backward_hook_criterion)
criterion.register_full_backward_pre_hook(backward_pre_hook_criterion)

net_lin.register_full_backward_pre_hook(backward_pre_hook)
net_lin.net1.register_full_backward_pre_hook(backward_pre_hook1)
net_lin.net2.register_full_backward_pre_hook(backward_pre_hook2)
net_lin.register_full_backward_hook(backward_hook)
net_lin.net2.register_full_backward_hook(backward_hook2)
net_lin.net1.register_full_backward_hook(backward_hook1)
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)

@@@@@@@@@@@@@@@@@@@@
Backward pre hook criterion  grad out : (tensor(1.),)
@@@@@@@@@@@@@@@@@@@@

@@@@@@@@@@@@@@@@@@@@
Backward hook criterion  grad inp: (tensor([[ -9.3122,  -9.3122],
        [-19.1469, -19.1469],
        [  5.2212,   5.2212],
        [ -8.1794,  -8.1794]]), tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]]))
Backward hook criterion grad out : (tensor(1.),)
@@@@@@@@@@@@@@@@@@@@

Backward pre hook model (tensor([[ -9.3122,  -9.3122],
        [-19.1469, -19.1469],
        [  5.2212,   5.2212],
        [ -8.1794,  -8.1794]]),)

Backward hook model grad inp: (None,)
Backward hook model grad out : (tensor([[ -9.3122,  -9.3122],
        [-19.1469, -19.1469],
        [  5.2212,   5.2212],
        [ -8.1794,  -8.1794]]),)

Backward pre hook net 2  (tensor([[ -9.3122,  -9.3122],
        [-19.1469, -19.1469],
        [  5.2212,   5.2212],
        [ -8.1794,  -8.1794]]),)
Gradient for weight  is:  None
multiply by 3 fixed*************************************

In [64]:
yp * 2

tensor([[ -9.3122,  -9.3122],
        [-19.1469, -19.1469],
        [  5.2212,   5.2212],
        [ -8.1794,  -8.1794]], grad_fn=<MulBackward0>)

In [73]:
net_lin.net2.weight.grad

tensor([[54.7483, 54.7483, 54.7483, 54.7483, 54.7483],
        [54.7483, 54.7483, 54.7483, 54.7483, 54.7483]])

In [78]:
def backward_pre_hook(module,  grad_out):
    print(f"Backward pre hook model", grad_out)
    print()
    
    g = grad_out[0]
    # g *= 2
    # return (2 * grad_out[0], )

def backward_hook(module,  grad_inp, grad_out):
    print(f"Backward hook model grad inp: {grad_inp}")
    print(f"Backward hook model grad out : {grad_out}")
    print()
    return (None, )

def backward_pre_hook1(module,  grad_out):
    print(f"Backward pre hook net 1 ", grad_out)
    print()

def backward_pre_hook2(module,  grad_out):
    print(f"Backward pre hook net 2 ", grad_out)
    print('Gradient for weight  is: ', module.weight.grad)
    print("multiply by 3 fixed*************************************")
    print()
    
    g = grad_out[0]
    # g *= 3
    # return (3 * grad_out[0], )

def backward_hook2(module,    grad_inp, grad_out):
    print("+" * 20)
    print(f"Backward hook 2  grad inp: {grad_inp}")
    print(f"Backward hook 2 grad out : {grad_out}")
    print('Gradient for weight  is: ', module.weight.grad)
    
    print("+" * 20)
    print()
    
    
def backward_hook1(module, grad_inp, grad_out):
    print("-" * 20)
    print(f"Backward hook 1  grad inp: {grad_inp}")
    print(f"Backward hook 1 grad out : {grad_out}")
    print("-" * 20)
    
def backward_hook_criterion(module, grad_inp, grad_out):
    print("@" * 20)
    print(f"Backward hook criterion  grad inp: {grad_inp}")
    print(f"Backward hook criterion grad out : {grad_out}")
    print("@" * 20)
    print()
    return grad_inp
    
def backward_pre_hook_criterion(module, grad_out):
    print("@" * 20)
    print(f"Backward pre hook criterion  grad out : {grad_out}")
    print("@" * 20)
    print()
    # Does not change gradient
    return (2 * grad_out[0], )
    
    
class Loss(nn.Module):
    
    def forward(self, inp, target):
        return (inp**2 - target).sum()

torch.manual_seed(0)
net_lin = ModelAutograd(din, dout)
criterion = Loss()

net_lin.zero_grad()
criterion.register_full_backward_hook(backward_hook_criterion)
criterion.register_full_backward_pre_hook(backward_pre_hook_criterion)

net_lin.register_full_backward_pre_hook(backward_pre_hook)
net_lin.net1.register_full_backward_pre_hook(backward_pre_hook1)
net_lin.net2.register_full_backward_pre_hook(backward_pre_hook2)
net_lin.register_full_backward_hook(backward_hook)
net_lin.net2.register_full_backward_hook(backward_hook2)
net_lin.net1.register_full_backward_hook(backward_hook1)
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)

@@@@@@@@@@@@@@@@@@@@
Backward pre hook criterion  grad out : (tensor(1.),)
@@@@@@@@@@@@@@@@@@@@

@@@@@@@@@@@@@@@@@@@@
Backward hook criterion  grad inp: (tensor([[ -9.3122,  -9.3122],
        [-19.1469, -19.1469],
        [  5.2212,   5.2212],
        [ -8.1794,  -8.1794]]), tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]]))
Backward hook criterion grad out : (tensor(2.),)
@@@@@@@@@@@@@@@@@@@@

Backward pre hook model (tensor([[ -9.3122,  -9.3122],
        [-19.1469, -19.1469],
        [  5.2212,   5.2212],
        [ -8.1794,  -8.1794]]),)

Backward hook model grad inp: (None,)
Backward hook model grad out : (tensor([[ -9.3122,  -9.3122],
        [-19.1469, -19.1469],
        [  5.2212,   5.2212],
        [ -8.1794,  -8.1794]]),)

Backward pre hook net 2  (tensor([[ -9.3122,  -9.3122],
        [-19.1469, -19.1469],
        [  5.2212,   5.2212],
        [ -8.1794,  -8.1794]]),)
Gradient for weight  is:  None
multiply by 3 fixed*************************************

In [80]:
def backward_pre_hook(module,  grad_out):
    print(f"Backward pre hook model", grad_out)
    print()
    
    g = grad_out[0]
    # g *= 2
    # return (2 * grad_out[0], )

def backward_hook(module,  grad_inp, grad_out):
    print(f"Backward hook model grad inp: {grad_inp}")
    print(f"Backward hook model grad out : {grad_out}")
    print()
    return (None, )

def backward_pre_hook1(module,  grad_out):
    print(f"Backward pre hook net 1 ", grad_out)
    print()

def backward_pre_hook2(module,  grad_out):
    print(f"Backward pre hook net 2 ", grad_out)
    print('Gradient for weight  is: ', module.weight.grad)
    print("multiply by 3 fixed*************************************")
    print()
    
    g = grad_out[0]
    # g *= 3
    # return (3 * grad_out[0], )

def backward_hook2(module,    grad_inp, grad_out):
    print("+" * 20)
    print(f"Backward hook 2  grad inp: {grad_inp}")
    print(f"Backward hook 2 grad out : {grad_out}")
    print('Gradient for weight  is: ', module.weight.grad)
    
    print("+" * 20)
    print()
    
    
def backward_hook1(module, grad_inp, grad_out):
    print("-" * 20)
    print(f"Backward hook 1  grad inp: {grad_inp}")
    print(f"Backward hook 1 grad out : {grad_out}")
    print("-" * 20)
    
def backward_hook_criterion(module, grad_inp, grad_out):
    print("@" * 20)
    print(f"Backward hook criterion  grad inp: {grad_inp}")
    print(f"Backward hook criterion grad out : {grad_out}")
    print("@" * 20)
    print()
    return grad_inp
    
def backward_pre_hook_criterion(module, grad_out):
    print("@" * 20)
    print(f"Backward pre hook criterion  grad out : {grad_out}")
    print("@" * 20)
    print()
    #  change gradient by 2
    g = grad_out[0]
    g *= 2
    
    
class Loss(nn.Module):
    
    def forward(self, inp, target):
        return (inp**2 - target).sum()

torch.manual_seed(0)
net_lin = ModelAutograd(din, dout)
criterion = Loss()

net_lin.zero_grad()
criterion.register_full_backward_hook(backward_hook_criterion)
criterion.register_full_backward_pre_hook(backward_pre_hook_criterion)

net_lin.register_full_backward_pre_hook(backward_pre_hook)
net_lin.net1.register_full_backward_pre_hook(backward_pre_hook1)
net_lin.net2.register_full_backward_pre_hook(backward_pre_hook2)
net_lin.register_full_backward_hook(backward_hook)
net_lin.net2.register_full_backward_hook(backward_hook2)
net_lin.net1.register_full_backward_hook(backward_hook1)
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)

@@@@@@@@@@@@@@@@@@@@
Backward pre hook criterion  grad out : (tensor(1.),)
@@@@@@@@@@@@@@@@@@@@

@@@@@@@@@@@@@@@@@@@@
Backward hook criterion  grad inp: (tensor([[-18.6244, -18.6244],
        [-38.2937, -38.2937],
        [ 10.4423,  10.4423],
        [-16.3588, -16.3588]]), tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]]))
Backward hook criterion grad out : (tensor(2.),)
@@@@@@@@@@@@@@@@@@@@

Backward pre hook model (tensor([[-18.6244, -18.6244],
        [-38.2937, -38.2937],
        [ 10.4423,  10.4423],
        [-16.3588, -16.3588]]),)

Backward hook model grad inp: (None,)
Backward hook model grad out : (tensor([[-18.6244, -18.6244],
        [-38.2937, -38.2937],
        [ 10.4423,  10.4423],
        [-16.3588, -16.3588]]),)

Backward pre hook net 2  (tensor([[-18.6244, -18.6244],
        [-38.2937, -38.2937],
        [ 10.4423,  10.4423],
        [-16.3588, -16.3588]]),)
Gradient for weight  is:  None
multiply by 3 fixed*************************************