In [1]:
%load_ext autoreload 
%autoreload 2
import torch
import torch.nn.functional as F
from torch import nn, autograd, func

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:

torch.manual_seed(0)

N = 4
din = 3
dout = 2
x = torch.randn(N, din)
w_hat = torch.rand(din, dout)
y =  x @ w_hat + torch.tensor([1.])

In [3]:
def criterion(inp, target):
    # return (inp - target).square().sum()
    return (2*inp**2 - target).sum()

In [4]:
class NetLinear_Noparam(nn.Module):
    def __init__(self, in_feat, out_feat, bias=True):
        super().__init__()
        self.net1 = nn.Linear(in_feat, 5, bias=bias)
        self.net2 = nn.Linear(5, out_feat , bias=bias)
        with torch.no_grad():
            self.net1.weight.data.copy_(torch.ones_like(self.net1.weight))
            self.net2.weight.data.copy_(torch.ones_like(self.net2.weight))
        if bias:
            nn.init.constant_(self.net1.bias, 0.)
            nn.init.constant_(self.net2.bias, 0.)
        
    def forward(self, x):
        out1= self.net1(x)
        
        out = self.net2(out1)
        print(out.size(), out1.size())
        return out 


In [5]:
def normalize_grad_backward(module, grad_inp, grad_out):
    print("grad scaling 2x ....", type(module))
    print(f"Grad input {len(grad_inp)}: ", grad_inp, )
    print(f"Grad input : ", grad_inp[1].size() )
    print()
    print(f"Grad output {len(grad_out)}: ", grad_out, )
    # return (grad_inp[0] * 2, grad_inp[1], grad_inp[2] * 2.)
    # return (grad_inp[0], grad_inp[1] * 2, grad_inp[2] )
    return (grad_inp[0], grad_inp[1] * 2, grad_inp[2] )

torch.manual_seed(0)
net_lin = NetLinear_Noparam(din, dout)
net_lin.zero_grad()
net_lin.register_backward_hook(normalize_grad_backward)
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)

torch.Size([4, 2]) torch.Size([4, 5])
grad scaling 2x .... <class '__main__.NetLinear_Noparam'>
Grad input 3:  (tensor([-62.8347, -62.8347]), tensor([[-37.2489, -37.2489, -37.2489, -37.2489, -37.2489],
        [-76.5875, -76.5875, -76.5875, -76.5875, -76.5875],
        [ 20.8846,  20.8846,  20.8846,  20.8846,  20.8846],
        [-32.7177, -32.7177, -32.7177, -32.7177, -32.7177]]), tensor([[109.4967, 109.4967],
        [109.4967, 109.4967],
        [109.4967, 109.4967],
        [109.4967, 109.4967],
        [109.4967, 109.4967]]))
Grad input :  torch.Size([4, 5])

Grad output 1:  (tensor([[-18.6244, -18.6244],
        [-38.2937, -38.2937],
        [ 10.4423,  10.4423],
        [-16.3588, -16.3588]]),)
tensor(540.8549, grad_fn=<SumBackward0>)




In [6]:
net_lin.net1.weight.grad, net_lin.net1.bias.grad

(tensor([[-158.6298,  262.0262,  334.5902],
         [-158.6298,  262.0262,  334.5902],
         [-158.6298,  262.0262,  334.5902],
         [-158.6298,  262.0262,  334.5902],
         [-158.6298,  262.0262,  334.5902]]),
 tensor([-251.3388, -251.3388, -251.3388, -251.3388, -251.3388]))

In [7]:
net_lin.net2.weight.grad, net_lin.net2.bias.grad

(tensor([[109.4967, 109.4967, 109.4967, 109.4967, 109.4967],
         [109.4967, 109.4967, 109.4967, 109.4967, 109.4967]]),
 tensor([-62.8347, -62.8347]))

In [4]:

class MatMul(autograd.Function):
    
    @staticmethod
    def forward(ctx, input, weights, bias) -> torch.Tensor:
        out = input @ weights.transpose(0, 1) + bias
        ctx.save_for_backward(input, weights, bias, out)
        return out
    
    @staticmethod
    def backward(ctx, *grad_outputs):
        print("--"* 100)
        print(grad_outputs)
        print("------" * 50)
        input, weights, bias, out = ctx.saved_tensors
        bz, fz = input.size()
        out_feat, in_feat = weights.size()
        assert fz == in_feat
        print(input.size(), grad_outputs[0].size())
        grad_w =  input.unsqueeze(1).expand(bz, out_feat, -1) * grad_outputs[0].unsqueeze(-1) 
        grad_w = grad_w.sum(0)
        grad_b =  (bias.new_ones(bias.size()).unsqueeze(0).unsqueeze(0) * grad_outputs[0]).sum(0).sum(0)
        inp_grad = None
        if ctx.needs_input_grad[0]:
            print("Need inp grad ", grad_outputs[0].size(), weights.size())
            inp_grad =  grad_outputs[0] @ weights
            assert inp_grad.size() == input.size()
        print("Grad_w is : ", grad_w )
        print("Grad_b is : ", grad_b )
        print()
        return  inp_grad, grad_w, grad_b,

class NetLinear(nn.Module):
    def __init__(self, in_feat, out_feat, bias=True):
        super().__init__()
        self.register_buffer("in_feat", torch.tensor(in_feat).long())
        self.register_buffer("out_feat", torch.tensor(out_feat).long())
        self.weight = nn.Parameter(torch.Tensor(self.out_feat, self.in_feat))
        self.scale = nn.Parameter(torch.ones(self.out_feat))
        if bias:
            self.bias = nn.Parameter(torch.zeros(self.out_feat))
        else:
            self.register_buffer("bias", None)
        
        with torch.no_grad():
            nn.init.xavier_normal_(self.weight)
        
    def forward(self, x):
        out = MatMul.apply(x, self.weight, self.bias)
        return out 

class ModelAutograd(nn.Module):
    def __init__(self, in_feat, out_feat, bias=True):
        super().__init__()
        self.net1 = NetLinear(in_feat, 5, bias=bias)
        self.net2 = NetLinear(5, out_feat , bias=bias)
        with torch.no_grad():
            self.net1.weight.data.copy_(torch.ones_like(self.net1.weight))
            self.net2.weight.data.copy_(torch.ones_like(self.net2.weight))
        if bias:
            nn.init.constant_(self.net1.bias, 0.)
            nn.init.constant_(self.net2.bias, 0.)
    
    def forward(self, x):
        return self.net2(self.net1(x))

In [9]:
def normalize_grad_backward(module, grad_inp, grad_out):
    print()
    print()
    print()
    print(f"Grad input {len(grad_inp)}: ", grad_inp, )
    print()
    print(f"Grad output {len(grad_out)}: ", grad_out, )
    print()
    print()
    print()
    return (grad_inp[0] *2, grad_inp[1] , grad_inp[2])
torch.manual_seed(0)
net_lin = ModelAutograd(din, dout)
net_lin.zero_grad()
net_lin.register_backward_hook(normalize_grad_backward)
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(tensor([[-18.6244, -18.6244],
        [-38.2937, -38.2937],
        [ 10.4423,  10.4423],
        [-16.3588, -16.3588]]),)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
torch.Size([4, 5]) torch.Size([4, 2])
Need inp grad  torch.Size([4, 2]) torch.Size([2, 5])
Grad_w is :  tensor([[109.4967, 109.4967, 109.4967, 109.4967, 109.4967],
        [109.4967, 109.4967, 109.4967, 109.4967, 109.4967]])
Grad_b is :  tensor([-62.8347, -62.8347])




Grad input 3:  (tensor([[-37.2489, -37.2489, -37.2489, -37.2489, -37.2489],
        [-76.5875, -76.5875



In [6]:
net_lin.net1.weight.grad, net_lin.net1.bias.grad

(tensor([[-158.6298,  262.0262,  334.5902],
         [-158.6298,  262.0262,  334.5902],
         [-158.6298,  262.0262,  334.5902],
         [-158.6298,  262.0262,  334.5902],
         [-158.6298,  262.0262,  334.5902]]),
 tensor([-251.3388, -251.3388, -251.3388, -251.3388, -251.3388]))

In [7]:
net_lin.net2.weight.grad, net_lin.net2.bias.grad

(tensor([[109.4967, 109.4967, 109.4967, 109.4967, 109.4967],
         [109.4967, 109.4967, 109.4967, 109.4967, 109.4967]]),
 tensor([-62.8347, -62.8347]))