In [1]:
%load_ext autoreload 
%autoreload 2
import torch
import torch.nn.functional as F
from torch import nn, autograd, func

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
class NetLinear(nn.Module):
    def __init__(self, in_feat, out_feat, bias=True):
        super().__init__()
        self.register_buffer("in_feat", torch.tensor(in_feat).long())
        self.register_buffer("out_feat", torch.tensor(out_feat).long())
        self.weight = nn.Parameter(torch.Tensor(self.out_feat, self.in_feat))
        self.scale = nn.Parameter(torch.ones(self.out_feat))
        if bias:
            self.bias = nn.Parameter(torch.zeros(self.out_feat))
        else:
            self.register_buffer("bias", None)
        
        with torch.no_grad():
            nn.init.xavier_normal_(self.weight)
        
    def forward(self, x):
        out = x @ self.weight.transpose(0, 1)  + self.bias 
        print(out)
        return out 
    
    def __repr__(self):
        return f"NetLinear(in_feat={self.in_feat}, out_feat={self.out_feat}, bias={self.bias is not None})"

In [10]:
def normalize_grad_backward(module, grad_inp, grad_out):
    print("No grad scaling....")
    print(f"Grad input {len(grad_inp)}: ", grad_inp, )
    print(f"Grad input : ", grad_inp[0].size(), grad_inp[0].sum(0) )
    print()
    print(f"Grad output {len(grad_out)}: ", grad_out, )
    return (grad_inp[0] * 1, grad_inp[1])

def criterion(inp, target):
    # return (inp - target).square().sum()
    return (2*inp**2 - target).sum()
N = 4
din = 3
dout = 2
torch.manual_seed(0)
net_lin = NetLinear(din, dout)
net_lin.zero_grad()
net_lin.register_backward_hook(normalize_grad_backward)
torch.manual_seed(0)
x = torch.randn(N, din)
w_hat = torch.rand(din, dout)
y =  x @ w_hat + torch.tensor([1.])

yp = net_lin(x)

tensor([[ 4.5587,  2.6825],
        [ 2.6825,  2.1854],
        [ 1.2287,  0.2064],
        [-0.5332,  0.1032]], grad_fn=<AddBackward0>)




In [11]:
loss = criterion(yp, y)
net_lin.zero_grad()
loss.backward()
print(loss)

No grad scaling....
Grad input 2:  (tensor([[18.2347, 10.7301],
        [10.7301,  8.7415],
        [ 4.9149,  0.8257],
        [-2.1329,  0.4129]]), tensor([31.7467, 20.7100]))
Grad input :  torch.Size([4, 2]) tensor([31.7467, 20.7100])

Grad output 1:  (tensor([[18.2347, 10.7301],
        [10.7301,  8.7415],
        [ 4.9149,  0.8257],
        [-2.1329,  0.4129]]),)
tensor(76.9643, grad_fn=<SumBackward0>)


In [38]:
net_lin.weight.grad, net_lin.bias.grad

(tensor([[ 37.0416, -11.5962, -58.6599],
         [ 21.6704, -12.1832, -36.1230]]),
 tensor([31.7467, 20.7100]))

In [39]:
def normalize_grad_backward(module, grad_inp, grad_out):
    print("grad scaling 2x ....")
    print(f"Grad input {len(grad_inp)}: ", grad_inp, )
    print(f"Grad input : ", grad_inp[0].size(), grad_inp[0].sum(0) )
    print()
    print(f"Grad output {len(grad_out)}: ", grad_out, )
    return (grad_inp[0] * 2, grad_inp[1])

In [40]:
torch.manual_seed(0)
net_lin = NetLinear(din, dout)
net_lin.zero_grad()
net_lin.register_backward_hook(normalize_grad_backward)
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)
print(net_lin.weight)

grad scaling 2x ....
Grad input 2:  (tensor([[18.2347, 10.7301],
        [10.7301,  8.7415],
        [ 4.9149,  0.8257],
        [-2.1329,  0.4129]]), tensor([119.0857,  48.0999]))
Grad input :  torch.Size([4, 2]) tensor([31.7467, 20.7100])

Grad output 1:  (tensor([[18.2347, 10.7301],
        [10.7301,  8.7415],
        [ 4.9149,  0.8257],
        [-2.1329,  0.4129]]),)
tensor(76.9643, grad_fn=<SumBackward0>)
Parameter containing:
tensor([[ 0.9746, -0.1856, -1.3780],
        [ 0.3595, -0.6859, -0.8845]], requires_grad=True)


In [41]:
net_lin.weight.grad, net_lin.bias.grad

(tensor([[  74.0831,  -23.1925, -117.3197],
         [  43.3408,  -24.3665,  -72.2460]]),
 tensor([63.4934, 41.4201]))

In [42]:
weight = torch.as_tensor(net_lin.weight)

In [43]:
class NetLinear_Noparam(nn.Module):
    def __init__(self, in_feat, out_feat, bias=True):
        super().__init__()
        self.net = nn.Linear(in_feat, out_feat, bias=bias)
        with torch.no_grad():
            self.net.weight.data.copy_(weight)
        if bias:
            nn.init.constant_(self.net.bias, 0.)
        
    def forward(self, x):
        out = self.net(x)
        return out 


In [46]:
def normalize_grad_backward(module, grad_inp, grad_out):
    print("grad scaling 2x ....")
    print(f"Grad input {len(grad_inp)}: ", grad_inp, )
    print(f"Grad input : ", grad_inp[0].size(), grad_inp[0].sum(0) )
    print()
    print(f"Grad output {len(grad_out)}: ", grad_out, )
    return (grad_inp[0] * 2, grad_inp[1], grad_inp[2] * 2.)

torch.manual_seed(0)
net_lin = NetLinear_Noparam(din, dout)
net_lin.zero_grad()
net_lin.register_backward_hook(normalize_grad_backward)
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)
print(net_lin.net.weight)

grad scaling 2x ....
Grad input 3:  (tensor([31.7467, 20.7100]), None, tensor([[ 37.0416,  21.6704],
        [-11.5962, -12.1832],
        [-58.6599, -36.1230]]))
Grad input :  torch.Size([2]) tensor(52.4567)

Grad output 1:  (tensor([[18.2347, 10.7301],
        [10.7301,  8.7415],
        [ 4.9149,  0.8257],
        [-2.1329,  0.4129]]),)
tensor(76.9643, grad_fn=<SumBackward0>)
Parameter containing:
tensor([[ 0.9746, -0.1856, -1.3780],
        [ 0.3595, -0.6859, -0.8845]], requires_grad=True)


In [47]:
net_lin.net.weight.grad, net_lin.net.bias.grad

(tensor([[  74.0831,  -23.1925, -117.3197],
         [  43.3408,  -24.3665,  -72.2460]]),
 tensor([63.4934, 41.4201]))

In [2]:

class MatMul(autograd.Function):
    
    @staticmethod
    def forward(ctx, input, weights, bias) -> torch.Tensor:
        out = input @ weights.transpose(0, 1) + bias
        ctx.save_for_backward(input, weights, bias, out)
        return out
    
    @staticmethod
    def backward(ctx, *grad_outputs):
        # print("Linear backward ", grad_outputs)
        input, weights, bias, out = ctx.saved_tensors
        bz, fz = input.size()
        out_feat, in_feat = weights.size()
        # assert fz == in_feat
        # print(grad_outputs[0].size())
        grad_w =  input.unsqueeze(1).expand(bz, out_feat, -1) * grad_outputs[0].unsqueeze(-1) 
        grad_w = grad_w.sum(0)
        grad_b =  (bias.new_ones(bias.size()).unsqueeze(0).unsqueeze(0) * grad_outputs[0]).sum(0).sum(0)
        # print(grad_b.size(), grad_w.size())
        inp_grad = None
        if ctx.needs_input_grad[0]:
            print("Need inp grad ")
            inp_grad =  grad_outputs[0] @ weights.t() 
        print("Grad_w is : ", grad_w )
        print("Grad_b is : ", grad_b )
        print()
        return  inp_grad, grad_w, grad_b,

class NetLinear(nn.Module):
    def __init__(self, in_feat, out_feat, bias=True):
        super().__init__()
        self.register_buffer("in_feat", torch.tensor(in_feat).long())
        self.register_buffer("out_feat", torch.tensor(out_feat).long())
        self.weight = nn.Parameter(torch.Tensor(self.out_feat, self.in_feat))
        self.scale = nn.Parameter(torch.ones(self.out_feat))
        if bias:
            self.bias = nn.Parameter(torch.zeros(self.out_feat))
        else:
            self.register_buffer("bias", None)
        
        with torch.no_grad():
            nn.init.xavier_normal_(self.weight)
        
    def forward(self, x):
        out = MatMul.apply(x, self.weight, self.bias)
        return out 
    
    def __repr__(self):
        return f"NetLinear(in_feat={self.in_feat}, out_feat={self.out_feat}, bias={self.bias is not None})"

In [5]:
def normalize_grad_backward(module, grad_inp, grad_out):
    print(f"Grad input {len(grad_inp)}: ", grad_inp, )
    print()
    print(f"Grad output {len(grad_out)}: ", grad_out, )
    return (grad_inp[0], grad_inp[1] * 2, grad_inp[2])

def criterion(inp, target):
    # return (inp - target).square().sum()
    return (2*inp**2 - target).sum()
N = 4
din = 3
dout = 2
torch.manual_seed(0)
net_lin = NetLinear(din, dout)
net_lin.zero_grad()
net_lin.register_backward_hook(normalize_grad_backward)
torch.manual_seed(0)
x = torch.randn(N, din)
w_hat = torch.rand(din, dout)
y =  x @ w_hat + torch.tensor([1.])

yp = net_lin(x)



In [6]:
loss = criterion(yp, y)
loss.backward()
print(loss)

Grad_w is :  tensor([[ 37.0416, -11.5962, -58.6599],
        [ 21.6704, -12.1832, -36.1230]])
Grad_b is :  tensor([31.7467, 20.7100])

Grad input 3:  (None, tensor([[ 37.0416, -11.5962, -58.6599],
        [ 21.6704, -12.1832, -36.1230]]), tensor([31.7467, 20.7100]))

Grad output 1:  (tensor([[18.2347, 10.7301],
        [10.7301,  8.7415],
        [ 4.9149,  0.8257],
        [-2.1329,  0.4129]]),)
tensor(76.9643, grad_fn=<SumBackward0>)


In [8]:
net_lin.weight.grad, net_lin.bias.grad

(tensor([[  74.0831,  -23.1925, -117.3197],
         [  43.3408,  -24.3665,  -72.2460]]),
 tensor([31.7467, 20.7100]))