In [1]:
%load_ext autoreload 
%autoreload 2
from functools import partial
import torch
import torch.nn.functional as F
from torch import nn, autograd, func

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [32]:
torch.manual_seed(0)
N = 4
din = 3
dout = 2
x = torch.randn(N, din)
w_hat = torch.rand(din, dout)
y =  x @ w_hat + torch.tensor([1.])

In [5]:
class NetLinear(nn.Module):
    def __init__(self, in_feat, out_feat, bias=True):
        super().__init__()
        self.register_buffer("in_feat", torch.tensor(in_feat).long())
        self.register_buffer("out_feat", torch.tensor(out_feat).long())
        self.weight = nn.Parameter(torch.Tensor(self.out_feat, self.in_feat))
        self.scale = nn.Parameter(torch.ones(self.out_feat))
        if bias:
            self.bias = nn.Parameter(torch.zeros(self.out_feat))
        else:
            self.register_buffer("bias", None)
        
        with torch.no_grad():
            self.weight.data.copy_(torch.ones_like(self.weight.data))
        
    def forward(self, x):
        print("Netlinear forward...")
        out = x @ self.weight.transpose(0, 1)  + self.bias 
        return out 
    

In [13]:
def forward_hook(module, args, output, scale=1.0):
    print(f"Forward scale: {int(scale)}x...")
    
    return scale * output

def criterion(inp, target):
    return inp.sum()

torch.manual_seed(0)
net_lin = NetLinear(din, dout)
net_lin.zero_grad()
handle = net_lin.register_forward_hook(partial(forward_hook, scale=1.))
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)

Netlinear forward...
Forward scale: 1x...
tensor(-6.2835, grad_fn=<SumBackward0>)


In [14]:
handle

<torch.utils.hooks.RemovableHandle at 0x7f1ec8464730>

In [16]:
handle.remove()

In [17]:
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)

Netlinear forward...
tensor(-6.2835, grad_fn=<SumBackward0>)


In [18]:
torch.manual_seed(0)
net_lin = NetLinear(din, dout)
net_lin.zero_grad()
handle = net_lin.register_forward_hook(partial(forward_hook, scale=1.))
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)

Netlinear forward...
Forward scale: 1x...
tensor(-6.2835, grad_fn=<SumBackward0>)


In [21]:
net_lin.weight.grad, net_lin.bias.grad

(tensor([[ 2.1094, -1.1366, -4.1146],
         [ 2.1094, -1.1366, -4.1146]]),
 tensor([4., 4.]))

In [25]:
torch.manual_seed(0)
net_lin = NetLinear(din, dout)
net_lin.zero_grad()
handle = net_lin.register_forward_hook(partial(forward_hook, scale=2.))
yp = net_lin(x)
loss = criterion(yp, y)
loss.backward()
print(loss)

Netlinear forward...
Forward scale: 2x...
tensor(-12.5669, grad_fn=<SumBackward0>)


In [26]:
net_lin.weight.grad, net_lin.bias.grad

(tensor([[ 4.2189, -2.2731, -8.2292],
         [ 4.2189, -2.2731, -8.2292]]),
 tensor([8., 8.]))

In [27]:
class NetLinear(nn.Module):
    def __init__(self, in_feat, out_feat, bias=True):
        super().__init__()
        self.register_buffer("in_feat", torch.tensor(in_feat).long())
        self.register_buffer("out_feat", torch.tensor(out_feat).long())
        self.weight = nn.Parameter(torch.Tensor(self.out_feat, self.in_feat))
        self.scale = nn.Parameter(torch.ones(self.out_feat))
        if bias:
            self.bias = nn.Parameter(torch.zeros(self.out_feat))
        else:
            self.register_buffer("bias", None)
        
        with torch.no_grad():
            self.weight.data.copy_(torch.ones_like(self.weight.data))
        
    def forward(self, x):
        print("Netlinear forward...", x)
        out = x @ self.weight.transpose(0, 1)  + self.bias 
        return out 
    

In [34]:
def forward_hook_prep(module, args):
    x, = args
    print("before x ", x)
    x *= x
    print("After x ", x)


torch.manual_seed(0)
net_lin = NetLinear(din, dout)
net_lin.zero_grad()
handle = net_lin.register_forward_pre_hook(forward_hook_prep)
yp = net_lin(torch.clone(x))
loss = criterion(yp, y)
loss.backward()
print(loss)

before x  tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820]])
After x  tensor([[2.3747, 0.0861, 4.7471],
        [0.3231, 1.1762, 1.9561],
        [0.1627, 0.7023, 0.5173],
        [0.1627, 0.3560, 0.0331]])
Netlinear forward... tensor([[2.3747, 0.0861, 4.7471],
        [0.3231, 1.1762, 1.9561],
        [0.1627, 0.7023, 0.5173],
        [0.1627, 0.3560, 0.0331]])
tensor(25.1947, grad_fn=<SumBackward0>)


In [35]:
class NetLinear(nn.Module):
    def __init__(self, in_feat, out_feat, bias=True):
        super().__init__()
        self.register_buffer("in_feat", torch.tensor(in_feat).long())
        self.register_buffer("out_feat", torch.tensor(out_feat).long())
        self.weight = nn.Parameter(torch.Tensor(self.out_feat, self.in_feat))
        self.scale = nn.Parameter(torch.ones(self.out_feat))
        if bias:
            self.bias = nn.Parameter(torch.zeros(self.out_feat))
        else:
            self.register_buffer("bias", None)
        
        with torch.no_grad():
            self.weight.data.copy_(torch.ones_like(self.weight.data))
        
    def forward(self, x):
        print("Netlinear forward...", x)
        out = x @ self.weight.transpose(0, 1)  + self.bias 
        return out 

class ModelLinear(nn.Module):
    def __init__(self, net1: NetLinear, net2: NetLinear):
        super().__init__()
        self.net1 = net1
        self.net2 = net2
        with torch.no_grad():
            self.net1.weight.data.copy_(torch.ones_like(self.net1.weight))
            self.net2.weight.data.copy_(torch.ones_like(self.net2.weight))
        if self.net1.bias is not None:
            nn.init.constant_(self.net1.bias, 0.)
            nn.init.constant_(self.net2.bias, 0.)
            
    def forward(self, x):
        print("Start forwarding model")
        print("--- Forwarding NET1 ---")
        print()
        
        out1= self.net1(x)
        print("--- Forwarding NET2 ---")
        print()
        out = self.net2(out1)
        print("Done forwarding model")
        return out 



In [37]:
def forward_hook_net1(module, args, output,):
    print(f"************ Forward hook net 1******************", args[0].size())

def forward_hook_net2(module, args, output,):
    print(f"************ Forward hook net 2******************", args[0].size())

def forward_hook_model(module, args, output,):
    print(f"************ Forward hook model******************", args[0].size())


torch.manual_seed(0)
net1 = NetLinear(din, 5, bias=True)
net2 = NetLinear(5, dout , bias=True)
net_lin = ModelLinear(net1, net2)

net1.register_forward_hook(forward_hook_net1)
net2.register_forward_hook(forward_hook_net2)
net_lin.register_forward_hook(forward_hook_model)

net_lin.zero_grad()

yp = net_lin(torch.clone(x))
loss = criterion(yp, y)

Start forwarding model
--- Forwarding NET1 ---

Netlinear forward... tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820]])
************ Forward hook net 1****************** torch.Size([4, 3])
--- Forwarding NET2 ---

Netlinear forward... tensor([[-0.9312, -0.9312, -0.9312, -0.9312, -0.9312],
        [-1.9147, -1.9147, -1.9147, -1.9147, -1.9147],
        [ 0.5221,  0.5221,  0.5221,  0.5221,  0.5221],
        [-0.8179, -0.8179, -0.8179, -0.8179, -0.8179]], grad_fn=<AddBackward0>)
************ Forward hook net 2****************** torch.Size([4, 5])
Done forwarding model
************ Forward hook model****************** torch.Size([4, 3])


In [40]:
def forward_hook_net1(module, args, output,):
    print(f"************ Forward hook net 1******************", args[0].size())

def forward_hook_net2(module, args, output,):
    print(f"************ Forward hook net 2******************", args[0].size())

def forward_hook_model(module, args, output,):
    print(f"************ Forward hook model 1 ******************", args[0].size())
def forward_hook_model2(module, args, output,):
    print(f"************ Forward hook model 2 ******************", args[0].size())

torch.manual_seed(0)
net1 = NetLinear(din, 5, bias=True)
net2 = NetLinear(5, dout , bias=True)
net_lin = ModelLinear(net1, net2)

net1.register_forward_hook(forward_hook_net1)
net2.register_forward_hook(forward_hook_net2)
net_lin.register_forward_hook(forward_hook_model,)
net_lin.register_forward_hook(forward_hook_model2)

net_lin.zero_grad()

yp = net_lin(torch.clone(x))
loss = criterion(yp, y)

Start forwarding model
--- Forwarding NET1 ---

Netlinear forward... tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820]])
************ Forward hook net 1****************** torch.Size([4, 3])
--- Forwarding NET2 ---

Netlinear forward... tensor([[-0.9312, -0.9312, -0.9312, -0.9312, -0.9312],
        [-1.9147, -1.9147, -1.9147, -1.9147, -1.9147],
        [ 0.5221,  0.5221,  0.5221,  0.5221,  0.5221],
        [-0.8179, -0.8179, -0.8179, -0.8179, -0.8179]], grad_fn=<AddBackward0>)
************ Forward hook net 2****************** torch.Size([4, 5])
Done forwarding model
************ Forward hook model 1 ****************** torch.Size([4, 3])
************ Forward hook model 2 ****************** torch.Size([4, 3])


In [41]:
def forward_hook_net1(module, args, output,):
    print(f"************ Forward hook net 1******************", args[0].size())

def forward_hook_net2(module, args, output,):
    print(f"************ Forward hook net 2******************", args[0].size())

def forward_hook_model(module, args, output,):
    print(f"************ Forward hook model 1 ******************", args[0].size())
def forward_hook_model2(module, args, kwargs, output,):
    print(f"************ Forward hook model 2 ******************", args[0].size())

torch.manual_seed(0)
net1 = NetLinear(din, 5, bias=True)
net2 = NetLinear(5, dout , bias=True)
net_lin = ModelLinear(net1, net2)

net1.register_forward_hook(forward_hook_net1)
net2.register_forward_hook(forward_hook_net2)
net_lin.register_forward_hook(forward_hook_model)
net_lin.register_forward_hook(forward_hook_model2, with_kwargs=True, prepend=True)

net_lin.zero_grad()

yp = net_lin(torch.clone(x))
loss = criterion(yp, y)

Start forwarding model
--- Forwarding NET1 ---

Netlinear forward... tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820]])
************ Forward hook net 1****************** torch.Size([4, 3])
--- Forwarding NET2 ---

Netlinear forward... tensor([[-0.9312, -0.9312, -0.9312, -0.9312, -0.9312],
        [-1.9147, -1.9147, -1.9147, -1.9147, -1.9147],
        [ 0.5221,  0.5221,  0.5221,  0.5221,  0.5221],
        [-0.8179, -0.8179, -0.8179, -0.8179, -0.8179]], grad_fn=<AddBackward0>)
************ Forward hook net 2****************** torch.Size([4, 5])
Done forwarding model
************ Forward hook model 2 ****************** torch.Size([4, 3])
************ Forward hook model 1 ****************** torch.Size([4, 3])
